# -*- coding: utf-8 -*-
#
# Authors: Swolf <swolfforever@gmail.com>
# Date: 2021/1/23
# License: MIT License
"""
Riemannian Procrustes Analysis.
Modified from https://github.com/plcrodrigues/RPA
"""
from typing import Optional
from functools import partial
import numpy as np
from numpy import ndarray
import autograd.numpy as anp
try:
from pymanopt.manifolds import Rotations
except Exception:
from pymanopt.manifolds import SpecialOrthogonalGroup as Rotations
from pymanopt import Problem
try:
from pymanopt.solvers import SteepestDescent
except Exception:
from pymanopt.optimizers import SteepestDescent
from ..utils.covariance import covariances, sqrtm, invsqrtm, logm, powm
from .riemann import mean_riemann, distance_riemann
[docs]def get_recenter(
X: ndarray, cov_method: str = "cov", mean_method: str = "riemann", n_jobs: int = 1
):
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
C = covariances(X, estimator=cov_method, n_jobs=n_jobs)
if mean_method == "riemann":
M = mean_riemann(C, n_jobs=n_jobs)
elif mean_method == "euclid":
M = np.mean(C, axis=0)
iM12 = invsqrtm(M)
return iM12
[docs]def recenter(X: ndarray, iM12: ndarray):
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
return iM12 @ X
[docs]def get_rescale(X: ndarray, cov_method: str = "cov", n_jobs: int = 1):
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
C = covariances(X, estimator=cov_method, n_jobs=n_jobs)
M = mean_riemann(C, n_jobs=n_jobs)
d = np.mean(np.square(distance_riemann(C, M, n_jobs=n_jobs)))
scale = np.sqrt(1 / d)
return M, scale
[docs]def rescale(
X: ndarray, M: ndarray, scale: float, cov_method: str = "cov", n_jobs: int = 1
):
X = np.reshape(X, (-1, *X.shape[-2:]))
X = X - np.mean(X, axis=-1, keepdims=True)
C = covariances(X, estimator=cov_method, n_jobs=n_jobs)
iM12 = invsqrtm(M)
M12 = sqrtm(M)
A = iM12 @ C @ iM12
B = M12 @ powm(A, (scale - 1) / 2) @ iM12
X = B @ X
return X
def _cost_euc(R: ndarray, Mt: ndarray, Ms: ndarray, weights: Optional[ndarray] = None):
if weights is None:
weights = anp.ones(len(Mt))
cost = 0
for i, a in enumerate(zip(Ms, Mt)):
Msi, Mti = a
Mti = anp.dot(R, anp.dot(Mti, R.T))
cost += weights[i] * anp.square(anp.linalg.norm(Mti - Msi))
# cost = anp.linalg.norm(Mt-Ms, ord='fro', axis=(-2, -1))
return cost
def _cost_rie(R: ndarray, Mt: ndarray, Ms: ndarray, weights: Optional[ndarray] = None):
if weights is None:
weights = anp.ones(len(Mt))
Mt = anp.matmul(R, anp.matmul(Mt, R.T))
# distance_riemann not implemented in autograd, must provide egrad
cost = anp.square(distance_riemann(Ms, Mt))
return anp.dot(cost, weights)
def _egrad_rie(R: ndarray, Mt: ndarray, Ms: ndarray, weights: Optional[ndarray] = None):
if weights is None:
weights = anp.ones(len(Mt))
# I dont't understand the code!!!
iMt12 = invsqrtm(Mt)
Ms12 = sqrtm(Ms)
term_aux = anp.matmul(R, anp.matmul(Mt, R.T))
term_aux = anp.matmul(iMt12, anp.matmul(term_aux, iMt12))
g = 4 * np.matmul(np.matmul(iMt12, logm(term_aux)), np.matmul(Ms12, R))
g = g * weights[:, np.newaxis, np.newaxis]
return anp.sum(g, axis=0)
def _procruster_cost_function_euc(R, Mt, Ms):
weights = anp.ones(len(Mt))
c = []
for Mti, Msi in zip(Mt, Ms):
t1 = Msi
t2 = anp.dot(R, anp.dot(Mti, R.T))
ci = anp.linalg.norm(t1 - t2) ** 2
c.append(ci)
c = anp.array(c)
return anp.dot(c, weights)
def _procruster_cost_function_rie(R, Mt, Ms):
weights = anp.ones(len(Mt))
c = []
for Mti, Msi in zip(Mt, Ms):
t1 = Msi
t2 = anp.dot(R, anp.dot(Mti, R.T))
ci = distance_riemann(t1, t2)[0] ** 2
c.append(ci)
c = anp.array(c)
return anp.dot(c, weights)
def _procruster_egrad_function_rie(R, Mt, Ms):
weights = anp.ones(len(Mt))
g = []
for Mti, Msi, wi in zip(Mt, Ms, weights):
iMti12 = invsqrtm(Mti)
Msi12 = sqrtm(Msi)
term_aux = anp.dot(R, anp.dot(Msi, R.T))
term_aux = anp.dot(iMti12, anp.dot(term_aux, iMti12))
gi = 4 * anp.dot(anp.dot(iMti12, logm(term_aux)), anp.dot(Msi12, R))
g.append(gi * wi)
g = anp.sum(g, axis=0)
return g
def _get_rotation_matrix(
Mt: ndarray, Ms: ndarray, weights: Optional[ndarray] = None, metric: str = "euclid"
):
Mt = Mt.reshape(-1, *Mt.shape[-2:])
Ms = Ms.reshape(-1, *Ms.shape[-2:])
n = Mt[0].shape[0]
manifolds = Rotations(n)
if metric == "euclid":
# cost = partial(_cost_euc, Mt=Mt, Ms=Ms, weights=weights)
cost = partial(_procruster_cost_function_euc, Mt=Mt, Ms=Ms)
problem = Problem(manifold=manifolds, cost=cost, verbosity=0)
elif metric == "riemann":
# cost = partial(_cost_rie, Mt=Mt, Ms=Ms, weights=weights)
# egrad = partial(_egrad_rie, Mt=Mt, Ms=Ms, weights=weights)
cost = partial(_procruster_cost_function_rie, Mt=Mt, Ms=Ms)
egrad = partial(_procruster_egrad_function_rie, Mt=Mt, Ms=Ms)
problem = Problem(manifold=manifolds, cost=cost, egrad=egrad, verbosity=0)
solver = SteepestDescent(mingradnorm=1e-3)
Ropt = solver.solve(problem)
return Ropt
[docs]def get_rotate(
Xs: ndarray,
ys: ndarray,
Xt: ndarray,
yt: ndarray,
cov_method: str = "cov",
metric: str = "euclid",
n_jobs: int = 1,
):
slabels = np.unique(ys)
tlabels = np.unique(yt)
Xs = np.reshape(Xs, (-1, *Xs.shape[-2:]))
Xt = np.reshape(Xt, (-1, *Xt.shape[-2:]))
Xs = Xs - np.mean(Xs, axis=-1, keepdims=True)
Xt = Xt - np.mean(Xt, axis=-1, keepdims=True)
Cs = covariances(Xs, estimator=cov_method, n_jobs=n_jobs)
Ct = covariances(Xt, estimator=cov_method, n_jobs=n_jobs)
Ms = np.stack([mean_riemann(Cs[ys == label]) for label in slabels])
Mt = np.stack([mean_riemann(Ct[yt == label]) for label in tlabels])
Ropt = _get_rotation_matrix(Mt, Ms, metric=metric)
return Ropt
[docs]def rotate(Xt: ndarray, Ropt: ndarray):
Xt = np.reshape(Xt, (-1, *Xt.shape[-2:]))
Xt = Xt - np.mean(Xt, axis=-1, keepdims=True)
return Ropt @ Xt