Source code for metabci.brainda.algorithms.decomposition.base

# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/07
# License: MIT License


from typing import Optional, List, Tuple, Union
import warnings
import numpy as np
from numpy import ndarray
from scipy.linalg import solve
from scipy.signal import sosfiltfilt, cheby1, cheb1ord
from sklearn.base import BaseEstimator, TransformerMixin, clone
from metabci.brainda.datasets.base import BaseTimeEncodingDataset
import mne


[docs]def robust_pattern(W : ndarray, Cx: ndarray, Cs: ndarray) -> ndarray: """Transform spatial filters to spatial patterns based on paper [1]_. Referring to the method mentioned in article [1],the constructed spatial filter only shows how to combine information from different channels to extract signals of interest from EEG signals, but if our goal is neurophysiological interpretation or visualization of weights, activation patterns need to be constructed from the obtained spatial filters. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- W : ndarray Spatial filters, shape (n_channels, n_filters). Cx : ndarray Covariance matrix of eeg data, shape (n_channels, n_channels). Cs : ndarray Covariance matrix of source data, shape (n_channels, n_channels). Returns ------- A : ndarray Spatial patterns, shape (n_channels, n_patterns), each column is a spatial pattern. References ---------- .. [1] Haufe, Stefan, et al. "On the interpretation of weight vectors of linear models in multivariate neuroimaging. Neuroimage 87 (2014): 96-110. """ # use linalg.solve instead of inv, makes it more stable # see https://github.com/robintibor/fbcsp/blob/master/fbcsp/signalproc.py # and https://ww2.mathworks.cn/help/matlab/ref/mldivide.html A = solve(Cs.T, np.dot(Cx, W).T).T return A
[docs]class FilterBank(BaseEstimator, TransformerMixin): """ Filter bank decomposition is a bandpass filter array that divides the input signal into multiple subband components and obtains the eigenvalues of each subband component. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- base_estimator : class Estimator for model training and feature extraction. filterbank : list[ndarray] A bandpass filter bank used to divide the input signal into multiple subband components. n_jobs : int Sets the number of CPU working cores. The default is None. References ---------- .. [1] Chen X, Wang Y, Nakanishi M, et al. High-speed spelling with a noninvasive brain-computer interface[J]. Proceedings of the national academy of sciences, 2015, 112(44): E6058-E6067. """ def __init__( self, base_estimator: BaseEstimator, filterbank: List[ndarray], n_jobs: Optional[int] = None, ): self.base_estimator = base_estimator self.filterbank = filterbank self.n_jobs = n_jobs
[docs] def fit(self, X: ndarray, y: Optional[ndarray] = None, **kwargs): """ Training model update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : None Training signal (parameters can be ignored, only used to maintain code structure). y : None Label data (ibid., ignorable). Yf : None Reference signal (ibid., ignorable). """ self.estimators_ = [ clone(self.base_estimator) for _ in range(len(self.filterbank)) ] X = self.transform_filterbank(X) for i, est in enumerate(self.estimators_): est.fit(X[i], y, **kwargs) # def wrapper(est, X, y, kwargs): # est.fit(X, y, **kwargs) # return est # self.estimators_ = Parallel(n_jobs=self.n_jobs)( # delayed(wrapper)(est, X[i], y, kwargs) for i, est in enumerate(self.estimators_)) return self
[docs] def transform(self, X: ndarray, **kwargs): """ The parameters stored in self are used to convert X into features, and X is filtered through the filter bank to obtain the eigenvalues of each subband component. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Test the signal. Returns ------- feat : ndarray, shape(n_trials, n_fre) Feature array. """ X = self.transform_filterbank(X) feat = [est.transform(X[i], **kwargs) for i, est in enumerate(self.estimators_)] # def wrapper(est, X, kwargs): # retval = est.transform(X, **kwargs) # return retval # feat = Parallel(n_jobs=self.n_jobs)( # delayed(wrapper)(est, X[i], kwargs) for i, est in enumerate(self.estimators_)) feat = np.concatenate(feat, axis=-1) return feat
[docs] def transform_filterbank(self, X: ndarray): """ The input signal is filtered through a filter bank. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Input signal. Returns ------- Xs: ndarray, shape(Nfb, n_trials, n_channels, n_samples) Individual subband components of the input signal. """ Xs = np.stack([sosfiltfilt(sos, X, axis=-1) for sos in self.filterbank]) return Xs
[docs]class FilterBankSSVEP(FilterBank): """ Filter bank analysis for SSVEP. The SSVEP is analyzed using filter banks, that is, multiple filters are combined to decompose the SSVEP signal into specific segments (subbands containing the original data) and obtain its characteristic data. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- filterbank : list[ndarray] The filter bank. base_estimator : class Estimator for model training and feature extraction. filterweights : ndarray Filter weight, default is None. n_jobs : int Sets the number of CPU working cores. The default is None. """ def __init__( self, filterbank: List[ndarray], base_estimator: BaseEstimator, filterweights: Optional[ndarray] = None, n_jobs: Optional[int] = None, ): self.filterweights = filterweights super().__init__(base_estimator, filterbank, n_jobs=n_jobs)
[docs] def transform(self, X: ndarray): # type: ignore[override] """ X is converted into features by using the parameters stored in self, and the eigenvalues of each subband component are obtained after the input signal is filtered by the filter bank. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : ndarray, shape(n_trials, n_channels, n_samples) Test the signal. Returns ------- features : ndarray, shape(n_trials, n_fre) Feature array. """ features = super().transform(X) if self.filterweights is None: return features else: features = np.reshape( features, (features.shape[0], len(self.filterbank), -1) ) return np.sum( features * self.filterweights[np.newaxis, :, np.newaxis], axis=1 )
[docs]class TimeDecodeTool: """ Decoding tool set for TDMA coding paradigm. Applicable data sets include P300 speller data set and aVEP speller data.The main functions include: dividing the trial according to the minor event, downsampling the data, and determining the target character (or instruction) according to the judgment result of the trial. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- dataset : BaseTimeEncodingDataset The data set to be decoded. feature_operation : str An operation performed after feature extraction for each attempt of the same class. """ def __init__(self, dataset: BaseTimeEncodingDataset, feature_operation: str = 'sum'): # Get minor event from the dataset minor_events = dataset.minor_events minor_class = list() for event in minor_events.values(): minor_class.append(event[0]) minor_class.sort() self.minor_class = np.array(minor_class) self.encode_map = dataset.encode self.encode_loop = dataset.encode_loop self.feature_operation = feature_operation def _trial_feature_split(self, key: str, feature: ndarray): """ The extracted feature is divided according to the character big tag (key, which is used to determine the length of the encoding sequence, which can be any big tag), the stimulus repetition cycle (self.encode_loop) and the encoding sequence corresponding to the big tag (self.encode_map). update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- key : str Character large label. feature : ndarray, shape(n_trials, n_class) A multidimensional array of the features of multiple attempts. The size of the array is the number of attempts x the number of template categories. Where the number of attempts is equal to the number of stimulus repeats * the length of the encoding sequence (key_encode_len). Returns ------- key : str Character large label. feature_storage : ndarray, shape(encode_loop, key_encode_len, n_class) A multi-dimensional array of the features of multiple attempts after partitioning. The size of the array is the number of rounds * the length of the encoding sequence * the number of template classes """ key_encode = self.encode_map[key] key_encode_len = len(key_encode) if key_encode_len * self.encode_loop != feature.shape[0]: raise ValueError('Epochs in the test trial does not same ' 'as the presetting parameter in dataset') # create a space for storage feature feature_storage = np.zeros((self.encode_loop, key_encode_len, *feature.shape[1:])) for row in range(self.encode_loop): for col in range(key_encode_len): feature_storage[row][col] = feature[row * key_encode_len + col, :] return key, feature_storage def _features_operation(self, feature_storage: ndarray, fold_num=6): """ The feature stack and other operations are carried out on the feature array with multiple repetitions. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- feature_storage : ndarray, shape(encode_loop, key_encode_len, n_samples) A multi-dimensional array composed of the features of multiple attempts after partitioning. The size of the array is the number of rounds * the length of the encoding sequence * the length of the feature vector. fold_num : int The stimulation was repeated. Returns ------- sum_feature : ndarray, shape(key_encode_len, n_class) A multi-dimensional array composed of features of multiple attempts after superposition. The size of the array is the length of the encoding sequence * category. """ if fold_num > np.shape(feature_storage)[0]: raise ValueError("The number of trial stacks cannot exceeds %d" % np.shape(feature_storage)[0]) if self.feature_operation == 'sum': sum_feature = np.sum(feature_storage[0:fold_num], axis=0, keepdims=False) return sum_feature def _predict(self, features: ndarray): """ To predict the category of trials based on the characteristics of the trials, the applicable data set includes aVEP speller data. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- features : ndarray, shape(key_encode_len, n_class) The eigenvalues are computed from multiple attempts and different templates Returns ------- predict_labels : ndarray, shape(key_encode_len, 1) The class of multiple attempts predicted from the eigenvalue. """ predict_labels = self.minor_class[np.argmax(features, axis=-1)] return predict_labels def _predict_p300(self, features: ndarray): """ The decoding method specifically designed for the classical column P300 speller can predict the category of the trial according to the characteristics of the trial. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- features : ndarray, shape(key_encode_len, n_class) The eigenvalues are computed from multiple attempts and different templates. Returns ------- predict_labels : ndarray, shape(key_encode_len, 1) The class of multiple attempts predicted from the eigenvalue. """ code_len = features.shape[0] half_len = int(code_len/2) predict_row = np.argmax(features[:half_len, -1]) predict_col = np.argmax(features[half_len:, -1])+6 predict_labels = np.ones_like(self.minor_class, dtype=int) predict_labels[predict_row] = 2 predict_labels[predict_col] = 2 return predict_labels def _find_command(self, predict_labels: ndarray): """ The class of the character to be tested is determined by comparing the encoding sequence of each character (instruction) with the class predicted from multiple trials. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- predict_labels : ndarray, shape(key_encode_len, 1) The class of multiple attempts predicted from the eigenvalue. Returns ------- key or none : str or none The character to be tested is predicted according to the class sequence of the time to be tested. If the predicted sequence exists in the encoded sequence of the data set, the character corresponding to the predicted sequence is output; If the prediction sequence does not exist in the dataset encoding sequence, output none. """ for key, value in self.encode_map.items(): if np.array_equal(np.array(value), predict_labels): return key return None
[docs] def decode(self, key: str, feature: ndarray, fold_num=6, paradigm='avep'): """ The data is decoded according to character large label (used to determine the encoding sequence length, which can be any large label) characteristics, stimulus repetition cycles (fold_num), and normal form types. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- key : str Character large label. feature : ndarray, shape(n_trials, n_class) A multidimensional array of the features of multiple attempts. The size of the array is the number of attempts x the number of template categories. Where the number of attempts is equal to the number of stimulus repeats * the length of the encoding sequence (key_encode_len). fold_num : int The stimulation was repeated. paradigm : str Type of paradigm. Returns ------- command : str The character to be tested is predicted according to the class sequence of the test. """ if feature.ndim < 2: feature = feature[:, np.newaxis] alpha_key, feature_storage = self._trial_feature_split(key, feature) merge_features = self._features_operation(feature_storage, fold_num) predict_labels = [] if paradigm == 'avep': predict_labels = self._predict(merge_features) elif paradigm == 'p300': predict_labels = self._predict_p300(merge_features) command = self._find_command(np.array(predict_labels)) return command
[docs] def target_calibrate(self, y, key): """ A trial identification method specifically designed for the classic column P300 speller. According to the trial label (y) and character label (key) of the labeled column in the P300 data set, the trial label is converted into a small label that can label "target" and "non-target". update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- y : list Each element is a character corresponding to all the try labels. key: A large label, which contains the label value (key.index) and the character corresponding to the label (key.value). Returns ------- y_tar : list Each element is all the small labels corresponding to a character (labeled "target" and "non-target"). """ y_tar = [] for i in range(len(y)): character = key.values[i] target_id = np.where( np.array(self.encode_map[character]) == 2)[0]+1 target_loc = [] event = y[i].copy() for j in target_id: target_loc = np.append(target_loc, np.where(event == j)) target_loc = np.array(target_loc, dtype=int) event[:] = 1 event[target_loc] = 2 y_tar.append(event) return y_tar
[docs] def resample(self, x, fs_old, fs_new, axis=None): """ Each element is all the small labels that correspond to a character (labeled "target" and "non-target"). update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- x : ndarray Each element is a character corresponding to all the try labels. fs_old : float The original sampling rate of x. fs_new : float Sampling rate of resampling. axis: Dimensions of resampling. Returns ------- x_1 : ndarray Data after resampling. """ if axis is None: axis = x.ndim-1 down_factor = fs_old/fs_new x_1 = mne.filter.resample(x, down=down_factor, axis=axis) return x_1
[docs] def epoch_sort(self, X, y): """ A trial-ordering method designed specifically for the classic column P300 speller. The trials are sorted in ascending order according to the trial label of a single round of characters. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- X : list Pre-sort data for multiple characters, where each element represents the data for all attempts of a character. y : list A multi-character trial tag, where each element represents the label value of all the tries of a character, and the label value represents the currently blinking row or column. Returns ------- X_sort : list The sorted data of multiple characters is arranged in ascending order of the label value, where each element represents the data of all attempts of a character. Y_sort : list After the sorting of multiple characters, each element in the ascending order of the label value represents the label value of all the tries of a character. The label value represents the current blinking row or column. """ code_len = len(self.minor_class) X_sort = [[] for i in range(len(X))] Y_sort = [[] for i in range(len(y))] for char_i in range(len(X)): for loop_i in range(self.encode_loop): epoch_id = np.arange(loop_i*code_len, (loop_i+1)*code_len) y_i = y[char_i][epoch_id] x_i = X[char_i][epoch_id] id = np.argsort(y_i) x_sort = x_i[id, :, :] y_sort = y_i[id] X_sort[char_i].append(x_sort) Y_sort[char_i].append(y_sort) X_sort[char_i] = np.concatenate(X_sort[char_i], axis=0) Y_sort[char_i] = np.concatenate(Y_sort[char_i], axis=0) return X_sort, Y_sort
[docs]def generate_filterbank( passbands: List[Tuple[float, float]], stopbands: List[Tuple[float, float]], srate: int, order: Optional[int] = None, rp: float = 0.5, ): """ Create a filter bank, that is, obtain a bandpass filter coefficient that can divide the input signal into multiple subband components. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- passbands : list or tuple(float, float) Passband parameters. stopbands : list or tuple(float, float) Stopband parameters. srate : float Sampling rate. order : int Filter order. rp : float The maximum ripple allowed in the passband below the unit gain is 0.5 by default. Returns ------- Filterbank:ndarray, shape(len(passbands), N, 6) Filter bank coefficient. """ filterbank = [] for wp, ws in zip(passbands, stopbands): if order is None: N, wn = cheb1ord(wp, ws, 3, 40, fs=srate) sos = cheby1(N, rp, wn, btype="bandpass", output="sos", fs=srate) else: sos = cheby1(order, rp, wp, btype="bandpass", output="sos", fs=srate) filterbank.append(sos) return filterbank
[docs]def generate_cca_references( freqs: Union[ndarray, int, float], srate, T, phases: Optional[Union[ndarray, int, float]] = None, n_harmonics: int = 1, ): """ Construct a sine-cosine reference signal for canonical correlation analysis (CCA). update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- freqs : int or float Frequency. srate : int Sampling rate. T : int Sampling time. phases : int or float Phase, default is None. n_harmonics : int The number of harmonics. The default value is 1. Returns ------- Yf:ndarray, shape(srate*T, n_harmonics*2) Sine and cosine reference signal. """ if isinstance(freqs, int) or isinstance(freqs, float): freqs = np.array([freqs]) freqs = np.array(freqs)[:, np.newaxis] if phases is None: phases = 0 if isinstance(phases, int) or isinstance(phases, float): phases = np.array([phases]) phases = np.array(phases)[:, np.newaxis] t = np.linspace(0, T, int(T * srate)) Yf = [] for i in range(n_harmonics): Yf.append( np.stack( [ np.sin(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), np.cos(2 * np.pi * (i + 1) * freqs * t + np.pi * phases), ], axis=1, ) ) Yf = np.concatenate(Yf, axis=1) return Yf
[docs]def sign_flip(u, s, vh=None): """Flip signs of SVD or EIG using the method in paper [1]_. update log: 2023-12-10 by Leyi Jia <18020095036@163.com>, Add code annotation Parameters ---------- u: ndarray left singular vectors, shape (M, K). s: ndarray singular values, shape (K,). vh: ndarray or None transpose of right singular vectors, shape (K, N). Returns ------- u: ndarray corrected left singular vectors. s: ndarray singular values. vh: ndarray transpose of corrected right singular vectors. References ---------- .. [1] https://www.sandia.gov/~tgkolda/pubs/pubfiles/SAND2007-6422.pdf """ if vh is None: total_proj = np.sum(u * s, axis=0) signs = np.sign(total_proj) random_idx = signs == 0 if np.any(random_idx): signs[random_idx] = 1 warnings.warn( "The magnitude is close to zero, the sign will become arbitrary." ) u = u * signs return u, s else: left_proj = np.sum(s[:, np.newaxis] * vh, axis=-1) right_proj = np.sum(u * s, axis=0) total_proj = left_proj + right_proj signs = np.sign(total_proj) random_idx = signs == 0 if np.any(random_idx): signs[random_idx] = 1 warnings.warn( "The magnitude is close to zero, the sign will become arbitrary." ) u = u * signs vh = signs[:, np.newaxis] * vh return u, s, vh