mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	refactor(ml): modularization and styling (#2835)
* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
		| @@ -21,8 +21,8 @@ ENV NODE_ENV=production \ | ||||
|   PYTHONDONTWRITEBYTECODE=1 \ | ||||
|   PYTHONUNBUFFERED=1 \ | ||||
|   PATH="/opt/venv/bin:$PATH" \ | ||||
|   PYTHONPATH=`pwd` | ||||
|   PYTHONPATH=/usr/src | ||||
|  | ||||
| COPY --from=builder /opt/venv /opt/venv | ||||
| COPY app . | ||||
| ENTRYPOINT ["python", "main.py"] | ||||
| ENTRYPOINT ["python", "-m", "app.main"] | ||||
|   | ||||
							
								
								
									
										0
									
								
								machine-learning/app/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								machine-learning/app/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,5 +1,10 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| from pydantic import BaseSettings | ||||
|  | ||||
| from .schemas import ModelType | ||||
|  | ||||
|  | ||||
| class Settings(BaseSettings): | ||||
|     cache_folder: str = "/cache" | ||||
|     classification_model: str = "microsoft/resnet-50" | ||||
| @@ -15,8 +20,12 @@ class Settings(BaseSettings): | ||||
|     min_face_score: float = 0.7 | ||||
|  | ||||
|     class Config(BaseSettings.Config): | ||||
|         env_prefix = 'MACHINE_LEARNING_' | ||||
|         env_prefix = "MACHINE_LEARNING_" | ||||
|         case_sensitive = False | ||||
|  | ||||
|  | ||||
| def get_cache_dir(model_name: str, model_type: ModelType) -> Path: | ||||
|     return Path(settings.cache_folder, model_type.value, model_name) | ||||
|  | ||||
|  | ||||
| settings = Settings() | ||||
|   | ||||
| @@ -1,52 +1,58 @@ | ||||
| import os | ||||
| import io | ||||
| from io import BytesIO | ||||
| from typing import Any | ||||
|  | ||||
| from cache import ModelCache | ||||
| from schemas import ( | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import uvicorn | ||||
| from fastapi import Body, Depends, FastAPI | ||||
| from PIL import Image | ||||
|  | ||||
| from .config import settings | ||||
| from .models.base import InferenceModel | ||||
| from .models.cache import ModelCache | ||||
| from .schemas import ( | ||||
|     EmbeddingResponse, | ||||
|     FaceResponse, | ||||
|     TagResponse, | ||||
|     MessageResponse, | ||||
|     ModelType, | ||||
|     TagResponse, | ||||
|     TextModelRequest, | ||||
|     TextResponse, | ||||
| ) | ||||
| import uvicorn | ||||
| from PIL import Image | ||||
| from fastapi import FastAPI, HTTPException, Depends, Body | ||||
| from models import get_model, run_classification, run_facial_recognition | ||||
| from config import settings | ||||
|  | ||||
| _model_cache = None | ||||
|  | ||||
| app = FastAPI() | ||||
|  | ||||
|  | ||||
| @app.on_event("startup") | ||||
| async def startup_event() -> None: | ||||
|     global _model_cache | ||||
|     _model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) | ||||
|     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) | ||||
|     same_clip = settings.clip_image_model == settings.clip_text_model | ||||
|     app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION | ||||
|     app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT | ||||
|     models = [ | ||||
|         (settings.classification_model, "image-classification"), | ||||
|         (settings.clip_image_model, "clip"), | ||||
|         (settings.clip_text_model, "clip"), | ||||
|         (settings.facial_recognition_model, "facial-recognition"), | ||||
|         (settings.classification_model, ModelType.IMAGE_CLASSIFICATION), | ||||
|         (settings.clip_image_model, app.state.clip_vision_type), | ||||
|         (settings.clip_text_model, app.state.clip_text_type), | ||||
|         (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), | ||||
|     ] | ||||
|  | ||||
|     # Get all models | ||||
|     for model_name, model_type in models: | ||||
|         if settings.eager_startup: | ||||
|             await _model_cache.get_cached_model(model_name, model_type) | ||||
|             await app.state.model_cache.get(model_name, model_type) | ||||
|         else: | ||||
|             get_model(model_name, model_type) | ||||
|             InferenceModel.from_model_type(model_type, model_name) | ||||
|  | ||||
|  | ||||
| def dep_model_cache(): | ||||
|     if _model_cache is None: | ||||
|         raise HTTPException(status_code=500, detail="Unable to load model.") | ||||
| def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image: | ||||
|     return Image.open(BytesIO(byte_image)) | ||||
|  | ||||
|  | ||||
| def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat: | ||||
|     byte_image_np = np.frombuffer(byte_image, np.uint8) | ||||
|     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR) | ||||
|  | ||||
| def dep_input_image(image: bytes = Body(...)) -> Image: | ||||
|     return Image.open(io.BytesIO(image)) | ||||
|  | ||||
| @app.get("/", response_model=MessageResponse) | ||||
| async def root() -> dict[str, str]: | ||||
| @@ -62,33 +68,29 @@ def ping() -> str: | ||||
|     "/image-classifier/tag-image", | ||||
|     response_model=TagResponse, | ||||
|     status_code=200, | ||||
|     dependencies=[Depends(dep_model_cache)], | ||||
| ) | ||||
| async def image_classification( | ||||
|     image: Image = Depends(dep_input_image) | ||||
|     image: Image.Image = Depends(dep_pil_image), | ||||
| ) -> list[str]: | ||||
|     try: | ||||
|         model = await _model_cache.get_cached_model( | ||||
|             settings.classification_model, "image-classification" | ||||
|         ) | ||||
|         labels = run_classification(model, image, settings.min_tag_score) | ||||
|     except Exception as ex: | ||||
|         raise HTTPException(status_code=500, detail=str(ex)) | ||||
|     else: | ||||
|         return labels | ||||
|     model = await app.state.model_cache.get( | ||||
|         settings.classification_model, ModelType.IMAGE_CLASSIFICATION | ||||
|     ) | ||||
|     labels = model.predict(image) | ||||
|     return labels | ||||
|  | ||||
|  | ||||
| @app.post( | ||||
|     "/sentence-transformer/encode-image", | ||||
|     response_model=EmbeddingResponse, | ||||
|     status_code=200, | ||||
|     dependencies=[Depends(dep_model_cache)], | ||||
| ) | ||||
| async def clip_encode_image( | ||||
|     image: Image = Depends(dep_input_image) | ||||
|     image: Image.Image = Depends(dep_pil_image), | ||||
| ) -> list[float]: | ||||
|     model = await _model_cache.get_cached_model(settings.clip_image_model, "clip") | ||||
|     embedding = model.encode(image).tolist() | ||||
|     model = await app.state.model_cache.get( | ||||
|         settings.clip_image_model, app.state.clip_vision_type | ||||
|     ) | ||||
|     embedding = model.predict(image) | ||||
|     return embedding | ||||
|  | ||||
|  | ||||
| @@ -96,13 +98,12 @@ async def clip_encode_image( | ||||
|     "/sentence-transformer/encode-text", | ||||
|     response_model=EmbeddingResponse, | ||||
|     status_code=200, | ||||
|     dependencies=[Depends(dep_model_cache)], | ||||
| ) | ||||
| async def clip_encode_text( | ||||
|     payload: TextModelRequest | ||||
| ) -> list[float]: | ||||
|     model = await _model_cache.get_cached_model(settings.clip_text_model, "clip") | ||||
|     embedding = model.encode(payload.text).tolist() | ||||
| async def clip_encode_text(payload: TextModelRequest) -> list[float]: | ||||
|     model = await app.state.model_cache.get( | ||||
|         settings.clip_text_model, app.state.clip_text_type | ||||
|     ) | ||||
|     embedding = model.predict(payload.text) | ||||
|     return embedding | ||||
|  | ||||
|  | ||||
| @@ -110,22 +111,21 @@ async def clip_encode_text( | ||||
|     "/facial-recognition/detect-faces", | ||||
|     response_model=FaceResponse, | ||||
|     status_code=200, | ||||
|     dependencies=[Depends(dep_model_cache)], | ||||
| ) | ||||
| async def facial_recognition( | ||||
|     image: bytes = Body(...), | ||||
|     image: cv2.Mat = Depends(dep_cv_image), | ||||
| ) -> list[dict[str, Any]]: | ||||
|     model = await _model_cache.get_cached_model( | ||||
|         settings.facial_recognition_model, "facial-recognition" | ||||
|     model = await app.state.model_cache.get( | ||||
|         settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION | ||||
|     ) | ||||
|     faces = run_facial_recognition(model, image) | ||||
|     faces = model.predict(image) | ||||
|     return faces | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     is_dev = os.getenv("NODE_ENV") == "development" | ||||
|     uvicorn.run( | ||||
|         "main:app", | ||||
|         "app.main:app", | ||||
|         host=settings.host, | ||||
|         port=settings.port, | ||||
|         reload=is_dev, | ||||
|   | ||||
| @@ -1,119 +0,0 @@ | ||||
| import torch | ||||
| from insightface.app import FaceAnalysis | ||||
| from pathlib import Path | ||||
|  | ||||
| from transformers import pipeline, Pipeline | ||||
| from sentence_transformers import SentenceTransformer | ||||
| from typing import Any, BinaryIO | ||||
| import cv2 as cv | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| from config import settings | ||||
|  | ||||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|  | ||||
|  | ||||
| def get_model(model_name: str, model_type: str, **model_kwargs): | ||||
|     """ | ||||
|     Instantiates the specified model. | ||||
|  | ||||
|     Args: | ||||
|         model_name: Name of model in the model hub used for the task. | ||||
|         model_type: Model type or task, which determines which model zoo is used. | ||||
|             `facial-recognition` uses Insightface, while all other models use the HF Model Hub. | ||||
|  | ||||
|             Options: | ||||
|                 `image-classification`, `clip`,`facial-recognition`, `tokenizer`, `processor` | ||||
|  | ||||
|     Returns: | ||||
|         model: The requested model. | ||||
|     """ | ||||
|  | ||||
|     cache_dir = _get_cache_dir(model_name, model_type) | ||||
|     match model_type: | ||||
|         case "facial-recognition": | ||||
|             model = _load_facial_recognition( | ||||
|                 model_name, cache_dir=cache_dir, **model_kwargs | ||||
|             ) | ||||
|         case "clip": | ||||
|             model = SentenceTransformer( | ||||
|                 model_name, cache_folder=cache_dir, **model_kwargs | ||||
|             ) | ||||
|         case _: | ||||
|             model = pipeline( | ||||
|                 model_type, | ||||
|                 model_name, | ||||
|                 model_kwargs={"cache_dir": cache_dir, **model_kwargs}, | ||||
|             ) | ||||
|  | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def run_classification( | ||||
|     model: Pipeline, image: Image, min_score: float | None = None | ||||
| ): | ||||
|     predictions: list[dict[str, Any]] = model(image)  # type: ignore | ||||
|     result = { | ||||
|         tag | ||||
|         for pred in predictions | ||||
|         for tag in pred["label"].split(", ") | ||||
|         if min_score is None or pred["score"] >= min_score | ||||
|     } | ||||
|  | ||||
|     return list(result) | ||||
|  | ||||
|  | ||||
| def run_facial_recognition( | ||||
|     model: FaceAnalysis, image: bytes | ||||
| ) -> list[dict[str, Any]]: | ||||
|     file_bytes = np.frombuffer(image, dtype=np.uint8) | ||||
|     img = cv.imdecode(file_bytes, cv.IMREAD_COLOR) | ||||
|     height, width, _ = img.shape | ||||
|     results = [] | ||||
|     faces = model.get(img) | ||||
|  | ||||
|     for face in faces: | ||||
|         x1, y1, x2, y2 = face.bbox | ||||
|  | ||||
|         results.append( | ||||
|             { | ||||
|                 "imageWidth": width, | ||||
|                 "imageHeight": height, | ||||
|                 "boundingBox": { | ||||
|                     "x1": round(x1), | ||||
|                     "y1": round(y1), | ||||
|                     "x2": round(x2), | ||||
|                     "y2": round(y2), | ||||
|                 }, | ||||
|                 "score": face.det_score.item(), | ||||
|                 "embedding": face.normed_embedding.tolist(), | ||||
|             } | ||||
|         ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def _load_facial_recognition( | ||||
|     model_name: str, | ||||
|     min_face_score: float | None = None, | ||||
|     cache_dir: Path | str | None = None, | ||||
|     **model_kwargs, | ||||
| ): | ||||
|     if cache_dir is None: | ||||
|         cache_dir = _get_cache_dir(model_name, "facial-recognition") | ||||
|     if isinstance(cache_dir, Path): | ||||
|         cache_dir = cache_dir.as_posix() | ||||
|     if min_face_score is None: | ||||
|         min_face_score = settings.min_face_score | ||||
|  | ||||
|     model = FaceAnalysis( | ||||
|         name=model_name, | ||||
|         root=cache_dir, | ||||
|         allowed_modules=["detection", "recognition"], | ||||
|         **model_kwargs, | ||||
|     ) | ||||
|     model.prepare(ctx_id=0, det_thresh=min_face_score, det_size=(640, 640)) | ||||
|     return model | ||||
|  | ||||
|  | ||||
| def _get_cache_dir(model_name: str, model_type: str) -> Path: | ||||
|     return Path(settings.cache_folder, device, model_type, model_name) | ||||
							
								
								
									
										3
									
								
								machine-learning/app/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								machine-learning/app/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder | ||||
| from .facial_recognition import FaceRecognizer | ||||
| from .image_classification import ImageClassifier | ||||
							
								
								
									
										52
									
								
								machine-learning/app/models/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								machine-learning/app/models/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from abc import abstractmethod, ABC | ||||
| from pathlib import Path | ||||
| from typing import Any | ||||
|  | ||||
| from ..config import get_cache_dir | ||||
| from ..schemas import ModelType | ||||
|  | ||||
|  | ||||
| class InferenceModel(ABC): | ||||
|     _model_type: ModelType | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: Path | None = None, | ||||
|     ): | ||||
|         self.model_name = model_name | ||||
|         self._cache_dir = ( | ||||
|             cache_dir | ||||
|             if cache_dir is not None | ||||
|             else get_cache_dir(model_name, self.model_type) | ||||
|         ) | ||||
|  | ||||
|     @abstractmethod | ||||
|     def predict(self, inputs: Any) -> Any: | ||||
|         ... | ||||
|  | ||||
|     @property | ||||
|     def model_type(self) -> ModelType: | ||||
|         return self._model_type | ||||
|  | ||||
|     @property | ||||
|     def cache_dir(self) -> Path: | ||||
|         return self._cache_dir | ||||
|  | ||||
|     @cache_dir.setter | ||||
|     def cache_dir(self, cache_dir: Path): | ||||
|         self._cache_dir = cache_dir | ||||
|  | ||||
|     @classmethod | ||||
|     def from_model_type( | ||||
|         cls, model_type: ModelType, model_name, **model_kwargs | ||||
|     ) -> InferenceModel: | ||||
|         subclasses = { | ||||
|             subclass._model_type: subclass for subclass in cls.__subclasses__() | ||||
|         } | ||||
|         if model_type not in subclasses: | ||||
|             raise ValueError(f"Unsupported model type: {model_type}") | ||||
|  | ||||
|         return subclasses[model_type](model_name, **model_kwargs) | ||||
| @@ -1,8 +1,11 @@ | ||||
| from aiocache.plugins import TimingPlugin, BasePlugin | ||||
| import asyncio | ||||
| 
 | ||||
| from aiocache.backends.memory import SimpleMemoryCache | ||||
| from aiocache.lock import OptimisticLock | ||||
| from typing import Any | ||||
| from models import get_model | ||||
| from aiocache.plugins import BasePlugin, TimingPlugin | ||||
| 
 | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
| 
 | ||||
| 
 | ||||
| class ModelCache: | ||||
| @@ -10,7 +13,7 @@ class ModelCache: | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         ttl: int | None = None, | ||||
|         ttl: float | None = None, | ||||
|         revalidate: bool = False, | ||||
|         timeout: int | None = None, | ||||
|         profiling: bool = False, | ||||
| @@ -35,9 +38,9 @@ class ModelCache: | ||||
|             ttl=ttl, timeout=timeout, plugins=plugins, namespace=None | ||||
|         ) | ||||
| 
 | ||||
|     async def get_cached_model( | ||||
|         self, model_name: str, model_type: str, **model_kwargs | ||||
|     ) -> Any: | ||||
|     async def get( | ||||
|         self, model_name: str, model_type: ModelType, **model_kwargs | ||||
|     ) -> InferenceModel: | ||||
|         """ | ||||
|         Args: | ||||
|             model_name: Name of model in the model hub used for the task. | ||||
| @@ -47,11 +50,16 @@ class ModelCache: | ||||
|             model: The requested model. | ||||
|         """ | ||||
| 
 | ||||
|         key = self.cache.build_key(model_name, model_type) | ||||
|         key = self.cache.build_key(model_name, model_type.value) | ||||
|         model = await self.cache.get(key) | ||||
|         if model is None: | ||||
|             async with OptimisticLock(self.cache, key) as lock: | ||||
|                 model = get_model(model_name, model_type, **model_kwargs) | ||||
|                 model = await asyncio.get_running_loop().run_in_executor( | ||||
|                     None, | ||||
|                     lambda: InferenceModel.from_model_type( | ||||
|                         model_type, model_name, **model_kwargs | ||||
|                     ), | ||||
|                 ) | ||||
|                 await lock.cas(model, ttl=self.ttl) | ||||
|         return model | ||||
| 
 | ||||
							
								
								
									
										37
									
								
								machine-learning/app/models/clip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								machine-learning/app/models/clip.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| from PIL.Image import Image | ||||
| from sentence_transformers import SentenceTransformer | ||||
|  | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| class CLIPSTEncoder(InferenceModel): | ||||
|     _model_type = ModelType.CLIP | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: Path | None = None, | ||||
|         **model_kwargs, | ||||
|     ): | ||||
|         super().__init__(model_name, cache_dir) | ||||
|         self.model = SentenceTransformer( | ||||
|             self.model_name, | ||||
|             cache_folder=self.cache_dir.as_posix(), | ||||
|             **model_kwargs, | ||||
|         ) | ||||
|  | ||||
|     def predict(self, image_or_text: Image | str) -> list[float]: | ||||
|         return self.model.encode(image_or_text).tolist() | ||||
|  | ||||
|  | ||||
| # stubs to allow different behavior between the two in the future | ||||
| # and handle loading different image and text clip models | ||||
| class CLIPSTVisionEncoder(CLIPSTEncoder): | ||||
|     _model_type = ModelType.CLIP_VISION | ||||
|  | ||||
|  | ||||
| class CLIPSTTextEncoder(CLIPSTEncoder): | ||||
|     _model_type = ModelType.CLIP_TEXT | ||||
							
								
								
									
										59
									
								
								machine-learning/app/models/facial_recognition.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								machine-learning/app/models/facial_recognition.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| from pathlib import Path | ||||
| from typing import Any | ||||
|  | ||||
| import cv2 | ||||
| from insightface.app import FaceAnalysis | ||||
|  | ||||
| from ..config import settings | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| class FaceRecognizer(InferenceModel): | ||||
|     _model_type = ModelType.FACIAL_RECOGNITION | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         min_score: float = settings.min_face_score, | ||||
|         cache_dir: Path | None = None, | ||||
|         **model_kwargs, | ||||
|     ): | ||||
|         super().__init__(model_name, cache_dir) | ||||
|         self.min_score = min_score | ||||
|         model = FaceAnalysis( | ||||
|             name=self.model_name, | ||||
|             root=self.cache_dir.as_posix(), | ||||
|             allowed_modules=["detection", "recognition"], | ||||
|             **model_kwargs, | ||||
|         ) | ||||
|         model.prepare( | ||||
|             ctx_id=0, | ||||
|             det_thresh=self.min_score, | ||||
|             det_size=(640, 640), | ||||
|         ) | ||||
|         self.model = model | ||||
|  | ||||
|     def predict(self, image: cv2.Mat) -> list[dict[str, Any]]: | ||||
|         height, width, _ = image.shape | ||||
|         results = [] | ||||
|         faces = self.model.get(image) | ||||
|  | ||||
|         for face in faces: | ||||
|             x1, y1, x2, y2 = face.bbox | ||||
|  | ||||
|             results.append( | ||||
|                 { | ||||
|                     "imageWidth": width, | ||||
|                     "imageHeight": height, | ||||
|                     "boundingBox": { | ||||
|                         "x1": round(x1), | ||||
|                         "y1": round(y1), | ||||
|                         "x2": round(x2), | ||||
|                         "y2": round(y2), | ||||
|                     }, | ||||
|                     "score": face.det_score.item(), | ||||
|                     "embedding": face.normed_embedding.tolist(), | ||||
|                 } | ||||
|             ) | ||||
|         return results | ||||
							
								
								
									
										40
									
								
								machine-learning/app/models/image_classification.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								machine-learning/app/models/image_classification.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| from PIL.Image import Image | ||||
| from transformers.pipelines import pipeline | ||||
|  | ||||
| from ..config import settings | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| class ImageClassifier(InferenceModel): | ||||
|     _model_type = ModelType.IMAGE_CLASSIFICATION | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         min_score: float = settings.min_tag_score, | ||||
|         cache_dir: Path | None = None, | ||||
|         **model_kwargs, | ||||
|     ): | ||||
|         super().__init__(model_name, cache_dir) | ||||
|         self.min_score = min_score | ||||
|  | ||||
|         self.model = pipeline( | ||||
|             self.model_type.value, | ||||
|             self.model_name, | ||||
|             model_kwargs={"cache_dir": self.cache_dir, **model_kwargs}, | ||||
|         ) | ||||
|  | ||||
|     def predict(self, image: Image) -> list[str]: | ||||
|         predictions = self.model(image) | ||||
|         tags = list( | ||||
|             { | ||||
|                 tag | ||||
|                 for pred in predictions | ||||
|                 for tag in pred["label"].split(", ") | ||||
|                 if pred["score"] >= self.min_score | ||||
|             } | ||||
|         ) | ||||
|         return tags | ||||
| @@ -1,3 +1,5 @@ | ||||
| from enum import Enum | ||||
|  | ||||
| from pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| @@ -54,3 +56,11 @@ class Face(BaseModel): | ||||
|  | ||||
| class FaceResponse(BaseModel): | ||||
|     __root__: list[Face] | ||||
|  | ||||
|  | ||||
| class ModelType(Enum): | ||||
|     IMAGE_CLASSIFICATION = "image-classification" | ||||
|     CLIP = "clip" | ||||
|     CLIP_VISION = "clip-vision" | ||||
|     CLIP_TEXT = "clip-text" | ||||
|     FACIAL_RECOGNITION = "facial-recognition" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user