Source code for metabci.brainda.algorithms.feature_analysis.time_freq_analysis

import numpy as np
from scipy.signal import detrend, stft
import mne
from scipy.signal import hilbert
import matplotlib.pyplot as plt


[docs]class TimeFrequencyAnalysis: def __init__(self, fs): self.fs = fs
[docs] def func_morlet_wavelet(self, data, xtimes, omega, sigma, fs=None): """ -author: Xin Fengran -Created on: 2022-8-8 -update log: 2022-8-11 by Xin Fengran Args: data: ndarray(Nchannel, Ntimes), the EEG data xtimes: ndarray(N,), timeline of the EEG data omega: float. sigma: float. fs: float, the resampling rate Returns: P: ndarray(Nchannel, N, nTimes), square amplitude of Morlet wavelet transform S: ndarray(Nchannel, N, nTimes), complex values of Morlet wavelet transform """ data = detrend(data, axis=-1, type="linear") N_T = data.shape[1] N_C = data.shape[0] N_F = xtimes.shape[0] if fs is None: fs = self.fs f = xtimes / fs S = np.zeros((N_C, N_F, N_T)) P = np.zeros((N_C, N_F, N_T)) L_hw = N_T for fi in range(N_F): scaling_factor = omega / f[fi] u = (-np.arange(-L_hw, L_hw + 1)) / scaling_factor hw = ( np.sqrt(1 / scaling_factor) * np.exp(-(u**2) / (2 * sigma**2)) * np.exp(1j * 2 * np.pi * omega * u) ) for ci in range(N_C): S_full = np.convolve(data[ci, :], hw.conjugate()) S[ci, fi, :] = S_full[L_hw : L_hw + N_T] P = np.abs(S) ** 2 return P, S
[docs] def fun_stft( self, data, fs=None, window="hann", nperseg=256, noverlap=None, nfft=None, detrend=False, return_onesided=True, boundary="zeros", padded=True, axis=-1, ): """ -author: Xin Fengran -Created on: 2022-8-8 -updata log: 2022-8-11 by Xin Fengran Args: data(ndarray): the EEG data fs(float): the rasampling rate window(str or tuple or ndarray): desired window to use. nperseg(int): length of each segment. noverlap(int): number of points to overlap between segments. nfft(int): length of the FFT used. detrend(str or function or False): specifies how to detrend each segment. return_onesided(bool): If True, return a one-sided spectrum for real data. If False return a two-sided spectrum. Defaults to True, but for complex data, a two-sided spectrum is always returned. boundary(str): Specifies whether the input signal is extended at both ends, and how to generate the new values, in order to center the first windowed segment on the first input point. padded(bool): Specifies whether the input signal is zero-padded at the end. axis(int): axis along which the STFT is computed Returns: f(ndarray): array of sample frequencies t(ndarray): array of segment times Zxx(ndarray): the STFT of the EEG data """ if fs is None: fs = self.fs f, t, Zxx = stft( data, fs, window, nperseg, noverlap, nfft, detrend, return_onesided, boundary, padded, axis, ) return f, t, Zxx
[docs] def fun_topoplot(self, X, chan_names, sfreq=None, ch_types="eeg"): """ -author: Li Xiaoyu -Created on: 2022-8-8 -updata log: 2022-8-11 by Li Xiaoyu Args: X(ndarray): the input data chan_names(list): the name of channels sfreq(float): the sampling rate ch_types(str): the type of channel """ if sfreq is None: sfreq = self.fs montage_type = mne.channels.make_standard_montage( "standard_1020", head_size=0.1 ) epoch_info = mne.create_info( ch_names=chan_names, ch_types=ch_types, sfreq=sfreq ) epoch_info.set_montage(montage_type) # The value specifying the upper bound of the color range. vmax = np.max(X) # The value specifying the lower bound of the color range. vmin = np.min(X) # plot fig, ax = plt.subplots(1, figsize=(4, 4), sharex=True, sharey=True) im, cn = mne.viz.plot_topomap( X, epoch_info, axes=ax, show=False, vmax=vmax, vmin=vmin, cnorm=None ) ax.set_title("topomap", fontsize=25) plt.colorbar(im) plt.show()
[docs] def fun_hilbert(self, X, N=None, axis=-1): """ -author: Li Xiaoyu -Created on: 2022-8-8 -updata log: 2022-8-11 by Li Xiaoyu Args: X(ndarray): the input data N(int): length of the hilbert used. Returns: analytic_signal(ndarray): discrete-time analytic signal realEnv(ndarray): the real part of the discrete-time analytic signal. imagEnv(ndarray): the imaginary part of the discrete-time analytical signal. angle(ndarray): the angle of the discrete-time analytical signal. envModu(ndarray): the envelope of the discrete-time analytical signal """ analytic_signal = hilbert(X, N, axis) realEnv = np.real(analytic_signal) imagEnv = np.imag(analytic_signal) angle = np.angle(analytic_signal) envModu = np.sqrt(realEnv**2 + imagEnv**2) return analytic_signal, realEnv, imagEnv, angle, envModu