mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	chore(ml): load models on start up (#2487)
* chore(ml): load models on start up * Download correct model
This commit is contained in:
		| @@ -5,7 +5,7 @@ import uvicorn | |||||||
|  |  | ||||||
| from insightface.app import FaceAnalysis | from insightface.app import FaceAnalysis | ||||||
| from transformers import pipeline | from transformers import pipeline | ||||||
| from sentence_transformers import SentenceTransformer, util | from sentence_transformers import SentenceTransformer | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from fastapi import FastAPI | from fastapi import FastAPI | ||||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||||
| @@ -20,22 +20,32 @@ class ClipRequestBody(BaseModel): | |||||||
|  |  | ||||||
|  |  | ||||||
| classification_model = os.getenv( | classification_model = os.getenv( | ||||||
|     'MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') |     "MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50" | ||||||
| object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny') | ) | ||||||
| clip_image_model = os.getenv( | object_model = os.getenv("MACHINE_LEARNING_OBJECT_MODEL", "hustvl/yolos-tiny") | ||||||
|     'MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') | clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32") | ||||||
| clip_text_model = os.getenv( | clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32") | ||||||
|     'MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') |  | ||||||
| facial_recognition_model = os.getenv( | facial_recognition_model = os.getenv( | ||||||
|     'MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL', 'buffalo_l') |     "MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l" | ||||||
|  | ) | ||||||
|  |  | ||||||
| cache_folder = os.getenv('MACHINE_LEARNING_CACHE_FOLDER', '/cache') | cache_folder = os.getenv("MACHINE_LEARNING_CACHE_FOLDER", "/cache") | ||||||
|  |  | ||||||
| _model_cache = {} | _model_cache = {} | ||||||
|  |  | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @app.on_event("startup") | ||||||
|  | async def startup_event(): | ||||||
|  |     # Get all models | ||||||
|  |     _get_model(object_model, "object-detection") | ||||||
|  |     _get_model(classification_model, "image-classification") | ||||||
|  |     _get_model(clip_image_model) | ||||||
|  |     _get_model(clip_text_model) | ||||||
|  |     _get_model(facial_recognition_model, "facial-recognition") | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.get("/") | @app.get("/") | ||||||
| async def root(): | async def root(): | ||||||
|     return {"message": "Immich ML"} |     return {"message": "Immich ML"} | ||||||
| @@ -48,14 +58,14 @@ def ping(): | |||||||
|  |  | ||||||
| @app.post("/object-detection/detect-object", status_code=200) | @app.post("/object-detection/detect-object", status_code=200) | ||||||
| def object_detection(payload: MlRequestBody): | def object_detection(payload: MlRequestBody): | ||||||
|     model = _get_model(object_model, 'object-detection') |     model = _get_model(object_model, "object-detection") | ||||||
|     assetPath = payload.thumbnailPath |     assetPath = payload.thumbnailPath | ||||||
|     return run_engine(model, assetPath) |     return run_engine(model, assetPath) | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.post("/image-classifier/tag-image", status_code=200) | @app.post("/image-classifier/tag-image", status_code=200) | ||||||
| def image_classification(payload: MlRequestBody): | def image_classification(payload: MlRequestBody): | ||||||
|     model = _get_model(classification_model, 'image-classification') |     model = _get_model(classification_model, "image-classification") | ||||||
|     assetPath = payload.thumbnailPath |     assetPath = payload.thumbnailPath | ||||||
|     return run_engine(model, assetPath) |     return run_engine(model, assetPath) | ||||||
|  |  | ||||||
| @@ -76,31 +86,32 @@ def clip_encode_text(payload: ClipRequestBody): | |||||||
|  |  | ||||||
| @app.post("/facial-recognition/detect-faces", status_code=200) | @app.post("/facial-recognition/detect-faces", status_code=200) | ||||||
| def facial_recognition(payload: MlRequestBody): | def facial_recognition(payload: MlRequestBody): | ||||||
|     model = _get_model(facial_recognition_model, 'facial-recognition') |     model = _get_model(facial_recognition_model, "facial-recognition") | ||||||
|     assetPath = payload.thumbnailPath |     assetPath = payload.thumbnailPath | ||||||
|     img = cv.imread(assetPath) |     img = cv.imread(assetPath) | ||||||
|     height, width, _ = img.shape |     height, width, _ = img.shape | ||||||
|     results = [] |     results = [] | ||||||
|     faces = model.get(img) |     faces = model.get(img) | ||||||
|  |  | ||||||
|     for face in faces: |     for face in faces: | ||||||
|         if face.det_score < 0.7: |         if face.det_score < 0.7: | ||||||
|             continue |             continue | ||||||
|         x1, y1, x2, y2 = face.bbox |         x1, y1, x2, y2 = face.bbox | ||||||
|         # min face size as percent of original image |  | ||||||
|         # if (x2 - x1) / width < 0.03 or (y2 - y1) / height < 0.05: |         results.append( | ||||||
|         #     continue |             { | ||||||
|         results.append({ |                 "imageWidth": width, | ||||||
|             "imageWidth": width, |                 "imageHeight": height, | ||||||
|             "imageHeight": height, |                 "boundingBox": { | ||||||
|             "boundingBox": { |                     "x1": round(x1), | ||||||
|                 "x1": round(x1), |                     "y1": round(y1), | ||||||
|                 "y1": round(y1), |                     "x2": round(x2), | ||||||
|                 "x2": round(x2), |                     "y2": round(y2), | ||||||
|                 "y2": round(y2), |                 }, | ||||||
|             }, |                 "score": face.det_score.item(), | ||||||
|             "score": face.det_score.item(), |                 "embedding": face.normed_embedding.tolist(), | ||||||
|             "embedding": face.normed_embedding.tolist() |             } | ||||||
|         }) |         ) | ||||||
|     return results |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -109,11 +120,11 @@ def run_engine(engine, path): | |||||||
|     predictions = engine(path) |     predictions = engine(path) | ||||||
|  |  | ||||||
|     for index, pred in enumerate(predictions): |     for index, pred in enumerate(predictions): | ||||||
|         tags = pred['label'].split(', ') |         tags = pred["label"].split(", ") | ||||||
|         if (pred['score'] > 0.9): |         if pred["score"] > 0.9: | ||||||
|             result = [*result, *tags] |             result = [*result, *tags] | ||||||
|  |  | ||||||
|     if (len(result) > 1): |     if len(result) > 1: | ||||||
|         result = list(set(result)) |         result = list(set(result)) | ||||||
|  |  | ||||||
|     return result |     return result | ||||||
| @@ -121,25 +132,27 @@ def run_engine(engine, path): | |||||||
|  |  | ||||||
| def _get_model(model, task=None): | def _get_model(model, task=None): | ||||||
|     global _model_cache |     global _model_cache | ||||||
|     key = '|'.join([model, str(task)]) |     key = "|".join([model, str(task)]) | ||||||
|     if key not in _model_cache: |     if key not in _model_cache: | ||||||
|         if task: |         if task: | ||||||
|             if task == 'facial-recognition': |             if task == "facial-recognition": | ||||||
|                 face_model = FaceAnalysis( |                 face_model = FaceAnalysis( | ||||||
|                     name=model, root=cache_folder, allowed_modules=["detection", "recognition"]) |                     name=model, | ||||||
|  |                     root=cache_folder, | ||||||
|  |                     allowed_modules=["detection", "recognition"], | ||||||
|  |                 ) | ||||||
|                 face_model.prepare(ctx_id=0, det_size=(640, 640)) |                 face_model.prepare(ctx_id=0, det_size=(640, 640)) | ||||||
|                 _model_cache[key] = face_model |                 _model_cache[key] = face_model | ||||||
|             else: |             else: | ||||||
|                 _model_cache[key] = pipeline(model=model, task=task) |                 _model_cache[key] = pipeline(model=model, task=task) | ||||||
|         else: |         else: | ||||||
|             _model_cache[key] = SentenceTransformer( |             _model_cache[key] = SentenceTransformer(model, cache_folder=cache_folder) | ||||||
|                 model, cache_folder=cache_folder) |  | ||||||
|     return _model_cache[key] |     return _model_cache[key] | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0') |     host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0") | ||||||
|     port = int(os.getenv('MACHINE_LEARNING_PORT', 3003)) |     port = int(os.getenv("MACHINE_LEARNING_PORT", 3003)) | ||||||
|     is_dev = os.getenv('NODE_ENV') == 'development' |     is_dev = os.getenv("NODE_ENV") == "development" | ||||||
|  |  | ||||||
|     uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1) |     uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user