mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	feat(ml): export clip models to ONNX and host models on Hugging Face (#4700)
* export clip models * export to hf refactored export code * export mclip, general refactoring cleanup * updated conda deps * do transforms with pillow and numpy, add tokenization config to export, general refactoring * moved conda dockerfile, re-added poetry * minor fixes * updated link * updated tests * removed `requirements.txt` from workflow * fixed mimalloc path * removed torchvision * cleaner np typing * review suggestions * update default model name * update test
This commit is contained in:
		| @@ -10,9 +10,8 @@ RUN poetry config installer.max-workers 10 && \ | ||||
| RUN python -m venv /opt/venv | ||||
| ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}" | ||||
|  | ||||
| COPY poetry.lock pyproject.toml requirements.txt ./ | ||||
| COPY poetry.lock pyproject.toml ./ | ||||
| RUN poetry install --sync --no-interaction --no-ansi --no-root --only main | ||||
| RUN pip install --no-deps -r requirements.txt | ||||
|  | ||||
| FROM python:3.11-slim-bookworm | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| import json | ||||
| from typing import Any, Iterator, TypeAlias | ||||
| from pathlib import Path | ||||
| from typing import Any, Iterator | ||||
| from unittest import mock | ||||
|  | ||||
| import numpy as np | ||||
| @@ -8,8 +9,7 @@ from fastapi.testclient import TestClient | ||||
| from PIL import Image | ||||
|  | ||||
| from .main import app, init_state | ||||
|  | ||||
| ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]] | ||||
| from .schemas import ndarray_f32 | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @@ -18,13 +18,13 @@ def pil_image() -> Image.Image: | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def cv_image(pil_image: Image.Image) -> ndarray: | ||||
| def cv_image(pil_image: Image.Image) -> ndarray_f32: | ||||
|     return np.asarray(pil_image)[:, :, ::-1]  # PIL uses RGB while cv2 uses BGR | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def mock_get_model() -> Iterator[mock.Mock]: | ||||
|     with mock.patch("app.models.cache.InferenceModel.from_model_type", autospec=True) as mocked: | ||||
|     with mock.patch("app.models.cache.from_model_type", autospec=True) as mocked: | ||||
|         yield mocked | ||||
|  | ||||
|  | ||||
| @@ -37,3 +37,25 @@ def deployed_app() -> TestClient: | ||||
| @pytest.fixture(scope="session") | ||||
| def responses() -> dict[str, Any]: | ||||
|     return json.load(open("responses.json", "r")) | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="session") | ||||
| def clip_model_cfg() -> dict[str, Any]: | ||||
|     return { | ||||
|         "embed_dim": 512, | ||||
|         "vision_cfg": {"image_size": 224, "layers": 12, "width": 768, "patch_size": 32}, | ||||
|         "text_cfg": {"context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12}, | ||||
|     } | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="session") | ||||
| def clip_preprocess_cfg() -> dict[str, Any]: | ||||
|     return { | ||||
|         "size": [224, 224], | ||||
|         "mode": "RGB", | ||||
|         "mean": [0.48145466, 0.4578275, 0.40821073], | ||||
|         "std": [0.26862954, 0.26130258, 0.27577711], | ||||
|         "interpolation": "bicubic", | ||||
|         "resize_mode": "shortest", | ||||
|         "fill_color": 0, | ||||
|     } | ||||
|   | ||||
| @@ -1,3 +1,25 @@ | ||||
| from .clip import CLIPEncoder | ||||
| from typing import Any | ||||
|  | ||||
| from app.schemas import ModelType | ||||
|  | ||||
| from .base import InferenceModel | ||||
| from .clip import MCLIPEncoder, OpenCLIPEncoder, is_mclip, is_openclip | ||||
| from .facial_recognition import FaceRecognizer | ||||
| from .image_classification import ImageClassifier | ||||
|  | ||||
|  | ||||
| def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel: | ||||
|     match model_type: | ||||
|         case ModelType.CLIP: | ||||
|             if is_openclip(model_name): | ||||
|                 return OpenCLIPEncoder(model_name, **model_kwargs) | ||||
|             elif is_mclip(model_name): | ||||
|                 return MCLIPEncoder(model_name, **model_kwargs) | ||||
|             else: | ||||
|                 raise ValueError(f"Unknown CLIP model {model_name}") | ||||
|         case ModelType.FACIAL_RECOGNITION: | ||||
|             return FaceRecognizer(model_name, **model_kwargs) | ||||
|         case ModelType.IMAGE_CLASSIFICATION: | ||||
|             return ImageClassifier(model_name, **model_kwargs) | ||||
|         case _: | ||||
|             raise ValueError(f"Unknown model type {model_type}") | ||||
|   | ||||
| @@ -25,7 +25,7 @@ class InferenceModel(ABC): | ||||
|     ) -> None: | ||||
|         self.model_name = model_name | ||||
|         self.loaded = False | ||||
|         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type) | ||||
|         self._cache_dir = Path(cache_dir) if cache_dir is not None else None | ||||
|         self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"]) | ||||
|         #  don't pre-allocate more memory than needed | ||||
|         self.provider_options = model_kwargs.pop( | ||||
| @@ -92,7 +92,7 @@ class InferenceModel(ABC): | ||||
|  | ||||
|     @property | ||||
|     def cache_dir(self) -> Path: | ||||
|         return self._cache_dir | ||||
|         return self._cache_dir if self._cache_dir is not None else get_cache_dir(self.model_name, self.model_type) | ||||
|  | ||||
|     @cache_dir.setter | ||||
|     def cache_dir(self, cache_dir: Path) -> None: | ||||
|   | ||||
| @@ -4,6 +4,8 @@ from aiocache.backends.memory import SimpleMemoryCache | ||||
| from aiocache.lock import OptimisticLock | ||||
| from aiocache.plugins import BasePlugin, TimingPlugin | ||||
|  | ||||
| from app.models import from_model_type | ||||
|  | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
| @@ -50,7 +52,7 @@ class ModelCache: | ||||
|         async with OptimisticLock(self.cache, key) as lock: | ||||
|             model = await self.cache.get(key) | ||||
|             if model is None: | ||||
|                 model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs) | ||||
|                 model = from_model_type(model_type, model_name, **model_kwargs) | ||||
|                 await lock.cas(model, ttl=self.ttl) | ||||
|         return model | ||||
|  | ||||
|   | ||||
| @@ -1,23 +1,24 @@ | ||||
| import os | ||||
| import zipfile | ||||
| import json | ||||
| from abc import abstractmethod | ||||
| from functools import cached_property | ||||
| from io import BytesIO | ||||
| from pathlib import Path | ||||
| from typing import Any, Literal | ||||
|  | ||||
| import numpy as np | ||||
| import onnxruntime as ort | ||||
| import torch | ||||
| from clip_server.model.clip import BICUBIC, _convert_image_to_rgb | ||||
| from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model | ||||
| from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE | ||||
| from clip_server.model.tokenization import Tokenizer | ||||
| from huggingface_hub import snapshot_download | ||||
| from PIL import Image | ||||
| from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor | ||||
| from transformers import AutoTokenizer | ||||
|  | ||||
| from app.config import log | ||||
| from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy | ||||
| from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64 | ||||
|  | ||||
| from ..config import log | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| class CLIPEncoder(InferenceModel): | ||||
| class BaseCLIPEncoder(InferenceModel): | ||||
|     _model_type = ModelType.CLIP | ||||
|  | ||||
|     def __init__( | ||||
| @@ -27,48 +28,29 @@ class CLIPEncoder(InferenceModel): | ||||
|         mode: Literal["text", "vision"] | None = None, | ||||
|         **model_kwargs: Any, | ||||
|     ) -> None: | ||||
|         if mode is not None and mode not in ("text", "vision"): | ||||
|             raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") | ||||
|         if model_name not in _MODELS: | ||||
|             raise ValueError(f"Unknown model name {model_name}.") | ||||
|         self.mode = mode | ||||
|         super().__init__(model_name, cache_dir, **model_kwargs) | ||||
|  | ||||
|     def _download(self) -> None: | ||||
|         models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] | ||||
|         text_onnx_path = self.cache_dir / "textual.onnx" | ||||
|         vision_onnx_path = self.cache_dir / "visual.onnx" | ||||
|  | ||||
|         if not text_onnx_path.is_file(): | ||||
|             self._download_model(*models[0]) | ||||
|  | ||||
|         if not vision_onnx_path.is_file(): | ||||
|             self._download_model(*models[1]) | ||||
|  | ||||
|     def _load(self) -> None: | ||||
|         if self.mode == "text" or self.mode is None: | ||||
|             log.debug(f"Loading clip text model '{self.model_name}'") | ||||
|  | ||||
|             self.text_model = ort.InferenceSession( | ||||
|                 self.cache_dir / "textual.onnx", | ||||
|                 self.textual_path.as_posix(), | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ) | ||||
|             self.text_outputs = [output.name for output in self.text_model.get_outputs()] | ||||
|             self.tokenizer = Tokenizer(self.model_name) | ||||
|  | ||||
|         if self.mode == "vision" or self.mode is None: | ||||
|             log.debug(f"Loading clip vision model '{self.model_name}'") | ||||
|  | ||||
|             self.vision_model = ort.InferenceSession( | ||||
|                 self.cache_dir / "visual.onnx", | ||||
|                 self.visual_path.as_posix(), | ||||
|                 sess_options=self.sess_options, | ||||
|                 providers=self.providers, | ||||
|                 provider_options=self.provider_options, | ||||
|             ) | ||||
|             self.vision_outputs = [output.name for output in self.vision_model.get_outputs()] | ||||
|  | ||||
|             image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] | ||||
|             self.transform = _transform_pil_image(image_size) | ||||
|  | ||||
|     def _predict(self, image_or_text: Image.Image | str) -> list[float]: | ||||
|         if isinstance(image_or_text, bytes): | ||||
| @@ -78,55 +60,163 @@ class CLIPEncoder(InferenceModel): | ||||
|             case Image.Image(): | ||||
|                 if self.mode == "text": | ||||
|                     raise TypeError("Cannot encode image as text-only model") | ||||
|                 pixel_values = self.transform(image_or_text) | ||||
|                 assert isinstance(pixel_values, torch.Tensor) | ||||
|                 pixel_values = torch.unsqueeze(pixel_values, 0).numpy() | ||||
|                 outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) | ||||
|  | ||||
|                 outputs = self.vision_model.run(None, self.transform(image_or_text)) | ||||
|             case str(): | ||||
|                 if self.mode == "vision": | ||||
|                     raise TypeError("Cannot encode text as vision-only model") | ||||
|                 text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text) | ||||
|                 inputs = { | ||||
|                     "input_ids": text_inputs["input_ids"].int().numpy(), | ||||
|                     "attention_mask": text_inputs["attention_mask"].int().numpy(), | ||||
|                 } | ||||
|                 outputs = self.text_model.run(self.text_outputs, inputs) | ||||
|  | ||||
|                 outputs = self.text_model.run(None, self.tokenize(image_or_text)) | ||||
|             case _: | ||||
|                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") | ||||
|  | ||||
|         return outputs[0][0].tolist() | ||||
|  | ||||
|     def _download_model(self, model_name: str, model_md5: str) -> bool: | ||||
|         # downloading logic is adapted from clip-server's CLIPOnnxModel class | ||||
|         download_model( | ||||
|             url=_S3_BUCKET_V2 + model_name, | ||||
|             target_folder=self.cache_dir.as_posix(), | ||||
|             md5sum=model_md5, | ||||
|             with_resume=True, | ||||
|         ) | ||||
|         file = self.cache_dir / model_name.split("/")[1] | ||||
|         if file.suffix == ".zip": | ||||
|             with zipfile.ZipFile(file, "r") as zip_ref: | ||||
|                 zip_ref.extractall(self.cache_dir) | ||||
|             os.remove(file) | ||||
|         return True | ||||
|     @abstractmethod | ||||
|     def tokenize(self, text: str) -> dict[str, ndarray_i32]: | ||||
|         pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def transform(self, image: Image.Image) -> dict[str, ndarray_f32]: | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     def textual_dir(self) -> Path: | ||||
|         return self.cache_dir / "textual" | ||||
|  | ||||
|     @property | ||||
|     def visual_dir(self) -> Path: | ||||
|         return self.cache_dir / "visual" | ||||
|  | ||||
|     @property | ||||
|     def model_cfg_path(self) -> Path: | ||||
|         return self.cache_dir / "config.json" | ||||
|  | ||||
|     @property | ||||
|     def textual_path(self) -> Path: | ||||
|         return self.textual_dir / "model.onnx" | ||||
|  | ||||
|     @property | ||||
|     def visual_path(self) -> Path: | ||||
|         return self.visual_dir / "model.onnx" | ||||
|  | ||||
|     @property | ||||
|     def preprocess_cfg_path(self) -> Path: | ||||
|         return self.visual_dir / "preprocess_cfg.json" | ||||
|  | ||||
|     @property | ||||
|     def cached(self) -> bool: | ||||
|         return (self.cache_dir / "textual.onnx").is_file() and (self.cache_dir / "visual.onnx").is_file() | ||||
|         return self.textual_path.is_file() and self.visual_path.is_file() | ||||
|  | ||||
|  | ||||
| # same as `_transform_blob` without `_blob2image` | ||||
| def _transform_pil_image(n_px: int) -> Compose: | ||||
|     return Compose( | ||||
|         [ | ||||
|             Resize(n_px, interpolation=BICUBIC), | ||||
|             CenterCrop(n_px), | ||||
|             _convert_image_to_rgb, | ||||
|             ToTensor(), | ||||
|             Normalize( | ||||
|                 (0.48145466, 0.4578275, 0.40821073), | ||||
|                 (0.26862954, 0.26130258, 0.27577711), | ||||
|             ), | ||||
|         ] | ||||
|     ) | ||||
| class OpenCLIPEncoder(BaseCLIPEncoder): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: str | None = None, | ||||
|         mode: Literal["text", "vision"] | None = None, | ||||
|         **model_kwargs: Any, | ||||
|     ) -> None: | ||||
|         super().__init__(_clean_model_name(model_name), cache_dir, mode, **model_kwargs) | ||||
|  | ||||
|     def _download(self) -> None: | ||||
|         snapshot_download( | ||||
|             f"immich-app/{self.model_name}", | ||||
|             cache_dir=self.cache_dir, | ||||
|             local_dir=self.cache_dir, | ||||
|             local_dir_use_symlinks=False, | ||||
|         ) | ||||
|  | ||||
|     def _load(self) -> None: | ||||
|         super()._load() | ||||
|  | ||||
|         self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir) | ||||
|         self.sequence_length = self.model_cfg["text_cfg"]["context_length"] | ||||
|  | ||||
|         self.size = ( | ||||
|             self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"] | ||||
|         ) | ||||
|         self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"]) | ||||
|         self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32) | ||||
|         self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32) | ||||
|  | ||||
|     def tokenize(self, text: str) -> dict[str, ndarray_i32]: | ||||
|         input_ids: ndarray_i64 = self.tokenizer( | ||||
|             text, | ||||
|             max_length=self.sequence_length, | ||||
|             return_tensors="np", | ||||
|             return_attention_mask=False, | ||||
|             padding="max_length", | ||||
|             truncation=True, | ||||
|         ).input_ids | ||||
|         return {"text": input_ids.astype(np.int32)} | ||||
|  | ||||
|     def transform(self, image: Image.Image) -> dict[str, ndarray_f32]: | ||||
|         image = resize(image, self.size) | ||||
|         image = crop(image, self.size) | ||||
|         image_np = to_numpy(image) | ||||
|         image_np = normalize(image_np, self.mean, self.std) | ||||
|         return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)} | ||||
|  | ||||
|     @cached_property | ||||
|     def model_cfg(self) -> dict[str, Any]: | ||||
|         return json.load(self.model_cfg_path.open()) | ||||
|  | ||||
|     @cached_property | ||||
|     def preprocess_cfg(self) -> dict[str, Any]: | ||||
|         return json.load(self.preprocess_cfg_path.open()) | ||||
|  | ||||
|  | ||||
| class MCLIPEncoder(OpenCLIPEncoder): | ||||
|     def tokenize(self, text: str) -> dict[str, ndarray_i32]: | ||||
|         tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np") | ||||
|         return {k: v.astype(np.int32) for k, v in tokens.items()} | ||||
|  | ||||
|  | ||||
| _OPENCLIP_MODELS = { | ||||
|     "RN50__openai", | ||||
|     "RN50__yfcc15m", | ||||
|     "RN50__cc12m", | ||||
|     "RN101__openai", | ||||
|     "RN101__yfcc15m", | ||||
|     "RN50x4__openai", | ||||
|     "RN50x16__openai", | ||||
|     "RN50x64__openai", | ||||
|     "ViT-B-32__openai", | ||||
|     "ViT-B-32__laion2b_e16", | ||||
|     "ViT-B-32__laion400m_e31", | ||||
|     "ViT-B-32__laion400m_e32", | ||||
|     "ViT-B-32__laion2b-s34b-b79k", | ||||
|     "ViT-B-16__openai", | ||||
|     "ViT-B-16__laion400m_e31", | ||||
|     "ViT-B-16__laion400m_e32", | ||||
|     "ViT-B-16-plus-240__laion400m_e31", | ||||
|     "ViT-B-16-plus-240__laion400m_e32", | ||||
|     "ViT-L-14__openai", | ||||
|     "ViT-L-14__laion400m_e31", | ||||
|     "ViT-L-14__laion400m_e32", | ||||
|     "ViT-L-14__laion2b-s32b-b82k", | ||||
|     "ViT-L-14-336__openai", | ||||
|     "ViT-H-14__laion2b-s32b-b79k", | ||||
|     "ViT-g-14__laion2b-s12b-b42k", | ||||
| } | ||||
|  | ||||
|  | ||||
| _MCLIP_MODELS = { | ||||
|     "LABSE-Vit-L-14", | ||||
|     "XLM-Roberta-Large-Vit-B-32", | ||||
|     "XLM-Roberta-Large-Vit-B-16Plus", | ||||
|     "XLM-Roberta-Large-Vit-L-14", | ||||
| } | ||||
|  | ||||
|  | ||||
| def _clean_model_name(model_name: str) -> str: | ||||
|     return model_name.split("/")[-1].replace("::", "__") | ||||
|  | ||||
|  | ||||
| def is_openclip(model_name: str) -> bool: | ||||
|     return _clean_model_name(model_name) in _OPENCLIP_MODELS | ||||
|  | ||||
|  | ||||
| def is_mclip(model_name: str) -> bool: | ||||
|     return _clean_model_name(model_name) in _MCLIP_MODELS | ||||
|   | ||||
| @@ -9,7 +9,8 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace | ||||
| from insightface.utils.face_align import norm_crop | ||||
| from insightface.utils.storage import BASE_REPO_URL, download_file | ||||
|  | ||||
| from ..schemas import ModelType | ||||
| from app.schemas import ModelType, ndarray_f32 | ||||
|  | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| @@ -68,7 +69,7 @@ class FaceRecognizer(InferenceModel): | ||||
|         ) | ||||
|         self.rec_model.prepare(ctx_id=0) | ||||
|  | ||||
|     def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]: | ||||
|     def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]: | ||||
|         if isinstance(image, bytes): | ||||
|             image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR) | ||||
|         bboxes, kpss = self.det_model.detect(image) | ||||
|   | ||||
							
								
								
									
										35
									
								
								machine-learning/app/models/transforms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								machine-learning/app/models/transforms.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
|  | ||||
| from app.schemas import ndarray_f32 | ||||
|  | ||||
| _PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling} | ||||
|  | ||||
|  | ||||
| def resize(img: Image.Image, size: int) -> Image.Image: | ||||
|     if img.width < img.height: | ||||
|         return img.resize((size, int((img.height / img.width) * size)), resample=Image.BICUBIC) | ||||
|     else: | ||||
|         return img.resize((int((img.width / img.height) * size), size), resample=Image.BICUBIC) | ||||
|  | ||||
|  | ||||
| # https://stackoverflow.com/a/60883103 | ||||
| def crop(img: Image.Image, size: int) -> Image.Image: | ||||
|     left = int((img.size[0] / 2) - (size / 2)) | ||||
|     upper = int((img.size[1] / 2) - (size / 2)) | ||||
|     right = left + size | ||||
|     lower = upper + size | ||||
|  | ||||
|     return img.crop((left, upper, right, lower)) | ||||
|  | ||||
|  | ||||
| def to_numpy(img: Image.Image) -> ndarray_f32: | ||||
|     return np.asarray(img.convert("RGB")).astype(np.float32) / 255.0 | ||||
|  | ||||
|  | ||||
| def normalize(img: ndarray_f32, mean: float | ndarray_f32, std: float | ndarray_f32) -> ndarray_f32: | ||||
|     return (img - mean) / std | ||||
|  | ||||
|  | ||||
| def get_pil_resampling(resample: str) -> Image.Resampling: | ||||
|     return _PIL_RESAMPLING_METHODS[resample.lower()] | ||||
| @@ -1,5 +1,7 @@ | ||||
| from enum import StrEnum | ||||
| from typing import TypeAlias | ||||
|  | ||||
| import numpy as np | ||||
| from pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| @@ -31,3 +33,8 @@ class ModelType(StrEnum): | ||||
|     IMAGE_CLASSIFICATION = "image-classification" | ||||
|     CLIP = "clip" | ||||
|     FACIAL_RECOGNITION = "facial-recognition" | ||||
|  | ||||
|  | ||||
| ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]] | ||||
| ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]] | ||||
| ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]] | ||||
|   | ||||
| @@ -1,7 +1,8 @@ | ||||
| import json | ||||
| import pickle | ||||
| from io import BytesIO | ||||
| from typing import Any, TypeAlias | ||||
| from pathlib import Path | ||||
| from typing import Any, Callable | ||||
| from unittest import mock | ||||
|  | ||||
| import cv2 | ||||
| @@ -14,13 +15,11 @@ from pytest_mock import MockerFixture | ||||
| from .config import settings | ||||
| from .models.base import PicklableSessionOptions | ||||
| from .models.cache import ModelCache | ||||
| from .models.clip import CLIPEncoder | ||||
| from .models.clip import OpenCLIPEncoder | ||||
| from .models.facial_recognition import FaceRecognizer | ||||
| from .models.image_classification import ImageClassifier | ||||
| from .schemas import ModelType | ||||
|  | ||||
| ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]] | ||||
|  | ||||
|  | ||||
| class TestImageClassifier: | ||||
|     classifier_preds = [ | ||||
| @@ -56,30 +55,50 @@ class TestImageClassifier: | ||||
|  | ||||
| class TestCLIP: | ||||
|     embedding = np.random.rand(512).astype(np.float32) | ||||
|     cache_dir = Path("test_cache") | ||||
|  | ||||
|     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: | ||||
|         mocker.patch.object(CLIPEncoder, "download") | ||||
|     def test_basic_image( | ||||
|         self, | ||||
|         pil_image: Image.Image, | ||||
|         mocker: MockerFixture, | ||||
|         clip_model_cfg: dict[str, Any], | ||||
|         clip_preprocess_cfg: Callable[[Path], dict[str, Any]], | ||||
|     ) -> None: | ||||
|         mocker.patch.object(OpenCLIPEncoder, "download") | ||||
|         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) | ||||
|         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) | ||||
|         mocker.patch("app.models.clip.AutoTokenizer.from_pretrained", autospec=True) | ||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||
|         mocked.return_value.run.return_value = [[self.embedding]] | ||||
|         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") | ||||
|         assert clip_encoder.mode == "vision" | ||||
|  | ||||
|         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") | ||||
|         embedding = clip_encoder.predict(pil_image) | ||||
|  | ||||
|         assert clip_encoder.mode == "vision" | ||||
|         assert isinstance(embedding, list) | ||||
|         assert len(embedding) == 512 | ||||
|         assert len(embedding) == clip_model_cfg["embed_dim"] | ||||
|         assert all([isinstance(num, float) for num in embedding]) | ||||
|         clip_encoder.vision_model.run.assert_called_once() | ||||
|  | ||||
|     def test_basic_text(self, mocker: MockerFixture) -> None: | ||||
|         mocker.patch.object(CLIPEncoder, "download") | ||||
|     def test_basic_text( | ||||
|         self, | ||||
|         mocker: MockerFixture, | ||||
|         clip_model_cfg: dict[str, Any], | ||||
|         clip_preprocess_cfg: Callable[[Path], dict[str, Any]], | ||||
|     ) -> None: | ||||
|         mocker.patch.object(OpenCLIPEncoder, "download") | ||||
|         mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg) | ||||
|         mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg) | ||||
|         mocker.patch("app.models.clip.AutoTokenizer.from_pretrained", autospec=True) | ||||
|         mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) | ||||
|         mocked.return_value.run.return_value = [[self.embedding]] | ||||
|         clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") | ||||
|         assert clip_encoder.mode == "text" | ||||
|  | ||||
|         clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") | ||||
|         embedding = clip_encoder.predict("test search query") | ||||
|  | ||||
|         assert clip_encoder.mode == "text" | ||||
|         assert isinstance(embedding, list) | ||||
|         assert len(embedding) == 512 | ||||
|         assert len(embedding) == clip_model_cfg["embed_dim"] | ||||
|         assert all([isinstance(num, float) for num in embedding]) | ||||
|         clip_encoder.text_model.run.assert_called_once() | ||||
|  | ||||
|   | ||||
							
								
								
									
										21
									
								
								machine-learning/export/Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								machine-learning/export/Dockerfile
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | ||||
| FROM mambaorg/micromamba:bookworm-slim as builder | ||||
|  | ||||
| ENV NODE_ENV=production \ | ||||
|   TRANSFORMERS_CACHE=/cache \ | ||||
|   PYTHONDONTWRITEBYTECODE=1 \ | ||||
|   PYTHONUNBUFFERED=1 \ | ||||
|   PATH="/opt/venv/bin:$PATH" \ | ||||
|   PYTHONPATH=/usr/src | ||||
|  | ||||
| COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml /tmp/conda-lock.yml | ||||
| RUN micromamba install -y -n base -f /tmp/conda-lock.yml && \ | ||||
|     micromamba remove -y -n base cxx-compiler && \ | ||||
|     micromamba clean --all --yes | ||||
|  | ||||
| WORKDIR /usr/src/app | ||||
|  | ||||
| COPY --chown=$MAMBA_USER:$MAMBA_USER start.sh . | ||||
| COPY --chown=$MAMBA_USER:$MAMBA_USER app . | ||||
|  | ||||
| ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"] | ||||
| CMD ["./start.sh"] | ||||
							
								
								
									
										3520
									
								
								machine-learning/export/conda-lock.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3520
									
								
								machine-learning/export/conda-lock.yml
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										15
									
								
								machine-learning/export/env.dev.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								machine-learning/export/env.dev.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| name: base | ||||
| channels: | ||||
|   - conda-forge | ||||
| platforms: | ||||
|   - linux-64 | ||||
|   - linux-aarch64 | ||||
| dependencies: | ||||
|   - black | ||||
|   - conda-lock | ||||
|   - mypy | ||||
|   - pytest | ||||
|   - pytest-cov | ||||
|   - pytest-mock | ||||
|   - ruff | ||||
| category: dev | ||||
							
								
								
									
										25
									
								
								machine-learning/export/env.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								machine-learning/export/env.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| name: base | ||||
| channels: | ||||
|   - conda-forge | ||||
|   - nvidia | ||||
|   - pytorch-nightly | ||||
| platforms: | ||||
|   - linux-64 | ||||
| dependencies: | ||||
|   - cxx-compiler | ||||
|   - onnx==1.* | ||||
|   - onnxruntime==1.* | ||||
|   - open-clip-torch==2.* | ||||
|   - orjson==3.* | ||||
|   - pip | ||||
|   - python==3.11.* | ||||
|   - pytorch | ||||
|   - rich==13.* | ||||
|   - safetensors==0.* | ||||
|   - setuptools==68.* | ||||
|   - torchvision | ||||
|   - transformers==4.* | ||||
|   - pip: | ||||
|     - multilingual-clip | ||||
|     - onnx-simplifier | ||||
| category: main | ||||
							
								
								
									
										0
									
								
								machine-learning/export/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								machine-learning/export/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										67
									
								
								machine-learning/export/models/mclip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								machine-learning/export/models/mclip.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| import tempfile | ||||
| import warnings | ||||
| from pathlib import Path | ||||
|  | ||||
| import torch | ||||
| from multilingual_clip.pt_multilingual_clip import MultilingualCLIP | ||||
| from transformers import AutoTokenizer | ||||
|  | ||||
| from .openclip import OpenCLIPModelConfig | ||||
| from .openclip import to_onnx as openclip_to_onnx | ||||
| from .optimize import optimize | ||||
| from .util import get_model_path | ||||
|  | ||||
| _MCLIP_TO_OPENCLIP = { | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"), | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"), | ||||
|     "M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"), | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"), | ||||
| } | ||||
|  | ||||
|  | ||||
| def to_onnx( | ||||
|     model_name: str, | ||||
|     output_dir_visual: Path | str, | ||||
|     output_dir_textual: Path | str, | ||||
| ) -> None: | ||||
|     textual_path = get_model_path(output_dir_textual) | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir) | ||||
|         AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual) | ||||
|  | ||||
|         for param in model.parameters(): | ||||
|             param.requires_grad_(False) | ||||
|  | ||||
|         export_text_encoder(model, textual_path) | ||||
|         openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual) | ||||
|         optimize(textual_path) | ||||
|  | ||||
|  | ||||
| def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None: | ||||
|     output_path = Path(output_path) | ||||
|  | ||||
|     def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||||
|         embs = self.transformer(input_ids, attention_mask)[0] | ||||
|         embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] | ||||
|         embs = self.LinearTransformation(embs) | ||||
|         return torch.nn.functional.normalize(embs, dim=-1) | ||||
|  | ||||
|     # unfortunately need to monkeypatch for tracing to work here | ||||
|     # otherwise it hits the 2GiB protobuf serialization limit | ||||
|     MultilingualCLIP.forward = forward | ||||
|  | ||||
|     args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32)) | ||||
|     with warnings.catch_warnings(): | ||||
|         warnings.simplefilter("ignore", UserWarning) | ||||
|         torch.onnx.export( | ||||
|             model, | ||||
|             args, | ||||
|             output_path.as_posix(), | ||||
|             input_names=["input_ids", "attention_mask"], | ||||
|             output_names=["text_embedding"], | ||||
|             opset_version=17, | ||||
|             dynamic_axes={ | ||||
|                 "input_ids": {0: "batch_size", 1: "sequence_length"}, | ||||
|                 "attention_mask": {0: "batch_size", 1: "sequence_length"}, | ||||
|             }, | ||||
|         ) | ||||
							
								
								
									
										109
									
								
								machine-learning/export/models/openclip.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								machine-learning/export/models/openclip.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| import tempfile | ||||
| import warnings | ||||
| from dataclasses import dataclass, field | ||||
| from pathlib import Path | ||||
|  | ||||
| import open_clip | ||||
| import torch | ||||
| from transformers import AutoTokenizer | ||||
|  | ||||
| from .optimize import optimize | ||||
| from .util import get_model_path, save_config | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class OpenCLIPModelConfig: | ||||
|     name: str | ||||
|     pretrained: str | ||||
|     image_size: int = field(init=False) | ||||
|     sequence_length: int = field(init=False) | ||||
|  | ||||
|     def __post_init__(self) -> None: | ||||
|         open_clip_cfg = open_clip.get_model_config(self.name) | ||||
|         if open_clip_cfg is None: | ||||
|             raise ValueError(f"Unknown model {self.name}") | ||||
|         self.image_size = open_clip_cfg["vision_cfg"]["image_size"] | ||||
|         self.sequence_length = open_clip_cfg["text_cfg"]["context_length"] | ||||
|  | ||||
|  | ||||
| def to_onnx( | ||||
|     model_cfg: OpenCLIPModelConfig, | ||||
|     output_dir_visual: Path | str | None = None, | ||||
|     output_dir_textual: Path | str | None = None, | ||||
| ) -> None: | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         model = open_clip.create_model( | ||||
|             model_cfg.name, | ||||
|             pretrained=model_cfg.pretrained, | ||||
|             jit=False, | ||||
|             cache_dir=tmpdir, | ||||
|             require_pretrained=True, | ||||
|         ) | ||||
|  | ||||
|         text_vision_cfg = open_clip.get_model_config(model_cfg.name) | ||||
|  | ||||
|         for param in model.parameters(): | ||||
|             param.requires_grad_(False) | ||||
|  | ||||
|         if output_dir_visual is not None: | ||||
|             output_dir_visual = Path(output_dir_visual) | ||||
|             visual_path = get_model_path(output_dir_visual) | ||||
|  | ||||
|             save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json") | ||||
|             save_config(text_vision_cfg, output_dir_visual.parent / "config.json") | ||||
|             export_image_encoder(model, model_cfg, visual_path) | ||||
|  | ||||
|             optimize(visual_path) | ||||
|  | ||||
|         if output_dir_textual is not None: | ||||
|             output_dir_textual = Path(output_dir_textual) | ||||
|             textual_path = get_model_path(output_dir_textual) | ||||
|  | ||||
|             tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32") | ||||
|             AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual) | ||||
|             export_text_encoder(model, model_cfg, textual_path) | ||||
|             optimize(textual_path) | ||||
|  | ||||
|  | ||||
| def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None: | ||||
|     output_path = Path(output_path) | ||||
|  | ||||
|     def encode_image(image: torch.Tensor) -> torch.Tensor: | ||||
|         return model.encode_image(image, normalize=True) | ||||
|  | ||||
|     args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),) | ||||
|     traced = torch.jit.trace(encode_image, args) | ||||
|  | ||||
|     with warnings.catch_warnings(): | ||||
|         warnings.simplefilter("ignore", UserWarning) | ||||
|         torch.onnx.export( | ||||
|             traced, | ||||
|             args, | ||||
|             output_path.as_posix(), | ||||
|             input_names=["image"], | ||||
|             output_names=["image_embedding"], | ||||
|             opset_version=17, | ||||
|             dynamic_axes={"image": {0: "batch_size"}}, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None: | ||||
|     output_path = Path(output_path) | ||||
|  | ||||
|     def encode_text(text: torch.Tensor) -> torch.Tensor: | ||||
|         return model.encode_text(text, normalize=True) | ||||
|  | ||||
|     args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),) | ||||
|     traced = torch.jit.trace(encode_text, args) | ||||
|  | ||||
|     with warnings.catch_warnings(): | ||||
|         warnings.simplefilter("ignore", UserWarning) | ||||
|         torch.onnx.export( | ||||
|             traced, | ||||
|             args, | ||||
|             output_path.as_posix(), | ||||
|             input_names=["text"], | ||||
|             output_names=["text_embedding"], | ||||
|             opset_version=17, | ||||
|             dynamic_axes={"text": {0: "batch_size"}}, | ||||
|         ) | ||||
							
								
								
									
										38
									
								
								machine-learning/export/models/optimize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								machine-learning/export/models/optimize.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| from pathlib import Path | ||||
|  | ||||
| import onnx | ||||
| import onnxruntime as ort | ||||
| import onnxsim | ||||
|  | ||||
|  | ||||
| def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None: | ||||
|     model_path = Path(model_path) | ||||
|     output_path = Path(output_path) | ||||
|     model = onnx.load(model_path.as_posix()) | ||||
|     model, check = onnxsim.simplify(model, skip_shape_inference=True) | ||||
|     assert check, "Simplified ONNX model could not be validated" | ||||
|     onnx.save(model, output_path.as_posix()) | ||||
|  | ||||
|  | ||||
| def optimize_ort( | ||||
|     model_path: Path | str, | ||||
|     output_path: Path | str, | ||||
|     level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, | ||||
| ) -> None: | ||||
|     model_path = Path(model_path) | ||||
|     output_path = Path(output_path) | ||||
|  | ||||
|     sess_options = ort.SessionOptions() | ||||
|     sess_options.graph_optimization_level = level | ||||
|     sess_options.optimized_model_filepath = output_path.as_posix() | ||||
|  | ||||
|     ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options) | ||||
|  | ||||
|  | ||||
| def optimize(model_path: Path | str) -> None: | ||||
|     model_path = Path(model_path) | ||||
|  | ||||
|     optimize_ort(model_path, model_path) | ||||
|     # onnxsim serializes large models as a blob, which uses much more memory when loading the model at runtime | ||||
|     if not any(file.name.startswith("Constant") for file in model_path.parent.iterdir()): | ||||
|         optimize_onnxsim(model_path, model_path) | ||||
							
								
								
									
										15
									
								
								machine-learning/export/models/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								machine-learning/export/models/util.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| import json | ||||
| from pathlib import Path | ||||
| from typing import Any | ||||
|  | ||||
|  | ||||
| def get_model_path(output_dir: Path | str) -> Path: | ||||
|     output_dir = Path(output_dir) | ||||
|     output_dir.mkdir(parents=True, exist_ok=True) | ||||
|     return output_dir / "model.onnx" | ||||
|  | ||||
|  | ||||
| def save_config(config: Any, output_path: Path | str) -> None: | ||||
|     output_path = Path(output_path) | ||||
|     output_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|     json.dump(config, output_path.open("w")) | ||||
							
								
								
									
										76
									
								
								machine-learning/export/run.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								machine-learning/export/run.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| import gc | ||||
| import os | ||||
| from pathlib import Path | ||||
| from tempfile import TemporaryDirectory | ||||
|  | ||||
| from huggingface_hub import create_repo, login, upload_folder | ||||
| from models import mclip, openclip | ||||
| from rich.progress import Progress | ||||
|  | ||||
| models = [ | ||||
|     "RN50::openai", | ||||
|     "RN50::yfcc15m", | ||||
|     "RN50::cc12m", | ||||
|     "RN101::openai", | ||||
|     "RN101::yfcc15m", | ||||
|     "RN50x4::openai", | ||||
|     "RN50x16::openai", | ||||
|     "RN50x64::openai", | ||||
|     "ViT-B-32::openai", | ||||
|     "ViT-B-32::laion2b_e16", | ||||
|     "ViT-B-32::laion400m_e31", | ||||
|     "ViT-B-32::laion400m_e32", | ||||
|     "ViT-B-32::laion2b-s34b-b79k", | ||||
|     "ViT-B-16::openai", | ||||
|     "ViT-B-16::laion400m_e31", | ||||
|     "ViT-B-16::laion400m_e32", | ||||
|     "ViT-B-16-plus-240::laion400m_e31", | ||||
|     "ViT-B-16-plus-240::laion400m_e32", | ||||
|     "ViT-L-14::openai", | ||||
|     "ViT-L-14::laion400m_e31", | ||||
|     "ViT-L-14::laion400m_e32", | ||||
|     "ViT-L-14::laion2b-s32b-b82k", | ||||
|     "ViT-L-14-336::openai", | ||||
|     "ViT-H-14::laion2b-s32b-b79k", | ||||
|     "ViT-g-14::laion2b-s12b-b42k", | ||||
|     "M-CLIP/LABSE-Vit-L-14", | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-B-32", | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus", | ||||
|     "M-CLIP/XLM-Roberta-Large-Vit-L-14", | ||||
| ] | ||||
|  | ||||
| login(token=os.environ["HF_AUTH_TOKEN"]) | ||||
|  | ||||
| with Progress() as progress: | ||||
|     task1 = progress.add_task("[green]Exporting models...", total=len(models)) | ||||
|     task2 = progress.add_task("[yellow]Uploading models...", total=len(models)) | ||||
|  | ||||
|     with TemporaryDirectory() as tmp: | ||||
|         tmpdir = Path(tmp) | ||||
|         for model in models: | ||||
|             model_name = model.split("/")[-1].replace("::", "__") | ||||
|             config_path = tmpdir / model_name / "config.json" | ||||
|  | ||||
|             def upload() -> None: | ||||
|                 progress.update(task2, description=f"[yellow]Uploading {model_name}") | ||||
|                 repo_id = f"immich-app/{model_name}" | ||||
|  | ||||
|                 create_repo(repo_id, exist_ok=True) | ||||
|                 upload_folder(repo_id=repo_id, folder_path=tmpdir / model_name) | ||||
|                 progress.update(task2, advance=1) | ||||
|  | ||||
|             def export() -> None: | ||||
|                 progress.update(task1, description=f"[green]Exporting {model_name}") | ||||
|                 visual_dir = tmpdir / model_name / "visual" | ||||
|                 textual_dir = tmpdir / model_name / "textual" | ||||
|                 if model.startswith("M-CLIP"): | ||||
|                     mclip.to_onnx(model, visual_dir, textual_dir) | ||||
|                 else: | ||||
|                     name, _, pretrained = model_name.partition("__") | ||||
|                     openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir) | ||||
|  | ||||
|                 progress.update(task1, advance=1) | ||||
|                 gc.collect() | ||||
|  | ||||
|             export() | ||||
|             upload() | ||||
| @@ -1,11 +1,12 @@ | ||||
| from io import BytesIO | ||||
| import json | ||||
| from argparse import ArgumentParser | ||||
| from io import BytesIO | ||||
| from typing import Any | ||||
|  | ||||
| from locust import HttpUser, events, task | ||||
| from locust.env import Environment | ||||
| from PIL import Image | ||||
| from argparse import ArgumentParser | ||||
|  | ||||
| byte_image = BytesIO() | ||||
|  | ||||
|  | ||||
| @@ -14,11 +15,21 @@ def _(parser: ArgumentParser) -> None: | ||||
|     parser.add_argument("--tag-model", type=str, default="microsoft/resnet-50") | ||||
|     parser.add_argument("--clip-model", type=str, default="ViT-B-32::openai") | ||||
|     parser.add_argument("--face-model", type=str, default="buffalo_l") | ||||
|     parser.add_argument("--tag-min-score", type=int, default=0.0,  | ||||
|                         help="Returns all tags at or above this score. The default returns all tags.") | ||||
|     parser.add_argument("--face-min-score", type=int, default=0.034,  | ||||
|                         help=("Returns all faces at or above this score. The default returns 1 face per request; " | ||||
|                               "setting this to 0 blows up the number of faces to the thousands.")) | ||||
|     parser.add_argument( | ||||
|         "--tag-min-score", | ||||
|         type=int, | ||||
|         default=0.0, | ||||
|         help="Returns all tags at or above this score. The default returns all tags.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--face-min-score", | ||||
|         type=int, | ||||
|         default=0.034, | ||||
|         help=( | ||||
|             "Returns all faces at or above this score. The default returns 1 face per request; " | ||||
|             "setting this to 0 blows up the number of faces to the thousands." | ||||
|         ), | ||||
|     ) | ||||
|     parser.add_argument("--image-size", type=int, default=1000) | ||||
|  | ||||
|  | ||||
| @@ -62,7 +73,7 @@ class CLIPTextFormDataLoadTest(InferenceLoadTest): | ||||
|             ("modelName", self.environment.parsed_options.clip_model), | ||||
|             ("modelType", "clip"), | ||||
|             ("options", json.dumps({"mode": "text"})), | ||||
|             ("text", "test search query") | ||||
|             ("text", "test search query"), | ||||
|         ] | ||||
|         self.client.post("/predict", data=data) | ||||
|  | ||||
| @@ -88,5 +99,5 @@ class RecognitionFormDataLoadTest(InferenceLoadTest): | ||||
|             ("options", json.dumps({"minScore": self.environment.parsed_options.face_min_score})), | ||||
|         ] | ||||
|         files = {"image": self.data} | ||||
|              | ||||
|  | ||||
|         self.client.post("/predict", data=data, files=files) | ||||
|   | ||||
							
								
								
									
										3875
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										3875
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -9,8 +9,8 @@ packages = [{include = "app"}] | ||||
| [tool.poetry.dependencies] | ||||
| python = "^3.11" | ||||
| torch = [ | ||||
|     {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=2.0.1", source = "pypi"}, | ||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"} | ||||
|     {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=2.1.0", source = "pypi"}, | ||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.1.0", source = "pytorch-cpu"} | ||||
| ] | ||||
| transformers = "^4.29.2" | ||||
| onnxruntime = "^1.15.0" | ||||
| @@ -22,14 +22,9 @@ uvicorn = {extras = ["standard"], version = "^0.22.0"} | ||||
| pydantic = "^1.10.8" | ||||
| aiocache = "^0.12.1" | ||||
| optimum = "^1.9.1" | ||||
| torchvision = [ | ||||
|     {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"}, | ||||
|     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"} | ||||
| ] | ||||
| rich = "^13.4.2" | ||||
| ftfy = "^6.1.1" | ||||
| setuptools = "^68.0.0" | ||||
| open-clip-torch = "^2.20.0" | ||||
| python-multipart = "^0.0.6" | ||||
| orjson = "^3.9.5" | ||||
| safetensors = "0.3.2" | ||||
| @@ -63,6 +58,7 @@ warn_redundant_casts = true | ||||
| disallow_any_generics = true | ||||
| check_untyped_defs = true | ||||
| disallow_untyped_defs = true | ||||
| ignore_missing_imports = true | ||||
|  | ||||
| [tool.pydantic-mypy] | ||||
| init_forbid_extra = true | ||||
| @@ -70,30 +66,6 @@ init_typed = true | ||||
| warn_required_dynamic_aliases = true | ||||
| warn_untyped_fields = true | ||||
|  | ||||
| [[tool.mypy.overrides]] | ||||
| module = [ | ||||
|     "huggingface_hub", | ||||
|     "transformers", | ||||
|     "gunicorn", | ||||
|     "cv2", | ||||
|     "insightface.model_zoo", | ||||
|     "insightface.utils.face_align", | ||||
|     "insightface.utils.storage", | ||||
|     "onnxruntime", | ||||
|     "optimum", | ||||
|     "optimum.pipelines", | ||||
|     "optimum.onnxruntime", | ||||
|     "clip_server.model.clip", | ||||
|     "clip_server.model.clip_onnx", | ||||
|     "clip_server.model.pretrained_models", | ||||
|     "clip_server.model.tokenization", | ||||
|     "torchvision.transforms", | ||||
|     "aiocache.backends.memory", | ||||
|     "aiocache.lock", | ||||
|     "aiocache.plugins" | ||||
| ] | ||||
| ignore_missing_imports = true | ||||
|  | ||||
| [tool.ruff] | ||||
| line-length = 120 | ||||
| target-version = "py311" | ||||
|   | ||||
| @@ -1,2 +0,0 @@ | ||||
| # requirements to be installed with `--no-deps` flag | ||||
| clip-server==0.8.* | ||||
		Reference in New Issue
	
	Block a user