import {re, im, complex, exp, multiply, abs} from 'mathjs';
import {fft, ifft} from '@signalprocessing/transforms';
import {pipe} from 'rxjs';
import {map} from 'rxjs/operators';

/**
 * Generates an array of log spaced frequencies
 * @method createLogSpacedArray
 * @example createLogSpacedArray(1, 32, 30)
 * @param {number} options.frequencyNumber Number of frequencies for Log Spaced Frequencies
 * @param {number} options.frequencyMin Minimum frequency in Hz
 * @param {number} options.frequencyMax Maximum frequency in Hz
 *
 * @returns {Array<frequency>}
 */
const createLogSpacedArray = (frequencyMin, frequencyMax, frequencyNumber) => {
  const ratio = frequencyMax / frequencyMin;
  const steps = frequencyNumber - 1;
  return [...Array(frequencyNumber).keys()].map(
    index => frequencyMin * Math.pow(ratio, index / steps)
  );
};

/**
 * Applies a Wavelet Transform to a stream of Epochs of EEG data.
 * Finds the average power within bins bounded by time triggers and conventional EEG frequency bands.
 * Array element 1: Beta band - between triggers 0 and 1 in time.
 * Array element 20: Delta band - between last trigger in time and epoch end.
 * Returns a stream of PSD arrays.
 * @method wavelet
 * @example eeg$.pipe(epoch({ duration: 256, interval: 100, samplingRate: 125 }), wavelet({ _, _, _, {0: 0, 1: 250, 2: 500, 3: 750}, _ }))
 * @param {Object} options - Wavelet options
 * @param {number} [options.frequencyNumber=30] Number of frequencies for Log Spaced Frequencies
 * @param {number} [options.frequencyMin=1] Minimum frequency in Hz
 * @param {number} [options.frequencyMax=32] Maximum frequency in Hz
 * @param {Object} options.triggers Triggers in time. Key = trigger index. Value = index of trigger in signal array.
 * @param {number} [options.waveNumber=6] Wavenumber for wavelet transform
 *
 * @returns {Observable<Array<PSD>>}
 */
export const wavelet = ({
  frequencyNumber = 30,
  frequencyMin = 1,
  frequencyMax = 32,
  triggers,
  waveNumber = 6
})  => {
  // Wavelet transform single channel
  const transformChannel = (signal, srate = 125) => {
    console.log("frequencyMin: " + frequencyMin)
    console.log("triggers: " + triggers)
    // Vector of log spaced frequencies sampled for analysis
    const frequencyLogSpaced = createLogSpacedArray(frequencyMin, frequencyMax, frequencyNumber);

    // Trigger Indices
    const triggerIndices = [...Object.keys(triggers).map(key => triggers[key]), signal.length - 1];

    // Bin frequencies
    const binFrequencies = [3, 7, 13, frequencyMax];

    // Time vector for the wavelet
    const waveTime = [...Array(4 * srate + 1).keys()].map(a => a / srate - 2);

    // Pad the signal
    const signalPadded = [...signal, ...Array(Math.floor(waveTime.length)).fill(0)];

    // DFT of padded signal
    const timeSeriesFFT = fft(signalPadded);

    // Initialize spectrogram
    const spectogram = [...Array(binFrequencies.length)].map(elem => []);

    // loop morletFrequency
    for (let frequencyIndex = 0; frequencyIndex < frequencyNumber; frequencyIndex++) {
      const morletFrequency = frequencyLogSpaced[frequencyIndex];

      // Gaussian Distribution
      const gaussianSigma = waveNumber / (2 * Math.PI * morletFrequency);
      const gaussianAmplitude = 1 / Math.sqrt(gaussianSigma * Math.sqrt(Math.PI));
      const gaussianWidth = 2 * Math.pow(gaussianSigma, 2);
      const gaussian = waveTime.map(t => gaussianAmplitude * exp(-Math.pow(t, 2) / gaussianWidth));

      // Morlet Wavelet
      const csw = waveTime.map(t => exp(complex(0, 2 * Math.PI * morletFrequency * t)));
      const morletWavelet = csw.map((cswData, index) => multiply(cswData, gaussian[index]));
      const morletPadded = [...morletWavelet, ...Array(Math.floor(signal.length)).fill(0)];

      // DFT of Morlet and normalize
      const morletFFT = fft(re(morletPadded), im(morletPadded));
      const morletMagnitude = morletFFT[0].map((real, index) =>
        abs(complex(real, morletFFT[1][index]))
      );
      const morletMax = Math.max(...morletMagnitude);
      const morletFFTNormalized = morletFFT.map(array => array.map(number => number / morletMax));

      // Get output of time domain convolution/frequency domain multiplication
      const morletTimeSeriesProduct = [[], []];
      for (let i = 0; i < morletFFTNormalized[0].length; i++) {
        const result = multiply(
          complex(morletFFTNormalized[0][i], morletFFTNormalized[1][i]),
          complex(timeSeriesFFT[0][i], timeSeriesFFT[1][i])
        );
        morletTimeSeriesProduct[0].push(re(result));
        morletTimeSeriesProduct[1].push(im(result));
      }

      const convolutionResult = ifft(morletTimeSeriesProduct[0], morletTimeSeriesProduct[1]);

      let convolutionResultPower = [];
      for (
        let i = Math.floor(waveTime.length / 2);
        i < convolutionResult[0].length - Math.floor(waveTime.length / 2);
        i++
      ) {
        convolutionResultPower.push(
          Math.pow(2 * abs(complex(convolutionResult[0][i], convolutionResult[1][i])), 2)
        );
      }

      // Slice Time Bins
      const timeAverages = triggerIndices.map((trigger, index, array) => {
        return convolutionResultPower
          .slice(trigger, array[index + 1])
          .reduce((currentAverage, value, index) => (currentAverage * index + value) / (index + 1));
      });

      if (morletFrequency <= binFrequencies[0]) {
        spectogram[0].push(timeAverages);
      } else if (morletFrequency <= binFrequencies[1]) {
        spectogram[1].push(timeAverages);
      } else if (morletFrequency <= binFrequencies[2]) {
        spectogram[2].push(timeAverages);
      } else {
        spectogram[3].push(timeAverages);
      }
    }

    // Average Frequency Bins
    const averages = spectogram.map(data => {
      const binAverages = Array(data[0].length);
      for (let i = 0; i < data[0].length; i++) {
        binAverages[i] = data.reduce((currentAverage, freqArray, index) => {
          return (currentAverage * index + freqArray[i]) / (index + 1);
        }, 0);
      }

      return binAverages;
    });

    // Reorder from highest frequency to lowest. Flatten 2D array to 1D
    return averages.reverse().flat();
  };

  return pipe(
    map(epoch => ({
      binnedPsdsArray: epoch.data.map(channel =>
        transformChannel(channel, epoch.info.samplingRate)
      ),
      info: epoch.info,
    }))
  );
};
