mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
chore(ml): added testing and github workflow (#2969)
* added testing * github action for python, made mypy happy * formatted with black * minor fixes and styling * test model cache * cache test dependencies * narrowed model cache tests * moved endpoint tests to their own class * cleaned up fixtures * formatting * removed unused dep
This commit is contained in:
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
||||
|
||||
from ..config import get_cache_dir
|
||||
from ..schemas import ModelType
|
||||
@@ -14,15 +14,9 @@ from ..schemas import ModelType
|
||||
class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, cache_dir: Path | None = None, **model_kwargs
|
||||
) -> None:
|
||||
def __init__(self, model_name: str, cache_dir: Path | str | None = None, **model_kwargs: Any) -> None:
|
||||
self.model_name = model_name
|
||||
self._cache_dir = (
|
||||
cache_dir
|
||||
if cache_dir is not None
|
||||
else get_cache_dir(model_name, self.model_type)
|
||||
)
|
||||
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
||||
|
||||
try:
|
||||
self.load(**model_kwargs)
|
||||
@@ -51,12 +45,8 @@ class InferenceModel(ABC):
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
@classmethod
|
||||
def from_model_type(
|
||||
cls, model_type: ModelType, model_name, **model_kwargs
|
||||
) -> InferenceModel:
|
||||
subclasses = {
|
||||
subclass._model_type: subclass for subclass in cls.__subclasses__()
|
||||
}
|
||||
def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
|
||||
subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
|
||||
if model_type not in subclasses:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
@@ -66,8 +56,6 @@ class InferenceModel(ABC):
|
||||
if not self.cache_dir.exists():
|
||||
return
|
||||
elif not rmtree.avoids_symlink_attacks:
|
||||
raise RuntimeError(
|
||||
"Attempted to clear cache, but rmtree is not safe on this platform."
|
||||
)
|
||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
||||
|
||||
rmtree(self.cache_dir)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
@@ -34,13 +35,9 @@ class ModelCache:
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.cache = SimpleMemoryCache(
|
||||
ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
|
||||
)
|
||||
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
|
||||
|
||||
async def get(
|
||||
self, model_name: str, model_type: ModelType, **model_kwargs
|
||||
) -> InferenceModel:
|
||||
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
||||
"""
|
||||
Args:
|
||||
model_name: Name of model in the model hub used for the task.
|
||||
@@ -56,9 +53,7 @@ class ModelCache:
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: InferenceModel.from_model_type(
|
||||
model_type, model_name, **model_kwargs
|
||||
),
|
||||
lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
|
||||
)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
return model
|
||||
@@ -73,7 +68,14 @@ class ModelCache:
|
||||
class RevalidationPlugin(BasePlugin):
|
||||
"""Revalidates cache item's TTL after cache hit."""
|
||||
|
||||
async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
|
||||
async def post_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
key: str,
|
||||
ret: Any | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
if namespace is not None:
|
||||
@@ -81,7 +83,14 @@ class RevalidationPlugin(BasePlugin):
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
|
||||
async def post_multi_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
keys: list[str],
|
||||
ret: list[Any] | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ class FaceRecognizer(InferenceModel):
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_face_score,
|
||||
cache_dir: Path | None = None,
|
||||
**model_kwargs,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.min_score = min_score
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
|
||||
@@ -16,8 +16,8 @@ class ImageClassifier(InferenceModel):
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_tag_score,
|
||||
cache_dir: Path | None = None,
|
||||
**model_kwargs,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.min_score = min_score
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
@@ -30,13 +30,7 @@ class ImageClassifier(InferenceModel):
|
||||
)
|
||||
|
||||
def predict(self, image: Image) -> list[str]:
|
||||
predictions = self.model(image)
|
||||
tags = list(
|
||||
{
|
||||
tag
|
||||
for pred in predictions
|
||||
for tag in pred["label"].split(", ")
|
||||
if pred["score"] >= self.min_score
|
||||
}
|
||||
)
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
|
||||
return tags
|
||||
|
||||
Reference in New Issue
Block a user