mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
refactor(ml): modularization and styling (#2835)
* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
3
machine-learning/app/models/__init__.py
Normal file
3
machine-learning/app/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder
|
||||
from .facial_recognition import FaceRecognizer
|
||||
from .image_classification import ImageClassifier
|
||||
52
machine-learning/app/models/base.py
Normal file
52
machine-learning/app/models/base.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import get_cache_dir
|
||||
from ..schemas import ModelType
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | None = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self._cache_dir = (
|
||||
cache_dir
|
||||
if cache_dir is not None
|
||||
else get_cache_dir(model_name, self.model_type)
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, inputs: Any) -> Any:
|
||||
...
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelType:
|
||||
return self._model_type
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
return self._cache_dir
|
||||
|
||||
@cache_dir.setter
|
||||
def cache_dir(self, cache_dir: Path):
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
@classmethod
|
||||
def from_model_type(
|
||||
cls, model_type: ModelType, model_name, **model_kwargs
|
||||
) -> InferenceModel:
|
||||
subclasses = {
|
||||
subclass._model_type: subclass for subclass in cls.__subclasses__()
|
||||
}
|
||||
if model_type not in subclasses:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
return subclasses[model_type](model_name, **model_kwargs)
|
||||
92
machine-learning/app/models/cache.py
Normal file
92
machine-learning/app/models/cache.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import asyncio
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ttl: float | None = None,
|
||||
revalidate: bool = False,
|
||||
timeout: int | None = None,
|
||||
profiling: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||
"""
|
||||
|
||||
self.ttl = ttl
|
||||
plugins = []
|
||||
|
||||
if revalidate:
|
||||
plugins.append(RevalidationPlugin())
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.cache = SimpleMemoryCache(
|
||||
ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
|
||||
)
|
||||
|
||||
async def get(
|
||||
self, model_name: str, model_type: ModelType, **model_kwargs
|
||||
) -> InferenceModel:
|
||||
"""
|
||||
Args:
|
||||
model_name: Name of model in the model hub used for the task.
|
||||
model_type: Model type or task, which determines which model zoo is used.
|
||||
|
||||
Returns:
|
||||
model: The requested model.
|
||||
"""
|
||||
|
||||
key = self.cache.build_key(model_name, model_type.value)
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: InferenceModel.from_model_type(
|
||||
model_type, model_name, **model_kwargs
|
||||
),
|
||||
)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
if not hasattr(self.cache, "profiling"):
|
||||
return None
|
||||
|
||||
return self.cache.profiling # type: ignore
|
||||
|
||||
|
||||
class RevalidationPlugin(BasePlugin):
|
||||
"""Revalidates cache item's TTL after cache hit."""
|
||||
|
||||
async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
|
||||
if ret is None:
|
||||
return
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
|
||||
if ret is None:
|
||||
return
|
||||
|
||||
for key, val in zip(keys, ret):
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if val is not None and key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
37
machine-learning/app/models/clip.py
Normal file
37
machine-learning/app/models/clip.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
class CLIPSTEncoder(InferenceModel):
|
||||
_model_type = ModelType.CLIP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | None = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
super().__init__(model_name, cache_dir)
|
||||
self.model = SentenceTransformer(
|
||||
self.model_name,
|
||||
cache_folder=self.cache_dir.as_posix(),
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def predict(self, image_or_text: Image | str) -> list[float]:
|
||||
return self.model.encode(image_or_text).tolist()
|
||||
|
||||
|
||||
# stubs to allow different behavior between the two in the future
|
||||
# and handle loading different image and text clip models
|
||||
class CLIPSTVisionEncoder(CLIPSTEncoder):
|
||||
_model_type = ModelType.CLIP_VISION
|
||||
|
||||
|
||||
class CLIPSTTextEncoder(CLIPSTEncoder):
|
||||
_model_type = ModelType.CLIP_TEXT
|
||||
59
machine-learning/app/models/facial_recognition.py
Normal file
59
machine-learning/app/models/facial_recognition.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
class FaceRecognizer(InferenceModel):
|
||||
_model_type = ModelType.FACIAL_RECOGNITION
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_face_score,
|
||||
cache_dir: Path | None = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
super().__init__(model_name, cache_dir)
|
||||
self.min_score = min_score
|
||||
model = FaceAnalysis(
|
||||
name=self.model_name,
|
||||
root=self.cache_dir.as_posix(),
|
||||
allowed_modules=["detection", "recognition"],
|
||||
**model_kwargs,
|
||||
)
|
||||
model.prepare(
|
||||
ctx_id=0,
|
||||
det_thresh=self.min_score,
|
||||
det_size=(640, 640),
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
|
||||
height, width, _ = image.shape
|
||||
results = []
|
||||
faces = self.model.get(image)
|
||||
|
||||
for face in faces:
|
||||
x1, y1, x2, y2 = face.bbox
|
||||
|
||||
results.append(
|
||||
{
|
||||
"imageWidth": width,
|
||||
"imageHeight": height,
|
||||
"boundingBox": {
|
||||
"x1": round(x1),
|
||||
"y1": round(y1),
|
||||
"x2": round(x2),
|
||||
"y2": round(y2),
|
||||
},
|
||||
"score": face.det_score.item(),
|
||||
"embedding": face.normed_embedding.tolist(),
|
||||
}
|
||||
)
|
||||
return results
|
||||
40
machine-learning/app/models/image_classification.py
Normal file
40
machine-learning/app/models/image_classification.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image
|
||||
from transformers.pipelines import pipeline
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
class ImageClassifier(InferenceModel):
|
||||
_model_type = ModelType.IMAGE_CLASSIFICATION
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_tag_score,
|
||||
cache_dir: Path | None = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
super().__init__(model_name, cache_dir)
|
||||
self.min_score = min_score
|
||||
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
self.model_name,
|
||||
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
|
||||
)
|
||||
|
||||
def predict(self, image: Image) -> list[str]:
|
||||
predictions = self.model(image)
|
||||
tags = list(
|
||||
{
|
||||
tag
|
||||
for pred in predictions
|
||||
for tag in pred["label"].split(", ")
|
||||
if pred["score"] >= self.min_score
|
||||
}
|
||||
)
|
||||
return tags
|
||||
Reference in New Issue
Block a user