# Copyright 2024, Seiko Epson Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the “Software”), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.

import logging
import multiprocessing
import os
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.sharedctypes import Synchronized
from random import random

import pytest
from pytest_mock import MockerFixture

from logger.core import Topic
from logger.measure import ReaderArgs, WriterArgs, reader_job, writer_job
from logger.measure.comm import Comm
from logger.measure.data import MeasureData


@dataclass
class ReaderArgsForTest(ReaderArgs):
    def __post_init__(self):
        super().__post_init__()
        self._count = 1
        self._count_max = 4

    def format_packet(self, index: int, packet: list[int]) -> MeasureData:
        # パケット境界の修正が reader_job で行われることを確認するため、
        # そちらを通過して不正な packet が渡されたらエラーを投げる
        # Ensure packet boundary correction is done in reader_job,
        # raise error if invalid packet is passed
        if packet[0] != 0x80 or packet[-1] != 0x0D:
            raise ValueError("Invalid packet boundary")

        # A342 のデータに合わせる
        # Match A342 data
        count = self._count % self._count_max
        self._count += 1

        # 戻り値はダミー
        # Return value is dummy
        dummy = MeasureData(
            index=index, count=count, temperature=0, x=0, y=0, z=0, flag=0
        )
        return dummy

    def _gen_invalid_data(self, index: int, count: int) -> MeasureData:
        return MeasureData(
            index=index, count=count, temperature=0, x=0, y=0, z=0, flag=0
        )


@dataclass
class WriterArgsForTest(WriterArgs):
    def __post_init__(self):
        self._count = 0

    def _get_filename(self) -> str:
        # ファイル名をタイムスタンプで生成するとテストが難しいため、内部カウントを使う
        # Using internal count instead of timestamp for filename to simplify testing
        file_name = f"{self._count}.csv"
        self._count += 1
        return file_name


int_16bit_max = 2**16


class TestJob:
    class TestReaderJob:
        def test_read_job_success(self, mocker, mock_comm):
            # read ループを1回で抜けるように multiprocessing.Value の戻り値をモック
            # Mock the return value of multiprocessing.Value to exit the read loop after one iteration
            mocker.patch(
                "multiprocessing.sharedctypes.Synchronized.value",
                side_effect=[True, False],
                new_callable=mocker.PropertyMock,
            )
            valid_packet = [0x80, 0, 0, 0x0D]
            mock_comm(
                send_return=[],
                read_return=[
                    *valid_packet,
                    *valid_packet,
                    *valid_packet,
                ],
            )
            queue: Queue = Queue(-1)
            error_queue: Queue = Queue(-1)
            measuring: Synchronized = multiprocessing.Value("b", True)
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=len(valid_packet),
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            reader_job(
                queue=queue, error_queue=error_queue, measuring=measuring, args=args
            )

            # 一回目の取得時は1秒分のレコードが入っている
            # The first retrieval contains records for one second
            assert not queue.empty()
            records = queue.get()
            assert len(records) == args.record_per_sec

            # 二回目の取得時はループを抜けているので残りが全て取得できる
            # The second retrieval gets all remaining records as the loop has exited
            records = queue.get()
            assert len(records) == 1

            # ループを抜けたので終了フラグのNoneを返す
            # Returns None as the end flag since the loop has exited
            records = queue.get()
            assert records is None
            assert queue.empty()

        def test_read_job_success_with_invalid_packets(self, mocker, mock_comm):
            # read ループを1回で抜けるように multiprocessing.Value の戻り値をモック
            # Mock the return value of multiprocessing.Value to exit the read loop after one iteration
            mocker.patch(
                "multiprocessing.sharedctypes.Synchronized.value",
                side_effect=[True, False],
                new_callable=mocker.PropertyMock,
            )
            valid_packet = [0x80, 0, 0, 0x0D]
            mock_comm(
                send_return=[],
                # 間に不正なパケットが入っていても正しく読み取れるか確認
                # Verify correct reading even with invalid packets in between
                read_return=[
                    *valid_packet,
                    0,
                    1,
                    2,
                    4,
                    *valid_packet,
                    0x80,
                    0,
                ],
            )
            queue: Queue = Queue(-1)
            error_queue: Queue = Queue(-1)
            measuring: Synchronized = multiprocessing.Value("b", True)
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=len(valid_packet),
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            reader_job(
                queue=queue, error_queue=error_queue, measuring=measuring, args=args
            )

            assert not queue.empty()

            # エラーが発生していないことの確認
            # Confirm no errors have occurred
            assert error_queue.empty()

            # 一回目の取得時は1秒分のレコードが入っている
            # The first retrieval contains records for one second
            records = queue.get()
            assert len(records) == args.record_per_sec

            # 二回目の取得時はループを抜けているので残りが全て取得できる
            # The second retrieval gets all remaining records as the loop has exited
            records = queue.get()
            assert len(records) == 0

            # ループを抜けたので終了フラグのNoneを返す
            # Returns None as the end flag since the loop has exited
            records = queue.get()
            assert records is None
            assert queue.empty()

            queue.close()
            error_queue.close()

    class TestWriterJob:
        def test_writer_job_success(self, tmpdir):
            queue: Queue = Queue(-1)
            error_queue: Queue = Queue(-1)
            args = WriterArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                logger_id="RP1",
                record_per_file=2,
            )
            data = [MeasureData(index=0, count=0, temperature=0, x=0, y=0, z=0, flag=0)]

            # 1件のデータを3回送る、1ファイルあたり2レコードだから2つファイルができる
            # Send one piece of data three times, creating two files as each file contains two records
            queue.put(data)
            queue.put(data)
            queue.put(data)
            queue.put(None)
            writer_job(
                queue=queue, error_queue=error_queue, args=args, output_dir=tmpdir
            )

            # queueから全部取得しているか
            # Check if all data has been retrieved from the queue
            assert queue.empty()

            # エラーが発生していないか
            # Check if no errors have occurred
            assert error_queue.empty()

            # 2ファイル出力されているか
            # Check if two files have been output
            csvs = sorted(os.listdir(tmpdir))
            assert csvs == ["0.csv", "1.csv"]

            queue.close()
            error_queue.close()

        def test_writer_job_empty_records_should_not_create_file(self, tmpdir):
            queue: Queue = Queue(-1)
            error_queue: Queue = Queue(-1)
            args = WriterArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                logger_id="RP1",
                record_per_file=2,
            )
            data = [MeasureData(index=0, count=0, temperature=0, x=0, y=0, z=0, flag=0)]

            # 1件のデータを2回と空の配列を送る、1ファイルあたり2レコードだから1つファイルができる
            # Send one piece of data twice and an empty array, creating one file as each file contains two records
            queue.put(data)
            queue.put(data)
            queue.put([])
            queue.put(None)
            writer_job(
                queue=queue, error_queue=error_queue, args=args, output_dir=tmpdir
            )

            # queueから全部取得しているか
            # Check if all data has been retrieved from the queue
            assert queue.empty()

            # エラーが発生していないか
            # Check if no errors have occurred
            assert error_queue.empty()

            # 1ファイル出力されているか
            # Check if one file has been output
            csvs = sorted(os.listdir(tmpdir))
            assert csvs == ["0.csv"]

            queue.close()
            error_queue.close()

    class TestReaderJobArgs:
        def test_complement_data_success_when_prev_is_none(self):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                None,
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0),
            )
            assert actual == [
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0)
            ]

        def test_complement_data_success_when_no_missing(self):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                MeasureData(index=0, count=3, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=1, count=0, temperature=0, x=0, y=0, z=0, flag=0),
            )
            assert actual == [
                MeasureData(index=1, count=0, temperature=0, x=0, y=0, z=0, flag=0)
            ]

        def test_complement_data_success_when_3_data_missing(self):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                MeasureData(index=0, count=2, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0),
            )
            assert actual == [
                MeasureData(index=1, count=3, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=2, count=0, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=3, count=1, temperature=0, x=0, y=0, z=0, flag=0),
            ]

        def test_complement_data_success_when_first_data_but_missing(self):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                None,
                MeasureData(index=0, count=2, temperature=0, x=0, y=0, z=0, flag=0),
            )
            assert actual == [
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=1, count=2, temperature=0, x=0, y=0, z=0, flag=0),
            ]

        def test_complement_data_success_count_start_is_not_zero(self):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                None,
                MeasureData(index=0, count=2, temperature=0, x=0, y=0, z=0, flag=0),
            )
            assert actual == [
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=1, count=2, temperature=0, x=0, y=0, z=0, flag=0),
            ]

        @pytest.mark.parametrize(
            "count_start, count_diff",
            [
                (
                    (rms_pp_interval * 600 + 2) % int_16bit_max,
                    rms_pp_interval * 600,
                )
                for rms_pp_interval in range(1, 256)
            ],
        )
        def test_complement_data_success_count_max(self, count_start, count_diff):
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=count_diff,
                count_max=int_16bit_max,
                count_start=count_start,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )
            actual = args.complement_data(
                None,
                MeasureData(
                    index=0,
                    count=(count_start + count_diff * 9) % int_16bit_max,
                    temperature=0,
                    x=0,
                    y=0,
                    z=0,
                    flag=0,
                ),
            )

            # 2件補完されている
            # 2 records are complemented
            assert len(actual) == 10
            for i in range(9):
                if actual[i + 1].count - actual[i].count < count_diff:
                    assert (actual[i].count + count_diff) % int_16bit_max == actual[
                        i + 1
                    ].count
                else:
                    assert actual[i + 1].count - actual[i].count == count_diff

        @staticmethod
        def _new_MeasureData(index: int) -> MeasureData:
            return MeasureData(
                index=index,
                x=random(),
                y=random(),
                z=random(),
                count=0,
                temperature=0,
                flag=0,
            )

        def test_diag_sensor_data_with_many_various_data(self, caplog):
            diag_count = 3
            args = ReaderArgs(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=0,
                count_max=0,
                count_start=0,
                sensor_data_diag=True,
                diag_broken_count=diag_count,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実行： 多めにループする
            # Execution: Loop more times
            prev = None
            for i in range(diag_count * 2):
                data = self._new_MeasureData(i)
                args.diag_sensor_data(prev=prev, next=data)
                prev = data

            # ログ出力されていない
            # No log output
            assert len(caplog.records) == 0

        def test_diag_sensor_data_with_many_same_data(self, caplog):
            diag_count = 3
            args = ReaderArgs(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=0,
                count_max=0,
                count_start=0,
                sensor_data_diag=True,
                diag_broken_count=diag_count,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実行 Execution
            prev = None
            # - 固定データを与える
            # - Provide fixed data
            data = self._new_MeasureData(0)
            for _ in range(diag_count * 2):
                args.diag_sensor_data(prev=prev, next=data)
                prev = data

            # ログが３件出力されている
            # Three log entries are output
            assert len(caplog.records) == 3

            # レベルは WARNING
            # The level is WARNING
            assert all(rec.levelname == "WARNING" for rec in caplog.records)

            # 次のメッセージが含まれている
            # - 各軸ごとに一つずつ
            # The following message are included
            # - One for each axis
            for axis in ["x", "y", "z"]:
                assert f"Sensor on axis: {axis} is possibly broken" in caplog.messages

        def test_diag_sensor_data_with_less_same_data(self, caplog):
            diag_count = 3
            args = ReaderArgs(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=0,
                count_max=0,
                count_start=0,
                sensor_data_diag=True,
                diag_broken_count=diag_count,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実行 Execution
            prev = None

            # - diag_count と同じ数の同値データ
            # - The same number of identical data as diag_count
            for i in range(diag_count):
                data = self._new_MeasureData(i)
                data.x = 0.5
                args.diag_sensor_data(prev=prev, next=data)
                prev = data

            # 最初の１件は連続にカウントされないので、メッセージなし
            # The first entry is not counted as continuous, so no message
            assert len(caplog.records) == 0

        def test_diag_sensor_data_with_more_same_data_and_recover(self, caplog):
            diag_count = 3
            args = ReaderArgs(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=0,
                count_max=0,
                count_start=0,
                sensor_data_diag=True,
                diag_broken_count=diag_count,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実行 Execution
            prev = None

            # - INFO レベルのログを取得する
            # - Set log level to INFO
            caplog.set_level(logging.INFO)

            # - diag_count + 1 件の同値データ
            # - Set log level to INFO
            for i in range(diag_count + 1):
                data = self._new_MeasureData(i)
                data.x = 0.5
                args.diag_sensor_data(prev=prev, next=data)
                prev = data

            # - 続いて変化したデータ
            # - Followed by changed data
            data = self._new_MeasureData(diag_count + 1)
            args.diag_sensor_data(prev=prev, next=data)

            # ログが２件出力されている
            # Two log entries are output
            assert len(caplog.records) == 2

            # WARNING -> INFO
            assert caplog.records[0].levelname == "WARNING"
            assert caplog.records[1].levelname == "INFO"

            # メッセージ Messages
            assert caplog.records[0].message == "Sensor on axis: x is possibly broken"
            assert caplog.records[1].message == "Sensor on axis: x is fixed"

        def test_send_message_when_lost_found(self, mocker: MockerFixture):
            # based on: test_complement_data_success_when_3_data_missing
            args = ReaderArgsForTest(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=1,
                count_max=4,
                count_start=1,
                sensor_data_diag=True,
                diag_broken_count=10,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実際にメッセージが飛ぶことをテストするのは困難なため Mock 化する
            # It is difficult to test the actual sending of messages, so we mock it
            mock = mocker.patch("logger.measure.job.MessageService.send")

            args.complement_data(
                MeasureData(index=0, count=2, temperature=0, x=0, y=0, z=0, flag=0),
                MeasureData(index=0, count=1, temperature=0, x=0, y=0, z=0, flag=0),
            )

            # 呼び出されていること
            # Ensure it is called
            mock.assert_called_once()

            # lost トピックでメッセージはログに出力されたものと同一
            # The message in the `lost` topic should be the same as the one output to the log
            mock.assert_called_with(
                Topic.sensor_lost("TEST", "serial"),
                "Missing 2 data from index: 1. Complement them.",
            )

        def test_send_message_when_diag_error_and_recover(self, mocker: MockerFixture):
            # based on: test_diag_sensor_data_with_more_same_data_and_recover
            diag_count = 3
            args = ReaderArgs(
                model="TEST",
                serial="serial",
                port="/dev/AAA",
                baud=100,
                record_size=4,
                record_per_sec=2,
                record_begin=0x80,
                record_end=0x0D,
                start_command=[],
                end_command=[],
                count_diff=0,
                count_max=0,
                count_start=0,
                sensor_data_diag=True,
                diag_broken_count=diag_count,
                read_length=Comm.DEFAULT_READ_BYTES,
                timeout=Comm.DEFAULT_TIMEOUT,
            )

            # 実際にメッセージが飛ぶことをテストするのは困難なため Mock 化する
            # It is difficult to test the actual sending of messages, so we mock it
            mock = mocker.patch("logger.measure.job.MessageService.send")

            # 実行 Execution
            prev = None

            # - diag_count + 1 件の同値データ
            # - Set log level to INFO
            for i in range(diag_count + 1):
                data = self._new_MeasureData(i)
                data.x = 0.5
                args.diag_sensor_data(prev=prev, next=data)
                prev = data

            # - 続いて変化したデータ
            # - Followed by changed data
            data = self._new_MeasureData(diag_count + 1)
            args.diag_sensor_data(prev=prev, next=data)

            # メッセージが2回送信されている
            # Ensure the message is sent twice
            assert mock.call_count == 2

            # メッセージ内容 Message content
            assert mock.call_args_list[0].args == (
                Topic.sensor_abnormal("TEST", "serial"),
                "Sensor on axis: x is possibly broken",
            )
            assert mock.call_args_list[1].args == (
                Topic.sensor_abnormal("TEST", "serial"),
                "Sensor on axis: x is fixed",
            )
