Source code for metabci.brainda.paradigms.base

# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2020/6/01
# License: MIT License
"""
Base Paradigm Design.

"""
from abc import ABCMeta, abstractmethod
from typing import Union, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import mne
from mne.utils import verbose
from joblib import Parallel, delayed
from ..utils import pick_channels
from ..datasets.base import BaseDataset, BaseTimeEncodingDataset


[docs]def label_encoder(y, labels): new_y = y.copy() for i, label in enumerate(labels): ix = y == label new_y[ix] = i return new_y
[docs]class BaseParadigm(metaclass=ABCMeta): """Abstract Base Paradigm.""" def __init__( self, channels: Optional[List[str]] = None, events: Optional[List[str]] = None, intervals: Optional[List[Tuple[float, float]]] = None, srate: Optional[float] = None, ): """ Parameters ---------- channels : Optional[List[str]], optional selected channel names, if None use all channels in dataset, by default None events : Optional[List[str]], optional selected event names, if None use all events in dataset, by default None intervals : Optional[List[Tuple[Union[int, float]]]], optional selected event intervals, if None use default intervals in dataset. If only one interval passed, all events use the same interval. Otherwise the number of tuples should be the same as the number of events, by default None srate : Optional[float], optional sampling rate, if None use default srate in dataset, by default None """ self.select_channels = ( None if channels is None else [ ch_name.upper() for ch_name in channels] ) self.event_list = events self.intervals = intervals self.srate = srate self._raw_hook = None self._epochs_hook = None self._data_hook = None
[docs] @abstractmethod def is_valid(self, dataset: BaseDataset) -> bool: """Verify the dataset is compatible with the paradigm. This method is called to verify dataset is compatible with the paradigm. This method should raise an error if the dataset is not compatible with the paradigm. This is for example the case if the dataset is an ERP dataset for motor imagery paradigm, or if the dataset does not contain any of the required events. Parameters ---------- dataset : BaseDataset dataset """ pass
def _map_events_intervals(self, dataset: BaseDataset): """Select and map events with their inervals. Args: dataset (BaseDataset): a pre defined dataset Raises: ValueError: length of intervals should be the same number of events Returns: used_evnets: selected events, return in dict. used_intervals: intervals of selected events, return in dict """ event_list = self.event_list intervals = self.intervals if event_list is None: # use all events in dataset event_list = list(dataset.events.keys()) used_events = {ev: dataset.events[ev][0] for ev in event_list} if intervals is None: used_intervals = {ev: dataset.events[ev][1] for ev in event_list} elif len(intervals) == 1: used_intervals = {ev: intervals[0] for ev in event_list} else: if len(event_list) != len(intervals): raise ValueError( "intervals should be the same number of events") used_intervals = { ev: interval for ev, interval in zip(event_list, intervals) } return used_events, used_intervals
[docs] def register_raw_hook(self, hook): """Register raw hook before epoch operation. Parameters ---------- hook : callable object Callable object to process Raw object before epoch operation. Its signature should look like: hook(raw, caches) -> raw, caches where caches is an dict stroing information, raw is MNE Raw instance. """ self._raw_hook = hook
[docs] def register_epochs_hook(self, hook): """Register epochs hook after epoch operation. Parameters ---------- hook : callable object Callable object to process Epochs object after epoch operation. Its' signature should look like: hook(epochs, caches) -> epochs, caches where caches is an dict storing information, epochs is MNE Epochs instance. """ self._epochs_hook = hook
[docs] def register_data_hook(self, hook): """Register data hook before return data. Parameters ---------- hook : callable object Callable object to process ndarray data before return it. Its' signature should look like: hook(X, y, meta, caches) -> X, y, meta, caches where caches is an dict storing information, X, y are ndarray object, meta is a pandas DataFrame instance. """ self._data_hook = hook
[docs] def unregister_raw_hook(self): """Unregister raw hook before epoch operation.""" self._raw_hook = None
[docs] def unregister_epochs_hook(self): """Register epochs hook after epoch operation.""" self._epochs_hook = None
[docs] def unregister_data_hook(self): """Register data hook before return data.""" self._data_hook = None
@verbose def _get_single_subject_data(self, dataset, subject_id, verbose=False): """Return data in micro-volt.""" if not self.is_valid(dataset): raise TypeError( "Dataset {:s} is not valid for the current paradigm. Check your events and channels settings".format( dataset.dataset_code ) ) # # events, interval checking used_events, used_intervals = self._map_events_intervals(dataset) Xs = {} ys = {} metas = {} data = dataset.get_data([subject_id]) for subject, sessions in data.items(): for session, runs in sessions.items(): for run, raw in runs.items(): # do raw hook either self-implemented or dataset inherited caches = {} if self._raw_hook: raw, caches = self._raw_hook(raw, caches) elif hasattr(dataset, "raw_hook"): raw, caches = dataset.raw_hook(raw, caches) # pick selected channels by order channels = ( dataset.channels if self.select_channels is None else self.select_channels ) picks = pick_channels(raw.ch_names, channels, ordered=True) # find available events, first check stim_channels then annotations stim_channels = mne.utils._get_stim_channel( None, raw.info, raise_error=False ) if len(stim_channels) > 0: events = mne.find_events( raw, shortest_event=0, initial_event=True ) else: # convert event_id to its number type instead of default auto-renaming in 0.19.2 events, _ = mne.events_from_annotations( raw, event_id=(lambda x: int(x)) ) for event_name in used_events.keys(): # mne.pick_events returns any matching events in include # only raise Runtime Error when nothing is found # then we just skip this event try: selected_events = mne.pick_events( events, include=used_events[event_name] ) except RuntimeError: continue # transform Raw to Epochs epochs = mne.Epochs( raw, selected_events, event_id={event_name: used_events[event_name]}, event_repeated="drop", tmin=used_intervals[event_name][0], tmax=used_intervals[event_name][1] - 1.0 / raw.info["sfreq"], picks=picks, proj=False, baseline=None, preload=True, ) # skip invalid time intervals if len(epochs) == 0: continue # do epochs hook if self._epochs_hook: epochs, caches = self._epochs_hook(epochs, caches) elif hasattr(dataset, "epochs_hook"): epochs, caches = dataset.epochs_hook( epochs, caches) # FIXME: is this resample reasonable? if self.srate: # as MNE suggested, decimate after extract epochs # low-pass raw object in raw_hook to prevent aliasing problem epochs = epochs.resample(self.srate) # epochs = epochs.decimate(dataset.srate//self.srate) # retrieve X, y and meta X = epochs.get_data() * 1e6 # micro-volt default y = epochs.events[:, -1] trial_ids = np.argwhere( events[:, -1] == list(epochs.event_id.values())[0] ).reshape((-1)) meta = pd.DataFrame( { "subject": [subject] * len(epochs), "session": [session] * len(epochs), "run": [run] * len(epochs), "event": [event_name] * len(epochs), "trial_id": trial_ids, "dataset": [dataset.dataset_code] * len(epochs), } ) # do data hook if self._data_hook: X, y, meta, caches = self._data_hook( X, y, meta, caches) elif hasattr(dataset, "data_hook"): X, y, meta, caches = dataset.data_hook( X, y, meta, caches) # collecting data pre_X = Xs.get(event_name) if pre_X is not None: Xs[event_name] = np.concatenate((pre_X, X), axis=0) else: Xs[event_name] = X pre_y = ys.get(event_name) if pre_y is not None: ys[event_name] = np.concatenate((pre_y, y), axis=0) else: ys[event_name] = y pre_meta = metas.get(event_name) if pre_meta is not None: metas[event_name] = pd.concat( (pre_meta, meta), axis=0, ignore_index=True ) else: metas[event_name] = meta return Xs, ys, metas
[docs] @verbose def get_data( self, dataset: BaseDataset, subjects: List[Union[int, str]] = [], label_encode: bool = True, return_concat: bool = False, n_jobs: int = -1, verbose: Optional[bool] = None, ) -> Tuple[ Union[ Dict[str, Union[np.ndarray, pd.DataFrame]], Union[np.ndarray, pd.DataFrame] ], ..., ]: """Get data from dataset with selected subjects. Parameters ---------- dataset : BaseDataset dataset subjects : List[Union[int, str]], selected subjects, by default empty label_encode: bool, optional, if True, return y in label encode way return_concat : bool, optional if True, return concated ndarray object, otherwise return dict of events, by default False n_jobs : int, optional Parallel jobs, by default -1 verbose : Optional[bool], optional verbose, by default None Returns ------- Tuple[Union[Dict[str, Union[np.ndarray, pd.DataFrame]], Union[np.ndarray, pd.DataFrame]], ...] Xs, ys, metas, corresponding to data, label and meta data Raises ------ TypeError raise error if dataset is not avaliable for the paradigm """ if not self.is_valid(dataset): raise TypeError( "Dataset {:s} is not valid for the current paradigm. Check your events and channels settings".format( dataset.dataset_code ) ) # events, interval checking used_events, used_intervals = self._map_events_intervals(dataset) Xs = {} ys = {} metas = {} X, y, meta = zip( *Parallel(n_jobs=n_jobs)( delayed(self._get_single_subject_data)( dataset, sub_id, verbose=verbose) for sub_id in subjects ) ) for event_name in used_events.keys(): Xs[event_name] = np.concatenate( [X[i][event_name] for i in range(len(subjects)) if event_name in X[i]], axis=0, ) ys[event_name] = np.concatenate( [y[i][event_name] for i in range(len(subjects)) if event_name in y[i]], axis=0, ) metas[event_name] = pd.concat( [ meta[i][event_name] for i in range(len(subjects)) if event_name in meta[i] ], axis=0, ignore_index=True, ) if label_encode: event_list = list(used_events.keys()) event_id = [dataset.events[e][0] for e in event_list] for event_name in used_events.keys(): ys[event_name] = label_encoder(ys[event_name], event_id) # python gaurante values in insert order. if return_concat: Xs = np.concatenate(list(Xs.values()), axis=0) ys = np.concatenate(list(ys.values()), axis=0) metas = pd.concat(list(metas.values()), axis=0, ignore_index=True) return Xs, ys, metas
def __str__(self): desc = "{}".format(self.__class__.__name__) return desc
[docs]class BaseTimeEncodingParadigm(BaseParadigm): def __init__( self, channels: Optional[List[str]] = None, events: Optional[List[str]] = None, intervals: Optional[List[Tuple[float, float]]] = None, minor_event_intervals: Optional[List[Tuple[float, float]]] = None, srate: Optional[float] = None, ): super().__init__( channels=channels, events=events, intervals=intervals, srate=srate ) self._trial_hook = None self.minor_event_intervals = minor_event_intervals
[docs] def is_valid(self, dataset): pass
def _map_events_intervals(self, dataset): event_list = self.event_list intervals = self.intervals minor_event_intervals = self.minor_event_intervals if event_list is None: # If no given events, using the dataset defined events event_list = list(dataset.events.keys()) used_events = {ev: dataset.events[ev][0] for ev in event_list} if intervals is None: used_intervals = {ev: dataset.events[ev][1] for ev in event_list} elif len(intervals) == 1: used_intervals = {ev: intervals[0] for ev in event_list} else: if len(event_list) != len(intervals): raise ValueError( "Intervals should be the same number of events") used_intervals = { ev: intervals for ev, interval in zip(event_list, intervals) } # extract minor events, all the minor events should be pre-defined in the dataset minor_event_list = list(dataset.minor_events.keys()) used_minor_events = { ev: dataset.minor_events[ev][0] for ev in minor_event_list} if minor_event_intervals is None: used_minor_intervals = { ev: dataset.minor_events[ev][1] for ev in minor_event_list} elif len(minor_event_intervals) == 1: used_minor_intervals = {ev: minor_event_intervals[0] for ev in minor_event_list} else: if len(event_list) != len(intervals): raise ValueError( "Intervals should be the same number of events" ) used_minor_intervals = { ev: intervals for ev, interval in zip(minor_event_list, minor_event_intervals) } encode_dict = dataset.encode encode_loop = dataset.encode_loop return used_events, used_intervals, used_minor_events, used_minor_intervals, encode_loop, encode_dict
[docs] def register_trial_hook(self, hook): """Register trial hook before trial operation. Parameters __________ hook : callable object to process Raw object before epoch operation. Different from the raw_hook, the trial hook allows you to do some specific operation BEFORE epoch operation (i.e. smallest encode unit) and AFTER raw continuous data operation Its signature should look like: hook(raw, caches) -> raw, caches where caches is a dict storing information, raw is MNE Raw instance Returns ------- """ self._trial_hook = hook
[docs] def unregister_trial_hook(self): self._trial_hook = None
@verbose def _get_single_subject_data(self, dataset, subject_id, verbose=False): """ Parameters ---------- dataset subject_id verbose Returns ------- """ used_events, used_intervals, used_minor_events, \ used_minor_intervals, encode_loop, encode_dict = \ self._map_events_intervals(dataset) # interval equally verification intervals = list(used_minor_intervals.values()) if intervals.count(intervals[0]) == len(intervals): epoch_tmin = intervals[0][0] epoch_tmax = intervals[0][1] else: raise ValueError( 'The defined intervals of minor event do not equal, please check') Xs = {} ys = {} metas = {} data = dataset.get_data([subject_id]) for subject, sessions in data.items(): for session, runs in sessions.items(): for run, raw in runs.items(): # do raw hook either self-implemented or dataset inherited caches = {} if self._raw_hook: raw, caches = self._raw_hook(raw, caches) elif hasattr(dataset, "raw_hook"): raw, caches = dataset.raw_hook(raw, caches) # pick selected channels by order channels = ( dataset.channels if self.select_channels is None else self.select_channels ) picks = pick_channels(raw.ch_names, channels, ordered=True) stim_channels = mne.utils._get_stim_channel( None, raw.info, raise_error=False ) if len(stim_channels) > 0: events = mne.find_events( raw, shortest_event=0, initial_event=True ) else: events, _ = mne.events_from_annotations( raw, event_id=(lambda x: int(x)) ) # extract main events main_events = mne.pick_events( events, include=list(used_events.values()) ) for event_name in used_events.keys(): # mne.pick_events returns any matching events in include # only raise Runtime Error when nothing is found # then we just skip this event try: selected_events = mne.pick_events( events, include=used_events[event_name] ) except RuntimeError: continue # Find trial_index in the original events series trial_index = list(np.argwhere( main_events[:, -1] == selected_events[0, 2] )) selected_annots = mne.annotations_from_events( selected_events, sfreq=raw.info['sfreq']) selected_annots.set_durations( used_intervals[event_name][1] - used_intervals[event_name][0]) unit_raws = raw.copy().crop_by_annotations(annotations=selected_annots) try: unit_encode = encode_dict[event_name] except Exception: raise Exception( "Dataset does not contain the encode key {:s}".format( event_name) ) if isinstance(encode_loop, dict): try: encode_loop_size = encode_loop[event_name] except Exception: raise Exception( "Dataset does not contain the encode key {:s}".format( event_name) ) elif isinstance(encode_loop, int): encode_loop_size = encode_loop else: raise TypeError( "Unknown encode_loop type" ) for unit_raw in unit_raws: # do trial hook if self._trial_hook: unit_raw, caches = self._trial_hook( unit_raw, caches) elif hasattr(dataset, "epochs_hook"): unit_raw, caches = dataset.trial_hook( unit_raw, caches) # Try to extract minor events minor_events = mne.find_events( unit_raw, shortest_event=0, initial_event=True ) minor_events = np.delete(minor_events, 0, axis=0) selected_minor_events = mne.pick_events(minor_events, include=list(used_minor_events.values())) # transform Raw to Epochs epochs = mne.Epochs( unit_raw, selected_minor_events, event_id=used_minor_events, event_repeated="drop", tmin=epoch_tmin, tmax=epoch_tmax - 1.0 / unit_raw.info['sfreq'], picks=picks, proj=False, baseline=None, preload=True, on_missing='ignore' ) # skip invalid time intervals if len(epochs) == 0: continue # check if the len of epochs matches with setting parameters if epochs.__len__() != len(unit_encode) * encode_loop_size: raise RuntimeError( "The setting parameters does not match the Epoch length" ) # do epochs hook if self._epochs_hook: epochs, caches = self._epochs_hook( epochs, caches) elif hasattr(dataset, "epochs_hook"): epochs, caches = dataset.epochs_hook( epochs, caches) # Get all epochs within a single 'character' event. unit_X = epochs.get_data() * 1e6 unit_y = epochs.events[:, -1] # trial_id is the index in the original event series of raw # for the time encode paradigms, the trial_id indicate the index of main events trial_id = trial_index[0] trial_index.pop(0) # Unlike the base paradigm class, we manually process a single trial # So the meta only contains a single trial info meta = pd.DataFrame( { "subject": [subject], "session": [session], "run": [run], "event": [event_name], "trial_id": trial_id, "dataset": [dataset.dataset_code], "code": [unit_encode] } ) # collecting data pre_X = Xs.get(event_name) if pre_X is not None: Xs[event_name].append(unit_X) else: Xs[event_name] = list() Xs[event_name].append(unit_X) pre_y = ys.get(event_name) if pre_y is not None: ys[event_name].append(unit_y) else: ys[event_name] = list() ys[event_name].append(unit_y) pre_meta = metas.get(event_name) if pre_meta is not None: metas[event_name] = pd.concat( (pre_meta, meta), axis=0, ignore_index=True ) else: metas[event_name] = meta if self._data_hook: Xs, ys, metas, caches = self._data_hook( Xs, ys, metas, caches) elif hasattr(dataset, "data_hook"): Xs, ys, metas, caches = dataset.data_hook( Xs, ys, metas, caches) return Xs, ys, metas
[docs] @verbose def get_data( self, dataset: BaseTimeEncodingDataset, subjects: List[Union[int, str]] = [], return_concat: bool = False, n_jobs: int = -1, verbose: Optional[bool] = None, ): if not self.is_valid(dataset): raise TypeError( "Dataset {:s} is not valid for the current paradigm. Check your events and channels settings".format( dataset.dataset_code ) ) used_events, used_intervals, used_minor_events, \ used_minor_intervals, encode_loop, encode_dict = \ self._map_events_intervals(dataset) Xs = [] ys = [] metas = {} # Need to sort here # due to the subject data are storage in list in sequence subjects.sort() X, y, meta = zip( *Parallel(n_jobs=n_jobs)( delayed(self._get_single_subject_data)( dataset, sub_id, verbose=verbose) for sub_id in subjects ) ) for event_name in used_events.keys(): for i in range(len(subjects)): if event_name in X[i]: for j in range(len(X[i][event_name])): Xs.append(X[i][event_name][j]) for i in range(len(subjects)): if event_name in y[i]: for j in range(len(y[i][event_name])): ys.append(y[i][event_name][j]) if event_name in meta[i]: metas[event_name] = pd.concat( [ meta[i][event_name] for i in range(len(subjects)) if event_name in meta[i] ], axis=0, ignore_index=True ) metas = pd.concat(list(metas.values()), axis=0, ignore_index=True) return Xs, ys, metas