mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
feat(ml)!: switch image classification and CLIP models to ONNX (#3809)
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .clip import CLIPSTEncoder
|
||||
from .clip import CLIPEncoder
|
||||
from .facial_recognition import FaceRecognizer
|
||||
from .image_classification import ImageClassifier
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
||||
|
||||
from ..config import get_cache_dir
|
||||
from ..config import get_cache_dir, settings
|
||||
from ..schemas import ModelType
|
||||
|
||||
|
||||
@@ -16,12 +19,31 @@ class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | str | None = None,
|
||||
eager: bool = True,
|
||||
inter_op_num_threads: int = settings.model_inter_op_threads,
|
||||
intra_op_num_threads: int = settings.model_intra_op_threads,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self._loaded = False
|
||||
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
||||
loader = self.load if eager else self.download
|
||||
|
||||
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
|
||||
# don't pre-allocate more memory than needed
|
||||
self.provider_options = model_kwargs.pop(
|
||||
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
|
||||
)
|
||||
self.sess_options = PicklableSessionOptions()
|
||||
# avoid thread contention between models
|
||||
if inter_op_num_threads > 1:
|
||||
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
||||
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
||||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||
|
||||
try:
|
||||
loader(**model_kwargs)
|
||||
except (OSError, InvalidProtobuf, BadZipFile):
|
||||
@@ -30,6 +52,7 @@ class InferenceModel(ABC):
|
||||
|
||||
def download(self, **model_kwargs: Any) -> None:
|
||||
if not self.cached:
|
||||
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
|
||||
self._download(**model_kwargs)
|
||||
|
||||
def load(self, **model_kwargs: Any) -> None:
|
||||
@@ -39,6 +62,7 @@ class InferenceModel(ABC):
|
||||
|
||||
def predict(self, inputs: Any) -> Any:
|
||||
if not self._loaded:
|
||||
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
|
||||
self.load()
|
||||
return self._predict(inputs)
|
||||
|
||||
@@ -89,3 +113,14 @@ class InferenceModel(ABC):
|
||||
else:
|
||||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions):
|
||||
def __getstate__(self) -> bytes:
|
||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||
|
||||
def __setstate__(self, state: Any) -> None:
|
||||
self.__init__() # type: ignore
|
||||
for attr, val in pickle.loads(state):
|
||||
setattr(self, attr, val)
|
||||
|
||||
@@ -46,7 +46,7 @@ class ModelCache:
|
||||
model: The requested model.
|
||||
"""
|
||||
|
||||
key = self.cache.build_key(model_name, model_type.value)
|
||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
|
||||
@@ -1,31 +1,141 @@
|
||||
from typing import Any
|
||||
import os
|
||||
import zipfile
|
||||
from typing import Any, Literal
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
||||
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
|
||||
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
|
||||
from clip_server.model.tokenization import Tokenizer
|
||||
from PIL.Image import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers.util import snapshot_download
|
||||
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
||||
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
_ST_TO_JINA_MODEL_NAME = {
|
||||
"clip-ViT-B-16": "ViT-B-16::openai",
|
||||
"clip-ViT-B-32": "ViT-B-32::openai",
|
||||
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
||||
"clip-ViT-L-14": "ViT-L-14::openai",
|
||||
}
|
||||
|
||||
class CLIPSTEncoder(InferenceModel):
|
||||
|
||||
class CLIPEncoder(InferenceModel):
|
||||
_model_type = ModelType.CLIP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: str | None = None,
|
||||
mode: Literal["text", "vision"] | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
if mode is not None and mode not in ("text", "vision"):
|
||||
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
||||
if "vit-b" not in model_name.lower():
|
||||
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
|
||||
self.mode = mode
|
||||
jina_model_name = self._get_jina_model_name(model_name)
|
||||
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
||||
|
||||
def _download(self, **model_kwargs: Any) -> None:
|
||||
repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}"
|
||||
snapshot_download(
|
||||
cache_dir=self.cache_dir,
|
||||
repo_id=repo_id,
|
||||
library_name="sentence-transformers",
|
||||
ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"],
|
||||
)
|
||||
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
||||
text_onnx_path = self.cache_dir / "textual.onnx"
|
||||
vision_onnx_path = self.cache_dir / "visual.onnx"
|
||||
|
||||
if not text_onnx_path.is_file():
|
||||
self._download_model(*models[0])
|
||||
|
||||
if not vision_onnx_path.is_file():
|
||||
self._download_model(*models[1])
|
||||
|
||||
def _load(self, **model_kwargs: Any) -> None:
|
||||
self.model = SentenceTransformer(
|
||||
self.model_name,
|
||||
cache_folder=self.cache_dir.as_posix(),
|
||||
**model_kwargs,
|
||||
)
|
||||
if self.mode == "text" or self.mode is None:
|
||||
self.text_model = ort.InferenceSession(
|
||||
self.cache_dir / "textual.onnx",
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
self.text_outputs = [output.name for output in self.text_model.get_outputs()]
|
||||
self.tokenizer = Tokenizer(self.model_name)
|
||||
|
||||
if self.mode == "vision" or self.mode is None:
|
||||
self.vision_model = ort.InferenceSession(
|
||||
self.cache_dir / "visual.onnx",
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
|
||||
|
||||
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
||||
self.transform = _transform_pil_image(image_size)
|
||||
|
||||
def _predict(self, image_or_text: Image | str) -> list[float]:
|
||||
return self.model.encode(image_or_text).tolist()
|
||||
match image_or_text:
|
||||
case Image():
|
||||
if self.mode == "text":
|
||||
raise TypeError("Cannot encode image as text-only model")
|
||||
pixel_values = self.transform(image_or_text)
|
||||
assert isinstance(pixel_values, torch.Tensor)
|
||||
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
|
||||
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
||||
case str():
|
||||
if self.mode == "vision":
|
||||
raise TypeError("Cannot encode text as vision-only model")
|
||||
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
|
||||
inputs = {
|
||||
"input_ids": text_inputs["input_ids"].int().numpy(),
|
||||
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
||||
}
|
||||
outputs = self.text_model.run(self.text_outputs, inputs)
|
||||
case _:
|
||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
||||
|
||||
return outputs[0][0].tolist()
|
||||
|
||||
def _get_jina_model_name(self, model_name: str) -> str:
|
||||
if model_name in _MODELS:
|
||||
return model_name
|
||||
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
||||
print(
|
||||
(f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
|
||||
(f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
|
||||
)
|
||||
return _ST_TO_JINA_MODEL_NAME[model_name]
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}.")
|
||||
|
||||
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
||||
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
||||
download_model(
|
||||
url=_S3_BUCKET_V2 + model_name,
|
||||
target_folder=self.cache_dir.as_posix(),
|
||||
md5sum=model_md5,
|
||||
with_resume=True,
|
||||
)
|
||||
file = self.cache_dir / model_name.split("/")[1]
|
||||
if file.suffix == ".zip":
|
||||
with zipfile.ZipFile(file, "r") as zip_ref:
|
||||
zip_ref.extractall(self.cache_dir)
|
||||
os.remove(file)
|
||||
return True
|
||||
|
||||
|
||||
# same as `_transform_blob` without `_blob2image`
|
||||
def _transform_pil_image(n_px: int) -> Compose:
|
||||
return Compose(
|
||||
[
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize(
|
||||
(0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
||||
from insightface.utils.face_align import norm_crop
|
||||
from insightface.utils.storage import BASE_REPO_URL, download_file
|
||||
@@ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel):
|
||||
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
|
||||
except StopIteration:
|
||||
raise FileNotFoundError("Facial recognition models not found in cache directory")
|
||||
self.det_model = RetinaFace(det_file.as_posix())
|
||||
self.rec_model = ArcFaceONNX(rec_file.as_posix())
|
||||
|
||||
self.det_model = RetinaFace(
|
||||
session=ort.InferenceSession(
|
||||
det_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
),
|
||||
)
|
||||
self.rec_model = ArcFaceONNX(
|
||||
rec_file.as_posix(),
|
||||
session=ort.InferenceSession(
|
||||
rec_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
),
|
||||
)
|
||||
|
||||
self.det_model.prepare(
|
||||
ctx_id=-1,
|
||||
ctx_id=0,
|
||||
det_thresh=self.min_score,
|
||||
input_size=(640, 640),
|
||||
)
|
||||
self.rec_model.prepare(ctx_id=-1)
|
||||
self.rec_model.prepare(ctx_id=0)
|
||||
|
||||
def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
|
||||
bboxes, kpss = self.det_model.detect(image)
|
||||
|
||||
@@ -2,8 +2,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from optimum.onnxruntime import ORTModelForImageClassification
|
||||
from optimum.pipelines import pipeline
|
||||
from PIL.Image import Image
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
@@ -25,15 +27,34 @@ class ImageClassifier(InferenceModel):
|
||||
|
||||
def _download(self, **model_kwargs: Any) -> None:
|
||||
snapshot_download(
|
||||
cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
|
||||
cache_dir=self.cache_dir,
|
||||
repo_id=self.model_name,
|
||||
allow_patterns=["*.bin", "*.json", "*.txt"],
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
def _load(self, **model_kwargs: Any) -> None:
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
|
||||
)
|
||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir)
|
||||
model_kwargs |= {
|
||||
"cache_dir": self.cache_dir,
|
||||
"provider": self.providers[0],
|
||||
"provider_options": self.provider_options[0],
|
||||
"session_options": self.sess_options,
|
||||
}
|
||||
model_path = self.cache_dir / "model.onnx"
|
||||
|
||||
if model_path.exists():
|
||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
||||
else:
|
||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
feature_extractor=processor,
|
||||
)
|
||||
|
||||
def _predict(self, image: Image) -> list[str]:
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user