feat(ml)!: switch image classification and CLIP models to ONNX (#3809)

This commit is contained in:
Mert
2023-08-25 00:28:51 -04:00
committed by GitHub
parent 8211afb726
commit 165b91b068
14 changed files with 1617 additions and 507 deletions

View File

@@ -1,3 +1,3 @@
from .clip import CLIPSTEncoder
from .clip import CLIPEncoder
from .facial_recognition import FaceRecognizer
from .image_classification import ImageClassifier

View File

@@ -1,14 +1,17 @@
from __future__ import annotations
import os
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any
from zipfile import BadZipFile
import onnxruntime as ort
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
from ..config import get_cache_dir
from ..config import get_cache_dir, settings
from ..schemas import ModelType
@@ -16,12 +19,31 @@ class InferenceModel(ABC):
_model_type: ModelType
def __init__(
self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
self,
model_name: str,
cache_dir: Path | str | None = None,
eager: bool = True,
inter_op_num_threads: int = settings.model_inter_op_threads,
intra_op_num_threads: int = settings.model_intra_op_threads,
**model_kwargs: Any,
) -> None:
self.model_name = model_name
self._loaded = False
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
loader = self.load if eager else self.download
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
# don't pre-allocate more memory than needed
self.provider_options = model_kwargs.pop(
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
)
self.sess_options = PicklableSessionOptions()
# avoid thread contention between models
if inter_op_num_threads > 1:
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
self.sess_options.inter_op_num_threads = inter_op_num_threads
self.sess_options.intra_op_num_threads = intra_op_num_threads
try:
loader(**model_kwargs)
except (OSError, InvalidProtobuf, BadZipFile):
@@ -30,6 +52,7 @@ class InferenceModel(ABC):
def download(self, **model_kwargs: Any) -> None:
if not self.cached:
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
self._download(**model_kwargs)
def load(self, **model_kwargs: Any) -> None:
@@ -39,6 +62,7 @@ class InferenceModel(ABC):
def predict(self, inputs: Any) -> Any:
if not self._loaded:
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
self.load()
return self._predict(inputs)
@@ -89,3 +113,14 @@ class InferenceModel(ABC):
else:
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions):
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Any) -> None:
self.__init__() # type: ignore
for attr, val in pickle.loads(state):
setattr(self, attr, val)

View File

@@ -46,7 +46,7 @@ class ModelCache:
model: The requested model.
"""
key = self.cache.build_key(model_name, model_type.value)
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
if model is None:

View File

@@ -1,31 +1,141 @@
from typing import Any
import os
import zipfile
from typing import Any, Literal
import onnxruntime as ort
import torch
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
from clip_server.model.tokenization import Tokenizer
from PIL.Image import Image
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import snapshot_download
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from ..schemas import ModelType
from .base import InferenceModel
_ST_TO_JINA_MODEL_NAME = {
"clip-ViT-B-16": "ViT-B-16::openai",
"clip-ViT-B-32": "ViT-B-32::openai",
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
"clip-ViT-L-14": "ViT-L-14::openai",
}
class CLIPSTEncoder(InferenceModel):
class CLIPEncoder(InferenceModel):
_model_type = ModelType.CLIP
def __init__(
self,
model_name: str,
cache_dir: str | None = None,
mode: Literal["text", "vision"] | None = None,
**model_kwargs: Any,
) -> None:
if mode is not None and mode not in ("text", "vision"):
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
if "vit-b" not in model_name.lower():
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
self.mode = mode
jina_model_name = self._get_jina_model_name(model_name)
super().__init__(jina_model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}"
snapshot_download(
cache_dir=self.cache_dir,
repo_id=repo_id,
library_name="sentence-transformers",
ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"],
)
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
text_onnx_path = self.cache_dir / "textual.onnx"
vision_onnx_path = self.cache_dir / "visual.onnx"
if not text_onnx_path.is_file():
self._download_model(*models[0])
if not vision_onnx_path.is_file():
self._download_model(*models[1])
def _load(self, **model_kwargs: Any) -> None:
self.model = SentenceTransformer(
self.model_name,
cache_folder=self.cache_dir.as_posix(),
**model_kwargs,
)
if self.mode == "text" or self.mode is None:
self.text_model = ort.InferenceSession(
self.cache_dir / "textual.onnx",
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.text_outputs = [output.name for output in self.text_model.get_outputs()]
self.tokenizer = Tokenizer(self.model_name)
if self.mode == "vision" or self.mode is None:
self.vision_model = ort.InferenceSession(
self.cache_dir / "visual.onnx",
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
self.transform = _transform_pil_image(image_size)
def _predict(self, image_or_text: Image | str) -> list[float]:
return self.model.encode(image_or_text).tolist()
match image_or_text:
case Image():
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
pixel_values = self.transform(image_or_text)
assert isinstance(pixel_values, torch.Tensor)
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
inputs = {
"input_ids": text_inputs["input_ids"].int().numpy(),
"attention_mask": text_inputs["attention_mask"].int().numpy(),
}
outputs = self.text_model.run(self.text_outputs, inputs)
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist()
def _get_jina_model_name(self, model_name: str) -> str:
if model_name in _MODELS:
return model_name
elif model_name in _ST_TO_JINA_MODEL_NAME:
print(
(f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
(f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
)
return _ST_TO_JINA_MODEL_NAME[model_name]
else:
raise ValueError(f"Unknown model name {model_name}.")
def _download_model(self, model_name: str, model_md5: str) -> bool:
# downloading logic is adapted from clip-server's CLIPOnnxModel class
download_model(
url=_S3_BUCKET_V2 + model_name,
target_folder=self.cache_dir.as_posix(),
md5sum=model_md5,
with_resume=True,
)
file = self.cache_dir / model_name.split("/")[1]
if file.suffix == ".zip":
with zipfile.ZipFile(file, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
os.remove(file)
return True
# same as `_transform_blob` without `_blob2image`
def _transform_pil_image(n_px: int) -> Compose:
return Compose(
[
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)

View File

@@ -4,6 +4,7 @@ from typing import Any
import cv2
import numpy as np
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop
from insightface.utils.storage import BASE_REPO_URL, download_file
@@ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel):
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
except StopIteration:
raise FileNotFoundError("Facial recognition models not found in cache directory")
self.det_model = RetinaFace(det_file.as_posix())
self.rec_model = ArcFaceONNX(rec_file.as_posix())
self.det_model = RetinaFace(
session=ort.InferenceSession(
det_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.rec_model = ArcFaceONNX(
rec_file.as_posix(),
session=ort.InferenceSession(
rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.det_model.prepare(
ctx_id=-1,
ctx_id=0,
det_thresh=self.min_score,
input_size=(640, 640),
)
self.rec_model.prepare(ctx_id=-1)
self.rec_model.prepare(ctx_id=0)
def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
bboxes, kpss = self.det_model.detect(image)

View File

@@ -2,8 +2,10 @@ from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from optimum.onnxruntime import ORTModelForImageClassification
from optimum.pipelines import pipeline
from PIL.Image import Image
from transformers.pipelines import pipeline
from transformers import AutoImageProcessor
from ..config import settings
from ..schemas import ModelType
@@ -25,15 +27,34 @@ class ImageClassifier(InferenceModel):
def _download(self, **model_kwargs: Any) -> None:
snapshot_download(
cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
cache_dir=self.cache_dir,
repo_id=self.model_name,
allow_patterns=["*.bin", "*.json", "*.txt"],
local_dir=self.cache_dir,
local_dir_use_symlinks=True,
)
def _load(self, **model_kwargs: Any) -> None:
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
)
processor = AutoImageProcessor.from_pretrained(self.cache_dir)
model_kwargs |= {
"cache_dir": self.cache_dir,
"provider": self.providers[0],
"provider_options": self.provider_options[0],
"session_options": self.sess_options,
}
model_path = self.cache_dir / "model.onnx"
if model_path.exists():
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
else:
self.sess_options.optimized_model_filepath = model_path.as_posix()
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs=model_kwargs,
feature_extractor=processor,
)
def _predict(self, image: Image) -> list[str]:
predictions: list[dict[str, Any]] = self.model(image) # type: ignore