Source code for simages.extractor

"""
.. module:: extractor
   :synopsis: Embedding extractor module

.. moduleauthor:: Justin Shenk <shenkjustin@gmail.com>


"""

import logging
import os
from typing import Union, Optional, Tuple
import warnings

import closely
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.utils.data as utils

from .dataset import PILDataset, ImageFolder
from .models import BasicAutoencoder, UnNormalize

warnings.filterwarnings("ignore", message="Palette images with Transparency")
log = logging.getLogger(__name__)


[docs]class EmbeddingExtractor: """Extract embeddings from data with models and allow visualization. Attributes: trainloader (torch loader) evalloader (torch loader) model (torch.nn.Module) embeddings (np.ndarray) """
[docs] def __init__( self, input: Union[str, np.ndarray], num_channels: int = 3, num_epochs: int = 2, batch_size: int = 32, show: bool = False, show_path: bool = False, show_train: bool = False, z_dim: int = 8, metric: str = "cosine", model: Optional[torch.nn.Module] = None, db: Optional = None, **kwargs, ): """Inits EmbeddingExtractor with input, either `str` or `np.ndarray`, performs training and validation. Args: input (np.ndarray or str): data num_channels (int): grayscale = 1, color = 3 num_epochs (int): more is better (generally) batch_size (int): number of images per batch show (bool): show closest pairs show_path (bool): show path of duplicates show_train (bool): show intermediate training results z_dim (int): compression size metric (str): distance metric for :meth:`scipy.spatial.distance.cdist` (eg, euclidean, cosine, hamming, etc.) model (torch.nn.Module, optional): class implementing same methods as :class:`~simages.BasicAutoencoder` db_conn_string (str): Mongodb connection string kwargs (dict) """ self.num_epochs = num_epochs self._batch_size = batch_size self._show = show self._show_path = show_path self._show_train = show_train self._db = db self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self._num_channels = num_channels self._metric = metric self._z_dim = z_dim self._hw = 48 self._mean = [0.5] * self._num_channels self._std = [0.25] * self._num_channels train_transforms = transforms.Compose( [ transforms.RandomResizedCrop(self._hw), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(self._mean, self._std), ] ) basic_transforms = transforms.Compose( [ transforms.Resize(self._hw), transforms.CenterCrop(self._hw), transforms.ToTensor(), transforms.Normalize(self._mean, self._std), ] ) def is_valid(path): img_extensions = [ ".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".gif", ".octet-stream", ] _, file_extension = os.path.splitext(path) valid_ext = file_extension.lower() in img_extensions if not valid_ext: return False try: Image.open(path).verify() except Exception as e: log.info(f"Skipping {os.path.basename(path)}: {e}") return False return True if isinstance(input, str): data_dir = os.path.abspath(input) self.train_dataset = ImageFolder( data_dir, transform=train_transforms, is_valid_file=is_valid ) self.eval_dataset = ImageFolder( data_dir, transform=basic_transforms, is_valid_file=is_valid ) self.trainloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=4 ) self.evalloader = torch.utils.data.DataLoader( self.eval_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) elif isinstance(input, np.ndarray): self.trainloader = self._tensor_dataloader( input, train_transforms, shuffle=True ) self.evalloader = self._tensor_dataloader( input, basic_transforms, shuffle=False ) if not torch.cuda.is_available(): log.info( "Note: No GPU found, using CPU. Performance is improved on a CUDA-device." ) if model is not None: self.model = model else: self.model = BasicAutoencoder(num_channels=num_channels, z_dim=z_dim) if torch.cuda.device_count() > 1: log.info("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(self.model) self.model.to(self._device) self._distance = nn.MSELoss() self._optimizer = torch.optim.Adam(self.model.parameters(), weight_decay=1e-5) self.train() self.eval()
def _truncate_middle(self, string: str, n: int) -> str: if len(string) <= n: # string is already short-enough return string # half of the size, minus the 3 .'s n_2 = int(int(n) / 2 - 3) # whatever's left n_1 = int(n - n_2 - 3) return f"{string[:n_1]}...{string[-n_2:]}"
[docs] def get_image(self, index: int) -> torch.Tensor: result = self.evalloader.dataset[index] if isinstance(result, tuple): return result[0].cpu() else: return result.cpu()
def _tensor_dataloader( self, array: np.ndarray, transforms: torchvision.transforms.Compose, shuffle: bool = True, ) -> utils.DataLoader: log.debug(f"INFO: data shape: {array.shape} (Target: N x C x H x W)") if array.ndim == 3: log.debug( f"Converting to grayscale dataset of dims {array.shape[0]} x 1 x {array.shape[1]} x {array.shape[2]}" ) array = array[:, np.newaxis, ...] log.debug(f"New shape: {array.shape}") tensor = torch.Tensor(array) pil_list = [TF.to_pil_image(array.squeeze()) for array in tensor] dataset = PILDataset(pil_list, transform=transforms) dataloader = utils.DataLoader( dataset, batch_size=self._batch_size, shuffle=shuffle ) return dataloader
[docs] def train(self): """Train autoencoder to build embeddings of dataset. Final embeddings are created in :meth:`~simages.extractor.EmbeddingExtractor.eval`. """ log.info( f"Building embeddings for {len(self.evalloader.dataset)} images. This may take some time..." ) for epoch in range(self.num_epochs): for data in self.trainloader: if isinstance(data, list): data = data[0] img = data.to(self._device) # ===================forward===================== output, embedding = self.model(img) loss = self._distance(output, img) # ===================backward==================== self._optimizer.zero_grad() loss.backward() self._optimizer.step() if self._show_train: try: img_array = img.cpu()[0] output_array = output.detach().cpu()[0] grid_img = torchvision.utils.make_grid( [img_array, output_array], nrow=1 ) self.show( grid_img, title=f"Building embeddings: epoch [{epoch+1}/{self.num_epochs}]", block=False, y_labels=[(2, "Original"), (5, "Reconstruction")], ) except Exception as e: log.error(f"{e}") # ===================log======================== log.info( "epoch [{}/{}], loss:{:.4f}".format(epoch + 1, self.num_epochs, loss) )
[docs] def eval(self): """Evaluate reconstruction of embeddings built in `train`.""" embeddings = [] imgs = [] # Change model to `eval` mode so weights are frozen self.model.eval() for data in self.evalloader: if isinstance(data, list): data = data[0] img = data.to(self._device) imgs.append(img) # ===================forward===================== output, embedding = self.model(img) embeddings.append(embedding) loss = self._distance(output, img) # ===================backward==================== self._optimizer.zero_grad() loss.backward() self._optimizer.step() if self._show_train: try: img_array = img.cpu()[0] output_array = output.detach().cpu()[0] grid_img = torchvision.utils.make_grid( [img_array, output_array], nrow=1 ) self.show( grid_img, title=f"Reconstruction", y_labels=[(2, "Original"), (5, "Reconstruction")], ) except Exception as e: log.error(f"{e}") # ===================log======================== log.info("eval, loss:{:.4f}".format(loss)) self.embeddings = torch.cat(embeddings).detach().cpu().numpy()
[docs] def duplicates( self, n: int = 10, quantile: float = None ) -> Tuple[np.ndarray, np.ndarray]: """Identify `n` closest pairs of images, or quantile (for example, closest 0.05). Args: n (int): number of pairs quantile (float): quantile of total combination (suggested range: 0.001 - 0.01) """ if quantile is not None: pairs, distances = closely.solve( self.embeddings, quantile=quantile, metric=self._metric ) else: pairs, distances = closely.solve(self.embeddings, n=n, metric=self._metric) return pairs, distances
[docs] @staticmethod def channels_last(img: np.ndarray) -> np.ndarray: """Move channels from first to last by swapping axes.""" img_t = np.transpose(img, (1, 2, 0)) return img_t
[docs] def show( self, img: Union[torch.Tensor, np.ndarray], title: str = "", block: bool = True, y_labels=None, unnormalize=True, ): """Plot `img` with `title`. Args: img (torch.Tensor or np.ndarray): Image to plot title (str): plot title block (bool): block matplotlib plot until window closed """ if unnormalize: img = self.unnormalize(img) if isinstance(img, torch.Tensor): npimg = img.detach().numpy() elif isinstance(img, np.ndarray): pass else: raise NotImplementedError(f"{type(img)}") if img.shape[0] in [1, 2, 3]: npimg = self.channels_last(npimg).squeeze() fig, ax = plt.subplots(1, 1) plt.title(f"{title}") ax.imshow(npimg, interpolation="nearest") if y_labels is not None: labels = [item.get_text() for item in ax.get_xticklabels()] for idx, label in y_labels: labels[idx] = label ax.set_yticklabels(labels) plt.show(block=block)
[docs] def show_images(self, indices: Union[list, int], title=""): """Plot images (from validation data) at `indices` with `title`""" if isinstance(indices, int): indices = [indices] tensors = [self.get_image(idx) for idx in indices] self.show(torchvision.utils.make_grid(tensors), title=title)
[docs] def image_paths(self, indices, short=True): """Get path to image at `index` of eval/embedding Args: indices Union[int,list]: indices of embeddings in dataset short (bool): truncate filepath to 30 charachters Returns: paths (str or list of str): paths to images in image folder """ if isinstance(indices, (int, np.int_)): indices = [indices] paths = [] for index in indices: path = self.evalloader.dataset.samples[index] if short: path = self._truncate_middle(os.path.basename(path), 30) paths.append(path) if len(paths) == 1: return paths[0] # backward compatibility return paths
[docs] def show_duplicates(self, n=5, path=False) -> (np.ndarray, np.ndarray): """Show duplicates from comparison of embeddings. Uses `closely` package to get pairs. Args: n (int): how many closest pairs to identify path (bool): Plot pairs of images with abbreviated paths Returns: pairs (np.ndarray): pairs as indices distances (np.ndarray): distances of pairs """ show_path = path or self._show_path pairs, distances = self.duplicates(n=n) # Plot pairs for idx, pair in enumerate(pairs): img0 = self.get_image(pair[0]) img1 = self.get_image(pair[1]) img0_reconst = self.decode(index=pair[0], astensor=True)[0] img1_reconst = self.decode(index=pair[1], astensor=True)[0] pair_details = ( f"{self.image_paths(pair[0])}\n{self.image_paths(pair[1])}" if show_path else pair ) title = f"{pair_details}, dist={distances[idx]:.2f}" self.show( torchvision.utils.make_grid( [img0, img1, img0_reconst, img1_reconst], nrow=2 ), title=title, y_labels=[(2, "Original"), (5, "Reconstruction")], ) return pairs, distances
[docs] def unnormalize(self, image: torch.Tensor) -> torch.Tensor: """Unnormalize an image. Args: image (:class:`torch.Tensor`) Returns: image (:class:`torch.Tensor`) """ unorm = UnNormalize(mean=self._mean, std=self._std) return unorm(image)
[docs] def decode( self, embedding: Optional[np.ndarray] = None, index: Optional[int] = None, show: bool = False, astensor: bool = False, ) -> np.ndarray: """Decode embeddings at `index` or pass `embedding` directly Args: embedding (np.ndarray, optional): embedding of image index (int): index (of validation set / embeddings) to decode show (bool): plot the results astensor (bool): keep as torch.Tensor Returns: image (np.ndarray or torch.Tensor): reconstructed image from embedding """ self.model.eval() if embedding is None: embedding = self.embeddings[index] emb = np.expand_dims(embedding, 0) # add batch axis # Check if has direct access to `decode` method if not hasattr(self.model, "decode"): image, _ = self.model.module.decode(torch.Tensor(emb).to(self._device)) else: image, _ = self.model.decode(torch.Tensor(emb).to(self._device)) image = self.unnormalize(image) if show: grid_img = torchvision.utils.make_grid(image) self.show(grid_img, title=index) if astensor: return image.detach().cpu() return image.detach().cpu().numpy()