mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	refactor(ml): modularization and styling (#2835)
* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
		
							
								
								
									
										92
									
								
								machine-learning/app/models/cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								machine-learning/app/models/cache.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| import asyncio | ||||
|  | ||||
| from aiocache.backends.memory import SimpleMemoryCache | ||||
| from aiocache.lock import OptimisticLock | ||||
| from aiocache.plugins import BasePlugin, TimingPlugin | ||||
|  | ||||
| from ..schemas import ModelType | ||||
| from .base import InferenceModel | ||||
|  | ||||
|  | ||||
| class ModelCache: | ||||
|     """Fetches a model from an in-memory cache, instantiating it if it's missing.""" | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         ttl: float | None = None, | ||||
|         revalidate: bool = False, | ||||
|         timeout: int | None = None, | ||||
|         profiling: bool = False, | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             ttl: Unloads model after this duration. Disabled if None. Defaults to None. | ||||
|             revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False. | ||||
|             timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None. | ||||
|             profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False. | ||||
|         """ | ||||
|  | ||||
|         self.ttl = ttl | ||||
|         plugins = [] | ||||
|  | ||||
|         if revalidate: | ||||
|             plugins.append(RevalidationPlugin()) | ||||
|         if profiling: | ||||
|             plugins.append(TimingPlugin()) | ||||
|  | ||||
|         self.cache = SimpleMemoryCache( | ||||
|             ttl=ttl, timeout=timeout, plugins=plugins, namespace=None | ||||
|         ) | ||||
|  | ||||
|     async def get( | ||||
|         self, model_name: str, model_type: ModelType, **model_kwargs | ||||
|     ) -> InferenceModel: | ||||
|         """ | ||||
|         Args: | ||||
|             model_name: Name of model in the model hub used for the task. | ||||
|             model_type: Model type or task, which determines which model zoo is used. | ||||
|  | ||||
|         Returns: | ||||
|             model: The requested model. | ||||
|         """ | ||||
|  | ||||
|         key = self.cache.build_key(model_name, model_type.value) | ||||
|         model = await self.cache.get(key) | ||||
|         if model is None: | ||||
|             async with OptimisticLock(self.cache, key) as lock: | ||||
|                 model = await asyncio.get_running_loop().run_in_executor( | ||||
|                     None, | ||||
|                     lambda: InferenceModel.from_model_type( | ||||
|                         model_type, model_name, **model_kwargs | ||||
|                     ), | ||||
|                 ) | ||||
|                 await lock.cas(model, ttl=self.ttl) | ||||
|         return model | ||||
|  | ||||
|     async def get_profiling(self) -> dict[str, float] | None: | ||||
|         if not hasattr(self.cache, "profiling"): | ||||
|             return None | ||||
|  | ||||
|         return self.cache.profiling  # type: ignore | ||||
|  | ||||
|  | ||||
| class RevalidationPlugin(BasePlugin): | ||||
|     """Revalidates cache item's TTL after cache hit.""" | ||||
|  | ||||
|     async def post_get(self, client, key, ret=None, namespace=None, **kwargs): | ||||
|         if ret is None: | ||||
|             return | ||||
|         if namespace is not None: | ||||
|             key = client.build_key(key, namespace) | ||||
|         if key in client._handlers: | ||||
|             await client.expire(key, client.ttl) | ||||
|  | ||||
|     async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs): | ||||
|         if ret is None: | ||||
|             return | ||||
|  | ||||
|         for key, val in zip(keys, ret): | ||||
|             if namespace is not None: | ||||
|                 key = client.build_key(key, namespace) | ||||
|             if val is not None and key in client._handlers: | ||||
|                 await client.expire(key, client.ttl) | ||||
		Reference in New Issue
	
	Block a user