Module toplot.algorithms
Functions
def cluster_latent_components(*weights, metric='jensenshannon', n_iterations: int = 11, verbose=False)
-
Expand source code
def cluster_latent_components( *weights, metric="jensenshannon", n_iterations: int = 11, verbose=False ): r"""Solve cluster identifiability problem using the Hungarian algorithm. Given \( s=1,\dot,S \) sets of `K` latent components that are similar upto a permutation of the component index \( k = 1,\dots,K \) (e.g., from `S` model fits), match the components `k` between the chains `s` and determine the centroid (i.e., the consensus component). The algorith uses multiple restarts and selects the best clustering based on the silhouette score. Example: See the notebook `examples/identifiability.ipynb`. Args: *weights: The component matrices \( \pmb{W}_1, \dots, \pmb{W}_S \) to cluster. Each matrix must have the same shape and is clustered on the leading axis. metric: The metric to use for computing the pairwise distances between the components. n_iterations: Number of iterations to run the clustering algorithm. verbose: Print the silhouette score at each iteration. Returns: A pair of cluster_assignments: Per chain (leading axis), the cluster assignments (second axis). centeroids: The centre of mass of the clusters. """ if len(weights) < 2: raise ValueError("Only one argument. Nothing to cluster against.") # Restart the clustering algorithm with different chain centroids. n_chains = n_restarts = len(weights) n_components = len(weights[0]) weights_np = np.stack(weights) if metric == "jensenshannon": if not np.all(weights_np.sum(axis=1) == 1): raise ValueError( "Trailing axis of weights should sum to 1 when using Jenson-Shannon distance." ) best_score = -1 for i in range(n_restarts): # Restart clustering algorithm with centroid initially at the i-th chain. centeroids = weights_np[i] for _ in range(n_iterations): # 1) Assignment step. cluster_assignments = _assign_clusters(weights_np, centeroids, metric) # 2) Update step. centeroids = _update_clusters(weights_np, cluster_assignments) # 3) Compute pairwise distances per chain. x_flat = weights_np.reshape([n_chains * n_components, -1]) distances = pdist(x_flat, metric=metric) X_dist = squareform(distances) ss = silhouette_score( X_dist, cluster_assignments.flatten(), metric="precomputed" ) if verbose: print("Silhouette score:", ss) if verbose: print("-------------------") print(f"Final silhouette score {i}:", ss) if ss > best_score: best_score = ss best_cluster_assignments = cluster_assignments best_centeroids = centeroids if verbose: print("Best score:", best_score) return best_cluster_assignments, best_centeroids
Solve cluster identifiability problem using the Hungarian algorithm.
Given s=1,\dot,S sets of
K
latent components that are similar upto a permutation of the component index k = 1,\dots,K (e.g., fromS
model fits), match the componentsk
between the chainss
and determine the centroid (i.e., the consensus component). The algorith uses multiple restarts and selects the best clustering based on the silhouette score.Example
See the notebook
examples/identifiability.ipynb
.Args
*weights
- The component matrices \pmb{W}_1, \dots, \pmb{W}_S to cluster. Each matrix must have the same shape and is clustered on the leading axis.
metric
- The metric to use for computing the pairwise distances between the components.
n_iterations
- Number of iterations to run the clustering algorithm.
verbose
- Print the silhouette score at each iteration.
Returns: A pair of cluster_assignments: Per chain (leading axis), the cluster assignments (second axis). centeroids: The centre of mass of the clusters.
def invert_cluster_mapping(cluster_assignments)
-
Expand source code
def invert_cluster_mapping(cluster_assignments): n_chains = cluster_assignments.shape[0] n_components = cluster_assignments.shape[1] # Invert cluster assignment mappings. cluster_assignments_inv = np.zeros_like(cluster_assignments) for k in range(n_components): for j in range(n_chains): cluster_assignments_inv[j, int(cluster_assignments[j, k])] = k return cluster_assignments_inv