Source code for simages.embeddings

"""
.. module:: embedding
   :synopsis: Embedding module

.. moduleauthor:: Justin Shenk <shenkjustin@gmail.com>


"""

import os
import glob
from typing import Union

import closely
import numpy as np

from .extractor import EmbeddingExtractor


[docs]def linkageplot(embeddings: np.ndarray, ordered=True): """Plot linkage between embeddings in hierarchical clustering of the distance matrix Args: embeddings (np.ndarray): embeddings of images in dataset ordered (bool): order distance matrix before plotting """ dist_mat = closely.distance_matrix(embeddings, metric="cosine", ordered=ordered) return closely.show_linkage(dist_mat)
[docs]class Embeddings: """Create embeddings from `input` data by training an autoencoder. Passes arguments for `EmbeddingExtractor`. Attributes: extractor (simages.EmbeddingExtractor): workhorse for extracting embeddings from dataset embeddings (np.ndarray): embeddings pairs (np.ndarray): n closest pairs distances (np.ndarray): distances between n-closest pairs """
[docs] def __init__(self, input: Union[np.ndarray, str], **kwargs): """Inits Embeddings with data.""" if isinstance(input, str): if os.path.isdir(input): self.data_dir = input # Get files files = glob.glob(os.path.join(input, "*.*")) # Exclude hidden files files = [x for x in files if not x.startswith(".")] # Assume they are images if len(files): self.embeddings = self.images_to_embeddings(self.data_dir, **kwargs) else: raise Exception(f"Files count is {len(files)}") elif isinstance(input, np.ndarray): if input.ndim == 3 and input.shape[0] == 1: num_channels = 1 elif input.ndim == 4: num_channels = input.shape[1] else: raise ( f"Data shape {input.shape} not supported, shoudld be N x C x H x W" ) self.embeddings = self.array_to_embeddings( input, num_channels=num_channels, **kwargs ) else: raise NotImplementedError(f"{type(input)}")
@property def array(self): return self.extractor.embeddings
[docs] def duplicates(self, n: int = 10): self.pairs, self.distances = closely.solve(self.embeddings, n=n) return self.pairs, self.distances
[docs] def show_duplicates(self, n=5): """Convenience wrapper for `EmbeddingExtractor.show_duplicates`""" return self.extractor.show_duplicates(n=n)
[docs] def images_to_embeddings(self, data_dir: str, **kwargs): self.extractor = EmbeddingExtractor(data_dir, **kwargs) return self.extractor.embeddings
[docs] def array_to_embeddings(self, array: np.ndarray, **kwargs): self.extractor = EmbeddingExtractor(array, **kwargs) return self.extractor.embeddings
def __repr__(self): return np.array_repr(self.extractor.embeddings)