Source code for metabci.brainda.algorithms.deep_learning.base

# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/23
# License: MIT License
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from skorch.classifier import NeuralNetClassifier
from skorch.dataset import ValidSplit
from skorch.callbacks import LRScheduler, EpochScoring, Checkpoint, EarlyStopping


[docs]def compute_out_size( input_size: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, ): return int( (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 )
[docs]def compute_same_pad1d(input_size, kernel_size, stride=1, dilation=1): all_padding = (stride - 1) * input_size - stride + dilation * (kernel_size - 1) + 1 return (all_padding // 2, all_padding - all_padding // 2)
[docs]def compute_same_pad2d(input_size, kernel_size, stride=(1, 1), dilation=(1, 1)): ud = compute_same_pad1d( input_size[0], kernel_size[0], stride=stride[0], dilation=dilation[0] ) lr = compute_same_pad1d( input_size[1], kernel_size[1], stride=stride[1], dilation=dilation[1] ) return [*lr, *ud]
[docs]class MaxNormConstraintConv2d(nn.Conv2d): def __init__(self, *args, max_norm_value=1, norm_axis=2, **kwargs): self.max_norm_value = max_norm_value self.norm_axis = norm_axis super().__init__(*args, **kwargs)
[docs] def forward(self, input): self.weight.data = self._max_norm(self.weight.data) return super().forward(input)
def _max_norm(self, w): with torch.no_grad(): # similar behavior as keras MaxNorm constraint norms = torch.sqrt( torch.sum(torch.square(w), dim=self.norm_axis, keepdim=True) ) desired = torch.clamp(norms, 0, self.max_norm_value) # instead of desired/(eps+norm), without changing norm in range w *= desired / norms return w
[docs]class MaxNormConstraintLinear(nn.Linear): def __init__(self, *args, max_norm_value=1, norm_axis=0, **kwargs): self.max_norm_value = max_norm_value self.norm_axis = norm_axis super().__init__(*args, **kwargs)
[docs] def forward(self, input): self.weight.data = self._max_norm(self.weight.data) return super().forward(input)
def _max_norm(self, w): with torch.no_grad(): # similar behavior as keras MaxNorm constraint norms = torch.sqrt( torch.sum(torch.square(w), dim=self.norm_axis, keepdim=True) ) desired = torch.clamp(norms, 0, self.max_norm_value) # instead of desired/(eps+norm), without changing norm in range w *= desired / norms return w
def _adabn_pre_forward_hook(self, inputs): old_training_state = self.training self.eval() # global AdaBN with torch.no_grad(): if not hasattr(self, "num_samples_tracked"): self.num_samples_tracked = 0 self.running_mean.data.zero_() self.running_var.data.zero_() self.running_var.data.fill_(1) k = len(inputs[0]) self.num_samples_tracked += k module_name = self.__class__.__name__ if "BatchNorm1d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2)) var = torch.var(input[0], dim=(0, 2)) elif "BatchNorm2d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2, 3)) var = torch.var(inputs[0], dim=(0, 2, 3)) elif "BatchNorm3d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2, 3, 4)) var = torch.var(inputs[0], dim=(0, 2, 3, 4)) # see https://www.sciencedirect.com/science/article/abs/pii/S003132031830092X d = mean - self.running_mean.data self.running_mean.data.add_(d * k / self.num_samples_tracked) self.running_var.data.add_( (var - self.running_var.data) * k / self.num_samples_tracked + torch.square(d) * k * (self.num_samples_tracked - k) / (self.num_samples_tracked**2) ) if old_training_state: self.train() def _global_adabn_pre_forward_hook(self, inputs): old_training_state = self.training self.eval() # global AdaBN with torch.no_grad(): module_name = self.__class__.__name__ if "BatchNorm1d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2)) var = torch.var(inputs[0], dim=(0, 2)) elif "BatchNorm2d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2, 3)) var = torch.var(inputs[0], dim=(0, 2, 3)) elif "BatchNorm3d" in module_name: mean = torch.mean(inputs[0], dim=(0, 2, 3, 4)) var = torch.var(inputs[0], dim=(0, 2, 3, 4)) self.running_mean.data.zero_() self.running_mean.data.add_(mean) self.running_var.data.zero_() self.running_var.data.add_(var) if old_training_state: self.train()
[docs]def adaptive_batch_norm(model, use_global=False): # register pre_forward_hook handles = [] hook = _global_adabn_pre_forward_hook if use_global else _adabn_pre_forward_hook for module in model.modules(): if "BatchNorm" in module.__class__.__name__: handles.append(module.register_forward_pre_hook(hook)) return model, handles
def _glorot_weight_zero_bias(model): """Initalize parameters of all modules by initializing weights with glorot uniform/xavier initialization, and setting biases to zero. Weights from batch norm layers are set to 1. Parameters ---------- model: Module """ for module in model.modules(): if hasattr(module, "weight"): if not ("BatchNorm" in module.__class__.__name__): nn.init.xavier_uniform_(module.weight, gain=1) else: nn.init.constant_(module.weight, 1) if hasattr(module, "bias"): if module.bias is not None: nn.init.constant_(module.bias, 0) def _narrow_normal_weight_zero_bias(model): """Initalize parameters of all modules by initializing weights with narrow normal distribution N(0, 0.01). Weights from batch norm layers are set to 1. Parameters ---------- model: Module """ for module in model.modules(): if hasattr(module, "weight"): if not ("BatchNorm" in module.__class__.__name__): nn.init.normal_(module.weight, mean=0.0, std=1e-2) else: nn.init.constant_(module.weight, 1) if hasattr(module, "bias"): if module.bias is not None: nn.init.constant_(module.bias, 0)
[docs]class NeuralNetClassifierNoLog(NeuralNetClassifier):
[docs] def get_loss(self, y_pred, y_true, *args, **kwargs): return super(NeuralNetClassifier, self).get_loss( y_pred, y_true, *args, **kwargs )
[docs] def fit(self, X, y, **fit_params): net = super(NeuralNetClassifier, self).fit(X, y, **fit_params) callbacks = OrderedDict(net.callbacks) if "checkpoint" in callbacks: net.load_params(checkpoint=callbacks["checkpoint"]) return net
[docs]class SkorchNet: def __init__(self, module): self.module = module def __call__(self, *args, **kwargs): model = self.module(*args, **kwargs) net = NeuralNetClassifierNoLog( model.double(), criterion=nn.CrossEntropyLoss, optimizer=optim.Adam, optimizer__weight_decay=0, batch_size=128, lr=1e-2, max_epochs=300, device="cpu", train_split=ValidSplit(0.2, stratified=True), iterator_train__shuffle=True, callbacks=[ ( "train_acc", EpochScoring( "accuracy", name="train_acc", on_train=True, lower_is_better=False, ), ), ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=300 - 1)), ("estoper", EarlyStopping(patience=50)), ( "checkpoint", Checkpoint(dirname="checkpoints/{:s}".format(str(id(model)))), ), ], verbose=True, ) return net
[docs]def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs): """ Convenience function to transform numpy array to `torch.Tensor`. Converts `X` to ndarray using asarray if necessary. Parameters ---------- X: ndarray or list or number Input arrays requires_grad: bool passed on to Variable constructor dtype: numpy dtype, optional var_kwargs: passed on to Variable constructor Returns ------- var: `torch.Tensor` """ if not hasattr(X, "__len__"): X = [X] X = np.asarray(X) if dtype is not None: X = X.astype(dtype) X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs) if pin_memory: X_tensor = X_tensor.pin_memory() return X_tensor
[docs]class AvgPool2dWithConv(nn.Module): """ Compute average pooling using a convolution, to have the dilation parameter. Parameters ---------- kernel_size: (int,int) Size of the pooling region. stride: (int,int) Stride of the pooling operation. dilation: int or (int,int) Dilation applied to the pooling filter. padding: int or (int,int) Padding applied before the pooling operation. """ def __init__(self, kernel_size, stride, dilation=1, padding=0): super(AvgPool2dWithConv, self).__init__() self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding # don't name them "weights" to # make sure these are not accidentally used by some procedure # that initializes parameters or something self._pool_weights = None
[docs] def forward(self, x): # Create weights for the convolution on demand: # size or type of x changed... in_channels = x.size()[1] weight_shape = ( in_channels, 1, self.kernel_size[0], self.kernel_size[1], ) if self._pool_weights is None or ( (tuple(self._pool_weights.size()) != tuple(weight_shape)) or (self._pool_weights.is_cuda != x.is_cuda) or (self._pool_weights.data.type() != x.data.type()) ): n_pool = np.prod(self.kernel_size) weights = np_to_th(np.ones(weight_shape, dtype=np.float32) / float(n_pool)) weights = weights.type_as(x) if x.is_cuda: weights = weights.cuda() self._pool_weights = weights pooled = F.conv2d( x, self._pool_weights, bias=None, stride=self.stride, dilation=self.dilation, padding=self.padding, groups=in_channels, ) return pooled
[docs]class Expression(nn.Module): """Compute given expression on forward pass. Parameters ---------- expression_fn : callable Should accept variable number of objects of type `torch.autograd.Variable` to compute its output. """ def __init__(self, expression_fn): super(Expression, self).__init__() self.expression_fn = expression_fn
[docs] def forward(self, *x): return self.expression_fn(*x)
def __repr__(self): if hasattr(self.expression_fn, "func") and hasattr( self.expression_fn, "kwargs" ): expression_str = "{:s} {:s}".format( self.expression_fn.func.__name__, str(self.expression_fn.kwargs) ) elif hasattr(self.expression_fn, "__name__"): expression_str = self.expression_fn.__name__ else: expression_str = repr(self.expression_fn) return self.__class__.__name__ + "(expression=%s) " % expression_str
[docs]class Ensure4d(nn.Module):
[docs] def forward(self, x): while len(x.shape) < 4: x = x.unsqueeze(-1) return x
[docs]def identity(x): return x
[docs]def transpose_time_to_spat(x): """Swap time and spatial dimensions. Returns ------- x: torch.Tensor tensor in which last and first dimensions are swapped """ return x.permute(0, 3, 2, 1)
[docs]def squeeze_final_output(x): """Removes empty dimension at end and potentially removes empty time dimension. It does not just use squeeze as we never want to remove first dimension. Returns ------- x: torch.Tensor squeezed tensor """ assert x.size()[3] == 1 x = x[:, :, :, 0] if x.size()[2] == 1: x = x[:, :, 0] return x