Skip to content

Commit

Permalink
perf: use best available torch backend for embedding
Browse files Browse the repository at this point in the history
Automatically detect and use the best available backend for PyTorch. This is used to specify the device on which to compute the embeddings. Currently, this searches in order:
1. CUDA
2. MPS

If no device is available, this process will fall back on the CPU.
  • Loading branch information
kevinsbarnard committed May 14, 2024
1 parent a5570d1 commit 59dcc06
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions vars_gridview/lib/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,37 @@
from pathlib import Path

import numpy as np
import torch
from dreamsim import dreamsim
from PIL import Image
from torch.types import Device

from vars_gridview.lib.settings import SettingsManager


def get_torch_device() -> Device:
"""
Get the appropriate torch device for embedding computation.
Returns:
Device: torch device instance.
"""
# Define an ordered list of predicates to check the device availability and its corresponding device string
device_check_order = [
(torch.cuda.is_available, "cuda"),
(torch.backends.mps.is_available, "mps"),
]

# Check each predicate in order, returning the first device that is present
for check, device_str in device_check_order:
present = check()
if present:
return device_str

# Failing all, use CPU
return "cpu"


class Embedding(ABC):
"""
Embedding abstract base class. Produces embedding vectors for images.
Expand Down Expand Up @@ -38,17 +63,20 @@ def __init__(self) -> None:
base_cache_dir = Path(settings.cache_dir.value)
dreamsim_cache_dir = base_cache_dir / DreamSimEmbedding.CACHE_SUBDIR_NAME

# Get the appropriate torch device
self._device = get_torch_device()

# Download / load the models
self._model, self._preprocess = dreamsim(
pretrained=True,
device="cuda",
device=self._device,
cache_dir=str(dreamsim_cache_dir.resolve().absolute()),
)

def embed(self, image: np.ndarray) -> np.ndarray:
# Preprocess the image
image_pil = Image.fromarray(image)
image_tensor = self._preprocess(image_pil).cuda()
image_tensor = self._preprocess(image_pil).to(self._device)

# Compute the embedding
embedding = self._model.embed(image_tensor).cpu().detach().numpy().flatten()
Expand Down

0 comments on commit 59dcc06

Please sign in to comment.