fix(ml): load models in separate threads (#4034)

* load models in thread

* set clip mode logs to debug level

* updated tests

* made fixtures slightly less ugly

* moved responses to json file

* formatting
This commit is contained in:
Mert
2023-09-09 05:02:44 -04:00
committed by GitHub
parent f1db257628
commit 258b98c262
9 changed files with 1683 additions and 114 deletions

View File

@@ -5,10 +5,8 @@ from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any
from zipfile import BadZipFile
import onnxruntime as ort
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from ..config import get_cache_dir, log, settings
from ..schemas import ModelType
@@ -21,16 +19,13 @@ class InferenceModel(ABC):
self,
model_name: str,
cache_dir: Path | str | None = None,
eager: bool = True,
inter_op_num_threads: int = settings.model_inter_op_threads,
intra_op_num_threads: int = settings.model_intra_op_threads,
**model_kwargs: Any,
) -> None:
self.model_name = model_name
self._loaded = False
self.loaded = False
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
loader = self.load if eager else self.download
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
# don't pre-allocate more memory than needed
self.provider_options = model_kwargs.pop(
@@ -55,34 +50,23 @@ class InferenceModel(ABC):
self.sess_options.intra_op_num_threads = intra_op_num_threads
self.sess_options.enable_cpu_mem_arena = False
try:
loader(**model_kwargs)
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
(
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
"Clearing cache and retrying."
)
)
self.clear_cache()
loader(**model_kwargs)
def download(self, **model_kwargs: Any) -> None:
def download(self) -> None:
if not self.cached:
log.info(
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
(f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'." "This may take a while.")
)
self._download(**model_kwargs)
self._download()
def load(self, **model_kwargs: Any) -> None:
self.download(**model_kwargs)
self._load(**model_kwargs)
self._loaded = True
def load(self) -> None:
if self.loaded:
return
self.download()
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}'")
self._load()
self.loaded = True
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
if not self._loaded:
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
self.load()
self.load()
if model_kwargs:
self.configure(**model_kwargs)
return self._predict(inputs)
@@ -95,11 +79,11 @@ class InferenceModel(ABC):
pass
@abstractmethod
def _download(self, **model_kwargs: Any) -> None:
def _download(self) -> None:
...
@abstractmethod
def _load(self, **model_kwargs: Any) -> None:
def _load(self) -> None:
...
@property

View File

@@ -17,7 +17,7 @@ class ModelCache:
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
):
) -> None:
"""
Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.

View File

@@ -42,7 +42,7 @@ class CLIPEncoder(InferenceModel):
jina_model_name = self._get_jina_model_name(model_name)
super().__init__(jina_model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
def _download(self) -> None:
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
text_onnx_path = self.cache_dir / "textual.onnx"
vision_onnx_path = self.cache_dir / "visual.onnx"
@@ -53,8 +53,9 @@ class CLIPEncoder(InferenceModel):
if not vision_onnx_path.is_file():
self._download_model(*models[1])
def _load(self, **model_kwargs: Any) -> None:
def _load(self) -> None:
if self.mode == "text" or self.mode is None:
log.debug(f"Loading clip text model '{self.model_name}'")
self.text_model = ort.InferenceSession(
self.cache_dir / "textual.onnx",
sess_options=self.sess_options,
@@ -65,6 +66,7 @@ class CLIPEncoder(InferenceModel):
self.tokenizer = Tokenizer(self.model_name)
if self.mode == "vision" or self.mode is None:
log.debug(f"Loading clip vision model '{self.model_name}'")
self.vision_model = ort.InferenceSession(
self.cache_dir / "visual.onnx",
sess_options=self.sess_options,

View File

@@ -26,7 +26,7 @@ class FaceRecognizer(InferenceModel):
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
def _download(self) -> None:
zip_file = self.cache_dir / f"{self.model_name}.zip"
download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
with zipfile.ZipFile(zip_file, "r") as zip:
@@ -36,7 +36,7 @@ class FaceRecognizer(InferenceModel):
zip.extractall(self.cache_dir, members=[det_file, rec_file])
zip_file.unlink()
def _load(self, **model_kwargs: Any) -> None:
def _load(self) -> None:
try:
det_file = next(self.cache_dir.glob("det_*.onnx"))
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))

View File

@@ -26,7 +26,7 @@ class ImageClassifier(InferenceModel):
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
def _download(self) -> None:
snapshot_download(
cache_dir=self.cache_dir,
repo_id=self.model_name,
@@ -35,10 +35,10 @@ class ImageClassifier(InferenceModel):
local_dir_use_symlinks=True,
)
def _load(self, **model_kwargs: Any) -> None:
def _load(self) -> None:
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
model_path = self.cache_dir / "model.onnx"
model_kwargs |= {
model_kwargs = {
"cache_dir": self.cache_dir,
"provider": self.providers[0],
"provider_options": self.provider_options[0],