mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
feat(ml)!: switch image classification and CLIP models to ONNX (#3809)
This commit is contained in:
@@ -2,8 +2,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from optimum.onnxruntime import ORTModelForImageClassification
|
||||
from optimum.pipelines import pipeline
|
||||
from PIL.Image import Image
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
@@ -25,15 +27,34 @@ class ImageClassifier(InferenceModel):
|
||||
|
||||
def _download(self, **model_kwargs: Any) -> None:
|
||||
snapshot_download(
|
||||
cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
|
||||
cache_dir=self.cache_dir,
|
||||
repo_id=self.model_name,
|
||||
allow_patterns=["*.bin", "*.json", "*.txt"],
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=True,
|
||||
)
|
||||
|
||||
def _load(self, **model_kwargs: Any) -> None:
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
|
||||
)
|
||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir)
|
||||
model_kwargs |= {
|
||||
"cache_dir": self.cache_dir,
|
||||
"provider": self.providers[0],
|
||||
"provider_options": self.provider_options[0],
|
||||
"session_options": self.sess_options,
|
||||
}
|
||||
model_path = self.cache_dir / "model.onnx"
|
||||
|
||||
if model_path.exists():
|
||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
||||
else:
|
||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
feature_extractor=processor,
|
||||
)
|
||||
|
||||
def _predict(self, image: Image) -> list[str]:
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user