mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
feat(ml)!: customizable ML settings (#3891)
* consolidated endpoints, added live configuration * added ml settings to server * added settings dashboard * updated deps, fixed typos * simplified modelconfig updated tests * Added ml setting accordion for admin page updated tests * merge `clipText` and `clipVision` * added face distance setting clarified setting * add clip mode in request, dropdown for face models * polished ml settings updated descriptions * update clip field on error * removed unused import * add description for image classification threshold * pin safetensors for arm wheel updated poetry lock * moved dto * set model type only in ml repository * revert form-data package install use fetch instead of axios * added slotted description with link updated facial recognition description clarified effect of disabling tasks * validation before model load * removed unnecessary getconfig call * added migration * updated api updated api updated api --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
@@ -8,17 +8,11 @@ from .schemas import ModelType
|
||||
|
||||
class Settings(BaseSettings):
|
||||
cache_folder: str = "/cache"
|
||||
classification_model: str = "microsoft/resnet-50"
|
||||
clip_image_model: str = "ViT-B-32::openai"
|
||||
clip_text_model: str = "ViT-B-32::openai"
|
||||
facial_recognition_model: str = "buffalo_l"
|
||||
min_tag_score: float = 0.9
|
||||
eager_startup: bool = False
|
||||
model_ttl: int = 0
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 3003
|
||||
workers: int = 1
|
||||
min_face_score: float = 0.7
|
||||
test_full: bool = False
|
||||
request_threads: int = os.cpu_count() or 4
|
||||
model_inter_op_threads: int = 1
|
||||
|
||||
@@ -1,29 +1,26 @@
|
||||
import asyncio
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import Body, Depends, FastAPI
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
|
||||
from .config import settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
EmbeddingResponse,
|
||||
FaceResponse,
|
||||
MessageResponse,
|
||||
ModelType,
|
||||
TagResponse,
|
||||
TextModelRequest,
|
||||
TextResponse,
|
||||
)
|
||||
|
||||
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@@ -33,37 +30,9 @@ def init_state() -> None:
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
|
||||
|
||||
|
||||
async def load_models() -> None:
|
||||
models: list[tuple[str, ModelType, dict[str, Any]]] = [
|
||||
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}),
|
||||
(settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}),
|
||||
(settings.clip_text_model, ModelType.CLIP, {"mode": "text"}),
|
||||
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}),
|
||||
]
|
||||
|
||||
# Get all models
|
||||
for model_name, model_type, model_kwargs in models:
|
||||
await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
init_state()
|
||||
await load_models()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
app.state.thread_pool.shutdown()
|
||||
|
||||
|
||||
def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
|
||||
return Image.open(BytesIO(byte_image))
|
||||
|
||||
|
||||
def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]:
|
||||
byte_image_np = np.frombuffer(byte_image, np.uint8)
|
||||
return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
|
||||
|
||||
|
||||
@app.get("/", response_model=MessageResponse)
|
||||
@@ -76,57 +45,27 @@ def ping() -> str:
|
||||
return "pong"
|
||||
|
||||
|
||||
@app.post(
|
||||
"/image-classifier/tag-image",
|
||||
response_model=TagResponse,
|
||||
status_code=200,
|
||||
)
|
||||
async def image_classification(
|
||||
image: Image.Image = Depends(dep_pil_image),
|
||||
) -> list[str]:
|
||||
model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
|
||||
labels = await predict(model, image)
|
||||
return labels
|
||||
@app.post("/predict")
|
||||
async def predict(
|
||||
model_name: str = Form(alias="modelName"),
|
||||
model_type: ModelType = Form(alias="modelType"),
|
||||
options: str = Form(default="{}"),
|
||||
text: str | None = Form(default=None),
|
||||
image: UploadFile | None = None,
|
||||
) -> Any:
|
||||
if image is not None:
|
||||
inputs: str | bytes = await image.read()
|
||||
elif text is not None:
|
||||
inputs = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
|
||||
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
|
||||
outputs = await run(model, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
||||
|
||||
@app.post(
|
||||
"/sentence-transformer/encode-image",
|
||||
response_model=EmbeddingResponse,
|
||||
status_code=200,
|
||||
)
|
||||
async def clip_encode_image(
|
||||
image: Image.Image = Depends(dep_pil_image),
|
||||
) -> list[float]:
|
||||
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision")
|
||||
embedding = await predict(model, image)
|
||||
return embedding
|
||||
|
||||
|
||||
@app.post(
|
||||
"/sentence-transformer/encode-text",
|
||||
response_model=EmbeddingResponse,
|
||||
status_code=200,
|
||||
)
|
||||
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
||||
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text")
|
||||
embedding = await predict(model, payload.text)
|
||||
return embedding
|
||||
|
||||
|
||||
@app.post(
|
||||
"/facial-recognition/detect-faces",
|
||||
response_model=FaceResponse,
|
||||
status_code=200,
|
||||
)
|
||||
async def facial_recognition(
|
||||
image: cv2.Mat = Depends(dep_cv_image),
|
||||
) -> list[dict[str, Any]]:
|
||||
model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
|
||||
faces = await predict(model, image)
|
||||
return faces
|
||||
|
||||
|
||||
async def predict(model: InferenceModel, inputs: Any) -> Any:
|
||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
|
||||
|
||||
|
||||
@@ -60,16 +60,21 @@ class InferenceModel(ABC):
|
||||
self._load(**model_kwargs)
|
||||
self._loaded = True
|
||||
|
||||
def predict(self, inputs: Any) -> Any:
|
||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||
if not self._loaded:
|
||||
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(inputs)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, inputs: Any) -> Any:
|
||||
...
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _download(self, **model_kwargs: Any) -> None:
|
||||
...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
|
||||
import onnxruntime as ort
|
||||
@@ -8,7 +9,7 @@ from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
||||
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
|
||||
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
|
||||
from clip_server.model.tokenization import Tokenizer
|
||||
from PIL.Image import Image
|
||||
from PIL import Image
|
||||
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
||||
|
||||
from ..schemas import ModelType
|
||||
@@ -74,9 +75,12 @@ class CLIPEncoder(InferenceModel):
|
||||
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
||||
self.transform = _transform_pil_image(image_size)
|
||||
|
||||
def _predict(self, image_or_text: Image | str) -> list[float]:
|
||||
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
||||
if isinstance(image_or_text, bytes):
|
||||
image_or_text = Image.open(BytesIO(image_or_text))
|
||||
|
||||
match image_or_text:
|
||||
case Image():
|
||||
case Image.Image():
|
||||
if self.mode == "text":
|
||||
raise TypeError("Cannot encode image as text-only model")
|
||||
pixel_values = self.transform(image_or_text)
|
||||
|
||||
@@ -9,7 +9,6 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
||||
from insightface.utils.face_align import norm_crop
|
||||
from insightface.utils.storage import BASE_REPO_URL, download_file
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
@@ -20,7 +19,7 @@ class FaceRecognizer(InferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_face_score,
|
||||
min_score: float = 0.7,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
@@ -69,11 +68,13 @@ class FaceRecognizer(InferenceModel):
|
||||
)
|
||||
self.rec_model.prepare(ctx_id=0)
|
||||
|
||||
def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
|
||||
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
|
||||
if isinstance(image, bytes):
|
||||
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
||||
bboxes, kpss = self.det_model.detect(image)
|
||||
if bboxes.size == 0:
|
||||
return []
|
||||
assert isinstance(kpss, np.ndarray)
|
||||
assert isinstance(image, np.ndarray) and isinstance(kpss, np.ndarray)
|
||||
|
||||
scores = bboxes[:, 4].tolist()
|
||||
bboxes = bboxes[:, :4].round().tolist()
|
||||
@@ -102,3 +103,6 @@ class FaceRecognizer(InferenceModel):
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
self.det_model.det_thresh = model_kwargs.get("min_score", self.det_model.det_thresh)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from optimum.onnxruntime import ORTModelForImageClassification
|
||||
from optimum.pipelines import pipeline
|
||||
from PIL.Image import Image
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
from ..config import settings
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
@@ -18,7 +18,7 @@ class ImageClassifier(InferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
min_score: float = settings.min_tag_score,
|
||||
min_score: float = 0.9,
|
||||
cache_dir: Path | str | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
@@ -56,8 +56,13 @@ class ImageClassifier(InferenceModel):
|
||||
feature_extractor=processor,
|
||||
)
|
||||
|
||||
def _predict(self, image: Image) -> list[str]:
|
||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
||||
if isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
|
||||
return tags
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
self.min_score = model_kwargs.get("min_score", self.min_score)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -20,18 +20,6 @@ class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class TagResponse(BaseModel):
|
||||
__root__: list[str]
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
__root__: list[float]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
__root__: Embedding
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
x1: int
|
||||
y1: int
|
||||
@@ -39,23 +27,7 @@ class BoundingBox(BaseModel):
|
||||
y2: int
|
||||
|
||||
|
||||
class Face(BaseModel):
|
||||
image_width: int
|
||||
image_height: int
|
||||
bounding_box: BoundingBox
|
||||
score: float
|
||||
embedding: Embedding
|
||||
|
||||
class Config:
|
||||
alias_generator = to_lower_camel
|
||||
allow_population_by_field_name = True
|
||||
|
||||
|
||||
class FaceResponse(BaseModel):
|
||||
__root__: list[Face]
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
class ModelType(StrEnum):
|
||||
IMAGE_CLASSIFICATION = "image-classification"
|
||||
CLIP = "clip"
|
||||
FACIAL_RECOGNITION = "facial-recognition"
|
||||
|
||||
Reference in New Issue
Block a user