Source code for thelper.viz.umap

"""Tools related to the Uniform Manifold Approximation and Projection (UMAP).

For more information on UMAP, see https://github.com/lmcinnes/umap for the original author's
repository.
"""

from typing import Any, AnyStr, Dict, List, Optional, Union  # noqa: F401

import cv2 as cv
import numpy as np
import torch
import tqdm

import thelper.utils
from thelper.viz.tsne import plot


[docs]def visualize(model, # type: thelper.typedefs.ModelType task, # type: thelper.typedefs.TaskType loader, # type: thelper.typedefs.LoaderType draw=False, # type: bool color_map=None, # type: Optional[Dict[int, np.ndarray]] max_samples=None, # type: Optional[int] return_meta=False, # type: Union[bool, List[AnyStr]] **kwargs ): # type: (...) -> Dict[AnyStr, Any] """ Creates (and optionally displays) a 2D UMAP visualization of sample embeddings. By default, all samples from the data loader will be projected using the model and used for the visualization. If the task is related to classification, the prediction and groundtruth labels will be highlighting using various colors. If the model does not possess a ``get_embedding`` attribute, its raw output will be used for projections. Otherwise, ``get_embedding`` will be called. Args: model: the model which will be used to produce embeddings. task: the task object used to decode predictions and color samples (if possible). loader: the data loader used to get data samples to project. draw: boolean flag used to toggle internal display call on or off. color_map: map of RGB triplets used to color predictions (for classification only). max_samples: maximum number of samples to draw from the data loader. return_meta: toggles whether sample metadata should be provided as output or not. Returns: A dictionary of the visualization result (an RGB image in numpy format), a list of projected embedding coordinates, the labels of the samples, and the predictions of the samples. """ assert thelper.utils.check_installed("umap"), \ "could not import optional 3rd-party dependency 'umap-learn'; make sure you install it first!" import umap assert loader is not None and len(loader) > 0, "no available data to load" assert model is not None and isinstance(model, torch.nn.Module), "invalid model" assert task is not None and isinstance(task, thelper.tasks.Task), "invalid task" assert max_samples is None or max_samples > 0, "invalid maximum loader sample count" thelper.viz.logger.debug("fetching data loader samples for UMAP visualization...") embeddings, labels, preds, idxs = [], [], [], [] if isinstance(task, thelper.tasks.Classification) and not task.multi_label: assert all([isinstance(n, str) for n in task.class_names]), "unexpected class name types" if not color_map: if hasattr(task, "color_map"): color_map = task.color_map else: color_map = {idx: thelper.draw.get_label_color_mapping(idx + 1) for idx in task.class_indices.values()} color_map = {idx: f"#{c[0]:02X}{c[1]:02X}{c[2]:02X}" for idx, c in color_map.items()} if isinstance(return_meta, bool): return_meta = task.meta_keys if return_meta else [] assert isinstance(return_meta, list) and all([isinstance(key, str) for key in return_meta]), \ "sample metadata keys must be provided as a list of strings" meta = {key: [] for key in return_meta} for sample_idx, sample in tqdm.tqdm(enumerate(loader), desc="extracting embeddings"): if max_samples is not None and sample_idx > max_samples: break with torch.no_grad(): input_tensor = sample[task.input_key] if task is not None and isinstance(task, thelper.tasks.Classification) and \ not task.multi_label and task.gt_key in sample: label = sample[task.gt_key] if isinstance(label, torch.Tensor): label = label.cpu().numpy() if all([isinstance(lbl, str) for lbl in label]): label = [task.class_indices[lbl] for lbl in label] pred = model(input_tensor).topk(k=1, dim=1)[1].view(input_tensor.size(0)).cpu().numpy() labels.append(label) preds.append(pred) if hasattr(model, "get_embedding"): embedding = model.get_embedding(input_tensor) else: if not thelper.viz.warned_missing_get_embedding: thelper.viz.logger.warning("missing 'get_embedding' function in model object; will use output instead") thelper.viz.warned_missing_get_embedding = True embedding = model(input_tensor) if embedding.dim() > 2: # reshape to BxC embedding = embedding.view(embedding.size(0), -1) embeddings.append(embedding.cpu().numpy()) idxs.append(sample_idx) for key in return_meta: for v in sample[key]: meta[key].append(v) embeddings = np.concatenate(embeddings) if labels and preds: labels, preds = np.concatenate(labels), np.concatenate(preds) else: labels, preds = [0] * len(embeddings), [0] * len(embeddings) seed = thelper.utils.get_key_def("seed", kwargs, 0) if seed is None: seed = np.random.randint(np.iinfo(np.int32).max) prev_state = np.random.get_state() np.random.seed(seed) default_umap_args = {"n_components": 2} umap_args = thelper.utils.get_key_def("umap_args", kwargs, default_umap_args) umap_engine = umap.UMAP(**umap_args) thelper.viz.logger.debug("computing UMAP projection...") embeddings = umap_engine.fit_transform(embeddings) np.random.set_state(prev_state) fig = plot(embeddings, labels, preds, color_map=color_map, task=task, **kwargs) img = thelper.draw.fig2array(fig).copy() if draw: thelper.viz.logger.debug("displaying UMAP projection...") cv.imshow("thelper.viz.umap", img[..., ::-1]) # RGB to BGR for opencv display cv.waitKey(1) return { # key formatting should be compatible with _write_data in thelper/train/base.py "tsne-projs/pickle": embeddings, "tsne-labels/json": labels.tolist(), "tsne-preds/json": preds.tolist(), "tsne-idxs/json": idxs, "tsne-meta/json": meta, "tsne/image": img }