Source code for homcloud.codebook

"""
References:

* https://link.springer.com/article/10.1007/s10462-020-09897-4
* https://github.com/bziiuj/pcodebooks
"""

import numpy as np
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from matplotlib.colors import LogNorm


def random_sampling_pairs(diagrams, n, weighting_function, random_state):
    births = np.concatenate([pd.births for pd in diagrams])
    deaths = np.concatenate([pd.deaths for pd in diagrams])
    npairs = births.shape[0]

    if weighting_function:
        weights = np.array([weighting_function(b, d) for b, d in zip(births, deaths)])
        p = weights / np.sum(weights)
    else:
        p = weights = None

    pairs = np.stack([births, deaths], axis=-1)
    return pairs[random_state.choice(np.arange(npairs), n, p=p), :], weights


def setup_random_state(random_state):
    if random_state is None:
        return np.random.RandomState()
    elif isinstance(random_state, int):
        return np.random.RandomState(random_state)
    else:
        return random_state


def diagram_weights(pd, weighting_function):
    if weighting_function:
        return np.array([weighting_function(b, d) for b, d in zip(pd.births, pd.deaths)])
    else:
        return None


def diagram2array(diagram):
    return np.stack([diagram.births, diagram.deaths], axis=-1)


def bow_normalize(vector, really_normalize):
    if not really_normalize:
        return vector

    norm = np.linalg.norm(vector)
    if norm == 0:
        return vector
    else:
        return np.sqrt(vector) / norm


class PBoW:
    def __init__(self, n_random_sampled_pairs, n_clusters, weighting_function=None, random_state=None, normalize=True):
        self.n_random_sampled_pairs = n_random_sampled_pairs
        self.n_clusters = n_clusters
        self.weighting_function = weighting_function
        self.random_state = setup_random_state(random_state)
        self.kmeans = None
        self.normalize = normalize

    def fit(self, diagrams):
        sampled_pairs, _ = random_sampling_pairs(
            diagrams, self.n_random_sampled_pairs, self.weighting_function, self.random_state
        )
        self.kmeans = KMeans(self.n_clusters, init="k-means++", n_init=1, random_state=self.random_state)
        self.kmeans.fit(sampled_pairs)

    @property
    def cluster_centers(self):
        return self.kmeans.cluster_centers_

    def vectorize(self, diagram):
        labels = self.kmeans.predict(np.stack([diagram.births, diagram.deaths], axis=-1))
        counts = np.bincount(labels, diagram_weights(diagram, self.weighting_function), self.n_clusters)
        return bow_normalize(counts, self.normalize)


class GMMClusteringBase:
    def __init__(self, n_random_sampled_pairs, n_clusters, weighting_function=None, random_state=None):
        self.n_random_sampled_pairs = n_random_sampled_pairs
        self.n_clusters = n_clusters
        self.weighting_function = weighting_function
        self.random_state = setup_random_state(random_state)
        self.gmm = None

    def fit(self, diagrams):
        sampled_pairs, _ = random_sampling_pairs(
            diagrams, self.n_random_sampled_pairs, self.weighting_function, self.random_state
        )
        self.gmm = GaussianMixture(self.n_clusters, random_state=self.random_state)
        self.gmm.fit(sampled_pairs)

    @property
    def cluster_centers(self):
        return self.gmm.means_

    @property
    def cluster_weights(self):
        return self.gmm.weights_

    @property
    def cluster_covariances(self):
        return self.gmm.covariances_

    def plot_gmm_density_estimation(self, ax, x_range, x_num, y_range=None, y_num=None):
        if y_range is None:
            y_range = x_range
        if y_num is None:
            y_num = x_num
        x, y = np.meshgrid(np.linspace(x_range[0], x_range[1], x_num), np.linspace(y_range[0], y_range[1], y_num))
        xx = np.array([x.ravel(), y.ravel()]).T
        z = -self.gmm.score_samples(xx)
        z = z.reshape(x.shape)
        zmax = np.max(z)
        ax.contour(x, y, z, norm=LogNorm(vmin=1.0, vmax=zmax), levels=np.logspace(0, np.log(zmax), 20))


class StablePBoW(GMMClusteringBase):
    def __init__(self, n_random_sampled_pairs, n_clusters, weighting_function=None, random_state=None, normalize=True):
        super().__init__(n_random_sampled_pairs, n_clusters, weighting_function, random_state)
        self.normalize = normalize

    def vectorize(self, diagram):
        wlogprob = self.gmm._estimate_weighted_log_prob(np.stack([diagram.births, diagram.deaths], axis=-1))
        if self.weighting_function is None:
            return bow_normalize(np.sum(np.exp(wlogprob), axis=0), self.normalize)
        else:
            dweights = diagram_weights(diagram, self.weighting_function)
            mask = dweights > 0
            return bow_normalize(
                np.sum(np.exp(wlogprob[mask] + np.log(dweights[mask].reshape(-1, 1))), axis=0), self.normalize
            )


# * https://www.vlfeat.org/api/fisher-fundamentals.html
# * F. Perronnin and C. Dance. Fisher kenrels on visual vocabularies
#   for image categorizaton. In Proc. CVPR, 2006.
# * Florent Perronnin, Jorge Sánchez, and Thomas Mensink.
#   Improving the fisher kernel for large-scale image classification. In Proc. ECCV, 2010.
class PFV(GMMClusteringBase):
    def __init__(self, n_random_sampled_pairs, n_clusters, weighting_function=None, random_state=None, normalize=0.5):
        super().__init__(n_random_sampled_pairs, n_clusters, weighting_function, random_state)
        self.normalize = normalize

    def vectorize(self, diagram):
        npairs = diagram.num_pairs
        n_clusters = self.n_clusters

        pairs = diagram2array(diagram)  # (npairs, 2)
        gammas = self.gmm.predict_proba(pairs)  # (npairs, n_clusters)
        r = pairs.reshape(npairs, 1, 2) - self.gmm.means_.reshape(1, n_clusters, 2)  # (npairs, n_clusters, 2)
        sigmas = np.diagonal(self.gmm.covariances_, 0, 1, 2)  # (n_clusters, 2)
        dmu = np.sum(
            gammas.reshape(npairs, n_clusters, 1) * r / sigmas.reshape(1, n_clusters, 2), axis=0
        )  # (n_clusters, 2)
        dsigma = np.sum(
            gammas.reshape(npairs, n_clusters, 1) * ((r**2 / (sigmas**2).reshape(1, n_clusters, 2)) - 1), axis=0
        )  # (n_clusters, 2)
        nc = np.sqrt(n_clusters * self.gmm.weights_).reshape(n_clusters, 1)  # (n_clusters, 2)
        pfv_mu = dmu / nc
        pfv_sigma = dsigma / nc / np.sqrt(2)
        pfv = np.hstack([pfv_mu, pfv_sigma]).ravel()
        if self.normalize:
            return np.sign(pfv) * (np.abs(pfv) ** self.normalize)
        else:
            return pfv