# Copyright 2024, Seiko Epson Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the “Software”), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.

from dataclasses import dataclass, field
from enum import Enum

import numpy as np
import numpy.typing as npt

from vc_calc.vc_constants import VC_JUDGE_TBL, VC_LEVEL_TBL
from vc_calc.vc_data import DataType, MeasureData, OctBand, VcCalcResult


class VcCalc:
    """
    VC Calculation Class
    Performs VC calculations at specified sampling intervals and outputs the results.
    """
    # =======================================================================================================
    # Public Methods
    # =======================================================================================================
    def __init__(
        self, \
        fft_size: int = 10000, \
        exec_size: int = 1000, \
        avr_size: int = 30, \
        oct_band: Enum = OctBand.Oct_1_3, \
        data_type: Enum = DataType.Velocity, \
        data_rate: int = 1000
    ):
        """
        Constructor
        - input
            fft_size: Number of FFT calculation points (default = 10000, range = 2000 to 40000, must be divisible by 2 and exec_size)
            exec_size: Number of input (acceleration) data for calculation execution (default = 1000)
            avr_size: Number of averages for VC output (default = 30, range = 10 to 50)
            oct_band: octave band units for frequency domain data compression
            data_type: output data type (acceleration or velocity)
            data_rate: input data rate should be 1000 sps
        Returns ValueError if any parameters value are out of specification.
        """
        # Parameter check
        # -- Range of fft_size (2000 to 40000)
        if fft_size < 2000 or fft_size > 40000:
            raise ValueError("fft_size should be in the range 2000 to 40000")
        # -- fft_size must be even
        if fft_size % 2 != 0:
            raise ValueError("fft_size should be divisible by 2")
        # -- exec_size must be less than or equal to fft_size
        if exec_size < 1000 or exec_size > fft_size:
            raise ValueError("exec_size should be in the range 1000 to fft_size")
        # -- fft_size must be divisible by exec_size
        if fft_size % exec_size != 0:
            raise ValueError("fft_size should be divisible by exec_size")
        # -- Range of avr_size (10 to 50)
        if avr_size < 10 or avr_size > 50:
            raise ValueError("avr_size should be in the range 10 to 50")
        # -- Range of data_rate (1000)
        if data_rate != 1000:
            raise ValueError("data_rate should be 1000")

        # Class scope variables
        self.fft_size: int = fft_size  # FFT size
        self.exec_size: int = exec_size  # Calculation unit
        self.avr_size: int = avr_size  # Number of VC result averages
        self.data_rate: int =  data_rate  # data rate
        self.oct_band: Enum = oct_band # octave band units
        self.data_type: Enum = data_type # output data type

        self.acc_buf = Acc_buf()  # Acceleration buffer
        self.vc_sum = Acc_buf()  # VC sum buffer
        self.vc_result_sgl = VcCalcResult()  # VC calculation result (single)
        self.vc_result_avr = VcCalcResult()  # VC calculation result (average)
        self.avr_count: int = 0  # Number of VC result average processes
        self.buff_full: bool = False  # Buffer full for FFT calculation
        self.buff_1st: bool = False  # Buffer full for first calculation after start

        # make octave band frequency table
        self.octave_band_tbl: npt.NDArray[np.float64] = (
            self._make_oct_tbl()
        )
        self.oct_tbl_size = len(self.octave_band_tbl)
        # make fft position table
        self.fft_position_table: npt.NDArray[np.float64] = (
            self._make_fft_position_table()
        )
        # Hanning window
        self.hanning: npt.NDArray[np.float64] = np.hanning(self.fft_size)

        # Allocate memory area for VC sum buffer
        self.vc_sum.x = [0] * self.oct_tbl_size
        self.vc_sum.y = [0] * self.oct_tbl_size
        self.vc_sum.z = [0] * self.oct_tbl_size

        # Allocate memory area for VC result list
        self.vc_result_sgl.c = [0] * self.oct_tbl_size
        self.vc_result_sgl.x = [0] * self.oct_tbl_size
        self.vc_result_sgl.y = [0] * self.oct_tbl_size
        self.vc_result_sgl.z = [0] * self.oct_tbl_size
        self.vc_result_avr.c = [0] * self.oct_tbl_size
        self.vc_result_avr.x = [0] * self.oct_tbl_size
        self.vc_result_avr.y = [0] * self.oct_tbl_size
        self.vc_result_avr.z = [0] * self.oct_tbl_size

    # =======================================================================================================
    def calc(self, measData: MeasureData) -> tuple[VcCalcResult, VcCalcResult]:
        """
        - summary: VC Calculation
            Inputs sensor data at each sampling and outputs VC calculation results using fft_size sampling data every exec_size sampling.
            The VC determination process is carried out under the following conditions:
                - oct_band = OctBand.Oct_1_3
            Additionally, the VC determination process will not produce correct results unless the following condition is met:
                - The sampling rate of the input acceleration is 1000 sps
        - input: MeasureData
        - output: VcCalcResult (single), VcCalcResult (average)
        """
        self.vc_result_sgl.done = False  # Mark calculation as incomplete
        self.vc_result_avr.done = False  # Mark calculation as incomplete
        vc = Acc_buf()  # VC calculation work buffer

        # Save measurement values to calculation buffer

        # -- Save data until fft_size is reached for the first time
        if not self.buff_1st:
            self.acc_buf.x.append(measData.x)
            self.acc_buf.y.append(measData.y)
            self.acc_buf.z.append(measData.z)
            if len(self.acc_buf.x) == self.fft_size:
                self.buff_1st = True
                self.buff_full = True

        # -- After the second time, data is shifted and saved for each exec_size.
        else:
            if len(self.acc_buf.x) == self.fft_size:
                self.acc_buf.x = self.acc_buf.x[self.exec_size :]
                self.acc_buf.y = self.acc_buf.y[self.exec_size :]
                self.acc_buf.z = self.acc_buf.z[self.exec_size :]

            self.acc_buf.x.append(measData.x)
            self.acc_buf.y.append(measData.y)
            self.acc_buf.z.append(measData.z)
            if len(self.acc_buf.x) == self.fft_size:
                self.buff_full = True

        # Execute FFT and VC calculation every exec_size
        if self.buff_full:
            # Reset flag
            self.buff_full = False

            # VC calculation
            vc.x = self._calc_sub_one_axis(self.acc_buf.x)
            vc.y = self._calc_sub_one_axis(self.acc_buf.y)
            vc.z = self._calc_sub_one_axis(self.acc_buf.z)

            # Accumulate VC results and calculate 3-axis composite
            for cnt in range(len(vc.x)):
                # Output single mode withdout totalizing
                self.vc_result_sgl.x[cnt] = vc.x[cnt]
                self.vc_result_sgl.y[cnt] = vc.y[cnt]
                self.vc_result_sgl.z[cnt] = vc.z[cnt]
                self.vc_result_sgl.c[cnt] = np.sqrt(
                    vc.x[cnt] ** 2 + vc.y[cnt] ** 2 + vc.z[cnt] ** 2
                )

                # Average mode outputs totalized
                self.vc_sum.x[cnt] = self.vc_sum.x[cnt] + vc.x[cnt]
                self.vc_sum.y[cnt] = self.vc_sum.y[cnt] + vc.y[cnt]
                self.vc_sum.z[cnt] = self.vc_sum.z[cnt] + vc.z[cnt]

            # Average processing for average mode
            # -- Increment counter
            self.avr_count = self.avr_count + 1
            self.vc_result_avr.avr_num = self.avr_count
            # -- Average processing after the average number of times accumulated.
            if self.avr_count == self.avr_size:
                # Averaging process
                for cnt in range(len(vc.x)):
                    x: float = self.vc_sum.x[cnt] / self.avr_size
                    y: float = self.vc_sum.y[cnt] / self.avr_size
                    z: float = self.vc_sum.z[cnt] / self.avr_size

                    self.vc_result_avr.x[cnt] = x
                    self.vc_result_avr.y[cnt] = y
                    self.vc_result_avr.z[cnt] = z
                    self.vc_result_avr.c[cnt] = np.sqrt(x**2 + y**2 + z**2)

                # Clear buffer and counter
                self.vc_sum.x = [0] * self.oct_tbl_size
                self.vc_sum.y = [0] * self.oct_tbl_size
                self.vc_sum.z = [0] * self.oct_tbl_size
                self.avr_count = 0

                # execute vc judgement process in average mode
                (
                    self.vc_result_avr.max_f,
                    self.vc_result_avr.max_v,
                    self.vc_result_avr.vc_lvl,
                ) = self._vc_analysis(self.vc_result_avr.c)

                # Average mode calculation complete
                self.vc_result_avr.done = True

            # execute vc judgement process in single mode
            (
                self.vc_result_sgl.max_f,
                self.vc_result_sgl.max_v,
                self.vc_result_sgl.vc_lvl,
            ) = self._vc_analysis(self.vc_result_sgl.c)

            # Single mode calculation complete
            self.vc_result_sgl.done = True

        return self.vc_result_sgl, self.vc_result_avr

    # =======================================================================================================
    # Local methods
    # =======================================================================================================
    def _calc_sub_one_axis(self, acc_list: list[float]) -> list[float]:
        """
        - summary: Sub-method of calc, performs calculations for one axis
        - input: List of acceleration data to perform FFT
        - output: Velocity in 1/3 octave band interval
        """
        vc_result: list[float] = []

        # 3. Calculate average bias
        bias: np.float64 = np.mean(acc_list)

        # 4. Subtract average bias and apply Hanning window
        fft_in: npt.NDArray[np.float64] = (np.array(acc_list) - bias) * self.hanning

        # 5. Perform FFT power spectrum calculation
        fft_out: npt.NDArray[np.complex128] = np.fft.fft(fft_in)  # Perform FFT
        fft_out = fft_out[0 : int(len(fft_out) / 2)]  # Exclude the folded part
        fft_out = fft_out / int(len(fft_out))  # Divide by the number of FFT points
        fft_out = np.real(fft_out) ** 2 + np.imag(fft_out) ** 2  # Calculate power spectrum

        # 6. Adjust FFT using Parseval's theorem
        #  - Adjust to match the total energy in the time domain (1σ) with the total energy in the frequency domain (sum of the spectrum) *1/2
        #  - Multiply by the attenuation of the Hanning window *8/3 ... 8/3 * 1/2 = 4/3
        fft_out = fft_out * 4 / 3

        # 7. Calculate acceleration or velocity in octave band
        #  (Calculate acceleration or velocity for each octave band of the FFT result)
        for cnt in range(len(self.fft_position_table)):
            # Extract values from the fft position table
            cf, min, max = self.fft_position_table[cnt]

            # Calculate the sum of values from the lower frequency limit to the upper frequency limit
            fft_work: npt.NDArray[np.complex128] = fft_out[int(min) : int(max) + 1]
            work: float = sum(fft_work)

            # Square root of the sum of values (power spectrum to amplitude power spectrum)
            work = np.sqrt(work)

            # if output data type is Velocity, convert G to velocity(mm/s)
            if self.data_type.value == DataType.Velocity.value:
                # 1. Convert G to Gal (cm/s^2) (1[G] = 980.665[Gal] (cm/s^2))
                # 2. Convert acceleration[Gal] to velocity[cm/s] (1[cm/s] = 1[cm/s^2] /2πf[1/s])
                # 3. Convert unit (1[mm/s] = 1[cm/s] *10)
                work = work * 980.665 / (2 * np.pi * cf) * 10

            # Save the result
            vc_result.append(work)

        return vc_result

    # =======================================================================================================
    def _vc_analysis(self, vc_in: list[float]) -> tuple[float, float, str]:
        """
        - summary: Calculate the maximum velocity, the center frequency of the maximum velocity, and the judgment level from the VC results
        - input: VC results
        - output: Maximum velocity, center frequency of the maximum velocity, judgment level
        """
        max_v: float = 0
        max_f: float = 0
        vc_max: int = len(VC_LEVEL_TBL) - 1
        vc_lvl: str = ""
        lows, vc_judge_tbl_cols = VC_JUDGE_TBL.shape

        # execute vc analysis, if vc judgement enable
        for cnt in range(len(vc_in)):
            # Calculate the maximum velocity and the center frequency of the maximum velocity
            freq: float = self.octave_band_tbl[cnt][0]

            work = vc_in[cnt]
            # if output data type is acceleration, convert G to velocity(mm/s)
            if self.data_type.value == DataType.Acceleration.value:
                # 1. Convert G to Gal (cm/s^2) (1[G] = 980.665[Gal] (cm/s^2))
                # 2. Convert acceleration[Gal] to velocity[cm/s] (1[cm/s] = 1[cm/s^2] /2πf[1/s])
                # 3. Convert unit (1[mm/s] = 1[cm/s] *10)
                work = work * 980.665 / (2 * np.pi * freq) * 10

            if max_v < work:
                max_v = work
                max_f = freq

            # convert 1/6, 1/12 oct band freq(cnt) to 1/3 oct band freq(cnt)
            cnt3: int = cnt
            if self.oct_band.value != OctBand.Oct_1_3.value:
                cnt3 = cnt3 // 2
                if self.oct_band.value != OctBand.Oct_1_6.value:
                    cnt3 = cnt3 // 2

            # Calculate the maximum VC judgment value
            vc_tmp: int = len(VC_LEVEL_TBL) - 1
            for cnt2 in range(vc_judge_tbl_cols):
                if work > VC_JUDGE_TBL[cnt3][cnt2]:
                    vc_tmp = cnt2
                    break
            if vc_max > vc_tmp:
                vc_max = vc_tmp

            # Convert VC judgment level
            vc_lvl = VC_LEVEL_TBL[vc_max]

        return max_f, max_v, vc_lvl

    # =======================================================================================================
    def _make_oct_tbl(self) -> npt.NDArray[np.float64]:
        """
        - summary: Generate a octave band frequency table
        """

        # local variables
        cent: list[float] = []
        min_val: list[float] = []
        max_val: list[float] = []

        table_num: int = 0
        div_num: int = 0

        # create 1/3 octave frequency table
        if self.oct_band.value == OctBand.Oct_1_3.value:
            table_num = 27
            div_num: int = 3
        # create 1/6 octave frequency table
        elif self.oct_band.value == OctBand.Oct_1_6.value:
            table_num = 54
            div_num: int = 6
        # create 1/12 octave frequency table
        elif self.oct_band.value == OctBand.Oct_1_12.value:
            table_num = 108
            div_num: int = 12
        # no need to check otherwise case

        # make table
        for i in range(table_num):
            cent.append(2 ** (i /div_num))
            min_val.append(2 ** ((i - 0.5) /div_num))
            max_val.append(2 ** ((i + 0.5) /div_num))

        return np.array([cent, min_val, max_val]).T

    # =======================================================================================================
    def _make_fft_position_table(self) -> npt.NDArray[np.float64]:
        """
        - summary: Generate a table to extract the FFT data position
        """

        # local variables
        freq_step: list[float] = []
        min_idx: list[int] = []
        max_idx: list[int] = []
        cnt_freq: list[float] = []
        prv: float = 0

        # Create frequency table
        freq_step.append(prv)
        for _i in range((int)(self.fft_size / 2)):
            freq_step.append(prv + self.data_rate / self.fft_size)
            prv = prv + self.data_rate / self.fft_size

        # Generate a table to extract the 1/n octave band range from the FFT results
        for cf, min, max in self.octave_band_tbl:
            # Record the center frequency
            cnt_freq.append(cf)
            # Detect the position of the lower limit frequency
            for idx, frq in enumerate(freq_step):
                if frq > min:
                    min_idx.append(idx)
                    break
            # Detect the position of the upper limit frequency
            for idx, frq in enumerate(freq_step):
                if frq > max:
                    max_idx.append(idx - 1)
                    break
        return np.array([cnt_freq, min_idx, max_idx]).T

#######################################################
# Structure class
@dataclass
class Acc_buf:
    """Acceleration data buffer"""
    x: list[float] = field(default_factory=list)
    y: list[float] = field(default_factory=list)
    z: list[float] = field(default_factory=list)
