Source code for metabci.brainda.datasets.schirrmeister2017

# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/3/4
# License: MIT License
"""
High-gamma dataset.
"""
import re
from typing import Union, Optional, Dict, List
from pathlib import Path
import warnings
import numpy as np
import h5py
import mne
from mne.io import Raw
from mne.channels import make_standard_montage
from .base import BaseDataset
from ..utils.download import mne_data_path
from ..utils.channels import upper_ch_names

GIN_URL = "https://web.gin.g-node.org/robintibor/high-gamma-dataset/raw/master/data"


[docs]class Schirrmeister2017(BaseDataset): """High-gamma dataset discribed in Schirrmeister et al. 2017 Our “High-Gamma Dataset” is a 128-electrode dataset (of which we later only use 44 sensors covering the motor cortex, (see Section 2.7.1), obtained from 14 healthy subjects (6 female, 2 left-handed, age 27.2 ± 3.6 (mean ± std)) with roughly 1000 (963.1 ± 150.9, mean ± std) four-second trials of executed movements divided into 13 runs per subject. The four classes of movements were movements of either the left hand, the right hand, both feet, and rest (no movement, but same type of visual cue as for the other classes). The training set consists of the approx. 880 trials of all runs except the last two runs, the test set of the approx. 160 trials of the last 2 runs. This dataset was acquired in an EEG lab optimized for non-invasive detection of high- frequency movement-related EEG components (Ball et al., 2008; Darvas et al., 2010). Depending on the direction of a gray arrow that was shown on black back- ground, the subjects had to repetitively clench their toes (downward arrow), perform sequential finger-tapping of their left (leftward arrow) or right (rightward arrow) hand, or relax (upward arrow). The movements were selected to require little proximal muscular activity while still being complex enough to keep subjects in- volved. Within the 4-s trials, the subjects performed the repetitive movements at their own pace, which had to be maintained as long as the arrow was showing. Per run, 80 arrows were displayed for 4 s each, with 3 to 4 s of continuous random inter-trial interval. The order of presentation was pseudo-randomized, with all four arrows being shown every four trials. Ideally 13 runs were performed to collect 260 trials of each movement and rest. The stimuli were presented and the data recorded with BCI2000 (Schalk et al., 2004). The experiment was approved by the ethical committee of the University of Freiburg. References ---------- .. [1] Schirrmeister, Robin Tibor, et al. "Deep learning with convolutional neural networks for EEG decoding and visualization." Human brain mapping 38.11 (2017): 5391-5420. """ _EVENTS = { "right_hand": (1, (0, 4)), "left_hand": (2, (0, 4)), "rest": (3, (0, 4)), "feet": (4, (0, 4)), } _CHANNELS = [ "FP1", "FP2", "FPZ", "F7", "F3", "FZ", "F4", "F8", "FC5", "FC1", "FC2", "FC6", "T7", "C3", "CZ", "C4", "T8", "CP5", "CP1", "CP2", "CP6", "P7", "P3", "PZ", "P4", "P8", "POZ", "O1", "OZ", "O2", "AF7", "AF3", "AF4", "AF8", "F5", "F1", "F2", "F6", "FC3", "FCZ", "FC4", "C5", "C1", "C2", "C6", "CP3", "CPZ", "CP4", "P5", "P1", "P2", "P6", "PO5", "PO3", "PO4", "PO6", "FT7", "FT8", "TP7", "TP8", "PO7", "PO8", "FT9", "FT10", "TPP9H", "TPP10H", "PO9", "PO10", "P9", "P10", "AFF1", "AFZ", "AFF2", "FFC5H", "FFC3H", "FFC4H", "FFC6H", "FCC5H", "FCC3H", "FCC4H", "FCC6H", "CCP5H", "CCP3H", "CCP4H", "CCP6H", "CPP5H", "CPP3H", "CPP4H", "CPP6H", "PPO1", "PPO2", "I1", "IZ", "I2", "AFP3H", "AFP4H", "AFF5H", "AFF6H", "FFT7H", "FFC1H", "FFC2H", "FFT8H", "FTT9H", "FTT7H", "FCC1H", "FCC2H", "FTT8H", "FTT10H", "TTP7H", "CCP1H", "CCP2H", "TTP8H", "TPP7H", "CPP1H", "CPP2H", "TPP8H", "PPO9H", "PPO5H", "PPO6H", "PPO10H", "POO9H", "POO3H", "POO4H", "POO10H", "OI1H", "OI2H", ] def __init__(self): super().__init__( dataset_code="schirrmeister2017", subjects=list(range(1, 15)), events=self._EVENTS, channels=self._CHANNELS, srate=500, paradigm="imagery", )
[docs] def data_path( self, subject: Union[str, int], path: Optional[Union[str, Path]] = None, force_update: bool = False, update_path: Optional[bool] = None, proxies: Optional[Dict[str, str]] = None, verbose: Optional[Union[bool, str, int]] = None, ) -> List[List[Union[str, Path]]]: if subject not in self.subjects: raise (ValueError("Invalid subject id")) dests = [] base_url = "{u:s}/{t:s}/{s:d}.mat" dests = [ [ mne_data_path( base_url.format(u=GIN_URL, t=t, s=subject), self.dataset_code, path=path, proxies=proxies, force_update=force_update, update_path=update_path, ) for t in ["train", "test"] ] ] return dests
def _get_single_subject_data( self, subject: Union[str, int], verbose: Optional[Union[bool, str, int]] = None ) -> Dict[str, Dict[str, Raw]]: dests = self.data_path(subject) montage = make_standard_montage("standard_1005") montage.rename_channels( {ch_name: ch_name.upper() for ch_name in montage.ch_names} ) # montage.ch_names = [ch_name.upper() for ch_name in montage.ch_names] sess = dict() for isess, run_dests in enumerate(dests): runs = dict() for irun, run_array in enumerate(run_dests): raw = BBCIDataset(run_array).load() raw = upper_ch_names(raw) raw.set_montage(montage) runs["run_{:d}".format(irun)] = raw sess["session_{:d}".format(isess)] = runs return sess
[docs]class BBCIDataset(object): """ Loader class for files created by saving BBCI files in matlab (make sure to save with '-v7.3' in matlab, see https://de.mathworks.com/help/matlab/import_export/mat-file-versions.html#buk6i87 ) Parameters ---------- filename: str load_sensor_names: list of str, optional Also speeds up loading if you only load some sensors. None means load all sensors. Copyright Robin Schirrmeister, 2017 Altered by Vinay Jayaram, 2018 """ def __init__(self, filename, load_sensor_names=None): self.__dict__.update(locals()) del self.self
[docs] def load(self): cnt = self._load_continuous_signal() cnt = self._add_markers(cnt) return cnt
def _load_continuous_signal(self): wanted_chan_inds, wanted_sensor_names = self._determine_sensors() fs = self._determine_samplingrate() with h5py.File(self.filename, "r") as h5file: samples = int(h5file["nfo"]["T"][0, 0]) cnt_signal_shape = (samples, len(wanted_chan_inds)) continuous_signal = np.ones(cnt_signal_shape, dtype=np.float32) * np.nan for chan_ind_arr, chan_ind_set in enumerate(wanted_chan_inds): # + 1 because matlab/this hdf5-naming logic # has 1-based indexing # i.e ch1,ch2,.... chan_set_name = "ch" + str(chan_ind_set + 1) # first 0 to unpack into vector, before it is 1xN matrix chan_signal = h5file[chan_set_name][ : ].squeeze() # already load into memory continuous_signal[:, chan_ind_arr] = chan_signal assert not np.any(np.isnan(continuous_signal)), "No NaNs expected in signal" # Assume we cant know channel type here automatically ch_types = ["eeg"] * len(wanted_chan_inds) info = mne.create_info( ch_names=wanted_sensor_names, sfreq=fs, ch_types=ch_types ) # Scale to volts from microvolts, (VJ 19.6.18) continuous_signal = continuous_signal * 1e-6 cnt = mne.io.RawArray(continuous_signal.T, info) return cnt def _determine_sensors(self): all_sensor_names = self.get_all_sensors(self.filename, pattern=None) if self.load_sensor_names is None: # if no sensor names given, take all EEG-chans eeg_sensor_names = all_sensor_names eeg_sensor_names = filter( lambda s: not s.startswith("BIP"), eeg_sensor_names ) eeg_sensor_names = filter(lambda s: not s.startswith("E"), eeg_sensor_names) eeg_sensor_names = filter( lambda s: not s.startswith("Microphone"), eeg_sensor_names ) eeg_sensor_names = filter( lambda s: not s.startswith("Breath"), eeg_sensor_names ) eeg_sensor_names = filter( lambda s: not s.startswith("GSR"), eeg_sensor_names ) eeg_sensor_names = list(eeg_sensor_names) assert len(eeg_sensor_names) in set( [128, 64, 32, 16] ), "check this code if you have different sensors..." # noqa self.load_sensor_names = eeg_sensor_names chan_inds = self._determine_chan_inds(all_sensor_names, self.load_sensor_names) return chan_inds, self.load_sensor_names def _determine_samplingrate(self): with h5py.File(self.filename, "r") as h5file: fs = h5file["nfo"]["fs"][0, 0] assert isinstance(fs, int) or fs.is_integer() fs = int(fs) return fs @staticmethod def _determine_chan_inds(all_sensor_names, sensor_names): assert sensor_names is not None chan_inds = [all_sensor_names.index(s) for s in sensor_names] assert len(chan_inds) == len(sensor_names), "All" "sensors" "should be there." # TODO: is it possible for this to fail? the list # comp fails first right? assert len(set(chan_inds)) == len(chan_inds), ( "No" "duplicated sensors" "wanted." ) return chan_inds
[docs] @staticmethod def get_all_sensors(filename, pattern=None): """ Get all sensors that exist in the given file. Parameters ---------- filename: str pattern: str, optional Only return those sensor names that match the given pattern. Returns ------- sensor_names: list of str Sensor names that match the pattern or all sensor names in the file. """ with h5py.File(filename, "r") as h5file: clab_set = h5file["nfo"]["clab"][:].squeeze() all_sensor_names = [ "".join(chr(c.squeeze()) for c in h5file[obj_ref]) for obj_ref in clab_set ] if pattern is not None: all_sensor_names = filter( lambda sname: re.search(pattern, sname), all_sensor_names ) return all_sensor_names
def _add_markers(self, cnt): with h5py.File(self.filename, "r") as h5file: event_times_in_ms = h5file["mrk"]["time"][:].squeeze() event_classes = h5file["mrk"]["event"]["desc"][:].squeeze().astype(np.int64) # Check whether class names known and correct order # class_name_set = h5file['nfo']['className'][:].squeeze() # all_class_names = [''.join(chr(c) for c in h5file[obj_ref]) # for obj_ref in class_name_set] event_times_in_samples = event_times_in_ms * cnt.info["sfreq"] / 1000.0 event_times_in_samples = np.uint32(np.round(event_times_in_samples)) # Check if there are markers at the same time previous_i_sample = -1 for i_event, (i_sample, id_class) in enumerate( zip(event_times_in_samples, event_classes) ): if i_sample == previous_i_sample: info = "{:d}: ({:.0f} and {:.0f}).\n".format( i_sample, event_classes[i_event - 1], event_classes[i_event] ) warnings.warn( "Same sample has at least two markers.\n" + info + "Marker codes will be summed." ) previous_i_sample = i_sample # Now create stim chan stim_chan = np.zeros_like(cnt.get_data()[0]) for i_sample, id_class in zip(event_times_in_samples, event_classes): stim_chan[i_sample] += id_class info = mne.create_info( ch_names=["STI 014"], sfreq=cnt.info["sfreq"], ch_types=["stim"] ) stim_cnt = mne.io.RawArray(stim_chan[None], info, verbose="WARNING") cnt = cnt.add_channels([stim_cnt]) event_arr = [ event_times_in_samples, [0] * len(event_times_in_samples), event_classes, ] cnt.info["events"] = np.array(event_arr).T return cnt