mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	fix(ml): clear model cache on load error (#2951)
* clear model cache on load error * updated caught exceptions
This commit is contained in:
		| @@ -2,8 +2,11 @@ from __future__ import annotations | ||||
|  | ||||
| from abc import ABC, abstractmethod | ||||
| from pathlib import Path | ||||
| from shutil import rmtree | ||||
| from typing import Any | ||||
|  | ||||
| from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf | ||||
|  | ||||
| from ..config import get_cache_dir | ||||
| from ..schemas import ModelType | ||||
|  | ||||
| @@ -12,10 +15,8 @@ class InferenceModel(ABC): | ||||
|     _model_type: ModelType | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         cache_dir: Path | None = None, | ||||
|     ): | ||||
|         self, model_name: str, cache_dir: Path | None = None, **model_kwargs | ||||
|     ) -> None: | ||||
|         self.model_name = model_name | ||||
|         self._cache_dir = ( | ||||
|             cache_dir | ||||
| @@ -23,6 +24,16 @@ class InferenceModel(ABC): | ||||
|             else get_cache_dir(model_name, self.model_type) | ||||
|         ) | ||||
|  | ||||
|         try: | ||||
|             self.load(**model_kwargs) | ||||
|         except (OSError, InvalidProtobuf): | ||||
|             self.clear_cache() | ||||
|             self.load(**model_kwargs) | ||||
|  | ||||
|     @abstractmethod | ||||
|     def load(self, **model_kwargs: Any) -> None: | ||||
|         ... | ||||
|  | ||||
|     @abstractmethod | ||||
|     def predict(self, inputs: Any) -> Any: | ||||
|         ... | ||||
| @@ -36,7 +47,7 @@ class InferenceModel(ABC): | ||||
|         return self._cache_dir | ||||
|  | ||||
|     @cache_dir.setter | ||||
|     def cache_dir(self, cache_dir: Path): | ||||
|     def cache_dir(self, cache_dir: Path) -> None: | ||||
|         self._cache_dir = cache_dir | ||||
|  | ||||
|     @classmethod | ||||
| @@ -50,3 +61,13 @@ class InferenceModel(ABC): | ||||
|             raise ValueError(f"Unsupported model type: {model_type}") | ||||
|  | ||||
|         return subclasses[model_type](model_name, **model_kwargs) | ||||
|  | ||||
|     def clear_cache(self) -> None: | ||||
|         if not self.cache_dir.exists(): | ||||
|             return | ||||
|         elif not rmtree.avoids_symlink_attacks: | ||||
|             raise RuntimeError( | ||||
|                 "Attempted to clear cache, but rmtree is not safe on this platform." | ||||
|             ) | ||||
|  | ||||
|         rmtree(self.cache_dir) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user