mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	chore(ml): move to fastAPI (#2336)
This commit is contained in:
		| @@ -1,14 +1,15 @@ | ||||
| FROM python:3.10 as builder | ||||
|  | ||||
| ENV PYTHONDONTWRITEBYTECODE=1 \ | ||||
|     PYTHONUNBUFFERED=1 \ | ||||
|     PIP_NO_CACHE_DIR=true | ||||
|   PYTHONUNBUFFERED=1 \ | ||||
|   PIP_NO_CACHE_DIR=true | ||||
|  | ||||
| RUN python -m venv /opt/venv | ||||
| RUN /opt/venv/bin/pip install --pre torch  -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html | ||||
| RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece flask Pillow gunicorn | ||||
| RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece fastapi Pillow uvicorn[standard] | ||||
| RUN /opt/venv/bin/pip install --no-deps sentence-transformers | ||||
|  | ||||
|  | ||||
| FROM python:3.10-slim | ||||
|  | ||||
| ENV NODE_ENV=production | ||||
| @@ -16,12 +17,12 @@ ENV NODE_ENV=production | ||||
| COPY --from=builder /opt/venv /opt/venv | ||||
|  | ||||
| ENV TRANSFORMERS_CACHE=/cache \ | ||||
|     PYTHONDONTWRITEBYTECODE=1 \ | ||||
|     PYTHONUNBUFFERED=1 \ | ||||
|     PATH="/opt/venv/bin:$PATH" | ||||
|   PYTHONDONTWRITEBYTECODE=1 \ | ||||
|   PYTHONUNBUFFERED=1 \ | ||||
|   PATH="/opt/venv/bin:$PATH" | ||||
|  | ||||
| WORKDIR /usr/src/app | ||||
|  | ||||
| COPY . . | ||||
|  | ||||
| CMD ["gunicorn", "src.main:server"] | ||||
| ENV PYTHONPATH=`pwd` | ||||
| CMD ["python", "main.py"] | ||||
| @@ -1,29 +0,0 @@ | ||||
| """ | ||||
| Gunicorn configuration options. | ||||
| https://docs.gunicorn.org/en/stable/settings.html | ||||
| """ | ||||
| import os | ||||
|  | ||||
|  | ||||
| # Set the bind address based on the env | ||||
| port = os.getenv("MACHINE_LEARNING_PORT") or "3003" | ||||
| listen_ip = os.getenv("MACHINE_LEARNING_IP") or "0.0.0.0" | ||||
| bind = [f"{listen_ip}:{port}"] | ||||
|  | ||||
| # Preload the Flask app / models etc. before starting the server | ||||
| preload_app = True | ||||
|  | ||||
| # Logging settings - log to stdout and set log level | ||||
| accesslog = "-" | ||||
| loglevel = os.getenv("MACHINE_LEARNING_LOG_LEVEL") or "info" | ||||
|  | ||||
| # Worker settings | ||||
| # ---------------------- | ||||
| # It is important these are chosen carefully as per | ||||
| # https://pythonspeed.com/articles/gunicorn-in-docker/ | ||||
| # Otherwise we get workers failing to respond to heartbeat checks, | ||||
| # especially as requests take a long time to complete. | ||||
| workers = 2 | ||||
| threads = 4 | ||||
| worker_tmp_dir = "/dev/shm" | ||||
| timeout = 60 | ||||
| @@ -1,58 +1,77 @@ | ||||
| import os | ||||
| from flask import Flask, request | ||||
| from transformers import pipeline | ||||
| from sentence_transformers import SentenceTransformer, util | ||||
| from PIL import Image | ||||
| from fastapi import FastAPI | ||||
| import uvicorn | ||||
| import os | ||||
| from pydantic import BaseModel | ||||
|  | ||||
|  | ||||
| class MlRequestBody(BaseModel): | ||||
|     thumbnailPath: str | ||||
|  | ||||
|  | ||||
| class ClipRequestBody(BaseModel): | ||||
|     text: str | ||||
|  | ||||
|  | ||||
| is_dev = os.getenv('NODE_ENV') == 'development' | ||||
| server_port = os.getenv('MACHINE_LEARNING_PORT', 3003) | ||||
| server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0') | ||||
|  | ||||
| classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') | ||||
| app = FastAPI() | ||||
|  | ||||
| """ | ||||
| Model Initialization | ||||
| """ | ||||
| classification_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') | ||||
| object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny') | ||||
| clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') | ||||
| clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') | ||||
| clip_image_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') | ||||
| clip_text_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') | ||||
|  | ||||
| _model_cache = {} | ||||
| def _get_model(model, task=None): | ||||
|   global _model_cache | ||||
|   key = '|'.join([model, str(task)]) | ||||
|   if key not in _model_cache: | ||||
|     if task: | ||||
|       _model_cache[key] = pipeline(model=model, task=task) | ||||
|     else: | ||||
|       _model_cache[key] = SentenceTransformer(model) | ||||
|   return _model_cache[key] | ||||
|  | ||||
| server = Flask(__name__) | ||||
|  | ||||
| @server.route("/ping") | ||||
| @app.get("/") | ||||
| async def root(): | ||||
|     return {"message": "Immich ML"} | ||||
|  | ||||
|  | ||||
| @app.get("/ping") | ||||
| def ping(): | ||||
|     return "pong" | ||||
|  | ||||
| @server.route("/object-detection/detect-object", methods=['POST']) | ||||
| def object_detection(): | ||||
|  | ||||
| @app.post("/object-detection/detect-object", status_code=200) | ||||
| def object_detection(payload: MlRequestBody): | ||||
|     model = _get_model(object_model, 'object-detection') | ||||
|     assetPath = request.json['thumbnailPath'] | ||||
|     return run_engine(model, assetPath), 200 | ||||
|     assetPath = payload.thumbnailPath | ||||
|     return run_engine(model, assetPath) | ||||
|  | ||||
| @server.route("/image-classifier/tag-image", methods=['POST']) | ||||
| def image_classification(): | ||||
|  | ||||
| @app.post("/image-classifier/tag-image", status_code=200) | ||||
| def image_classification(payload: MlRequestBody): | ||||
|     model = _get_model(classification_model, 'image-classification') | ||||
|     assetPath = request.json['thumbnailPath'] | ||||
|     return run_engine(model, assetPath), 200 | ||||
|     assetPath = payload.thumbnailPath | ||||
|     return run_engine(model, assetPath) | ||||
|  | ||||
| @server.route("/sentence-transformer/encode-image", methods=['POST']) | ||||
| def clip_encode_image(): | ||||
|  | ||||
| @app.post("/sentence-transformer/encode-image", status_code=200) | ||||
| def clip_encode_image(payload: MlRequestBody): | ||||
|     model = _get_model(clip_image_model) | ||||
|     assetPath = request.json['thumbnailPath'] | ||||
|     return model.encode(Image.open(assetPath)).tolist(), 200 | ||||
|     assetPath = payload.thumbnailPath | ||||
|     return model.encode(Image.open(assetPath)).tolist() | ||||
|  | ||||
| @server.route("/sentence-transformer/encode-text", methods=['POST']) | ||||
| def clip_encode_text(): | ||||
|  | ||||
| @app.post("/sentence-transformer/encode-text", status_code=200) | ||||
| def clip_encode_text(payload: ClipRequestBody): | ||||
|     model = _get_model(clip_text_model) | ||||
|     text = request.json['text'] | ||||
|     return model.encode(text).tolist(), 200 | ||||
|     text = payload.text | ||||
|     return model.encode(text).tolist() | ||||
|  | ||||
|  | ||||
| def run_engine(engine, path): | ||||
|     result = [] | ||||
| @@ -69,5 +88,17 @@ def run_engine(engine, path): | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def _get_model(model, task=None): | ||||
|     global _model_cache | ||||
|     key = '|'.join([model, str(task)]) | ||||
|     if key not in _model_cache: | ||||
|         if task: | ||||
|             _model_cache[key] = pipeline(model=model, task=task) | ||||
|         else: | ||||
|             _model_cache[key] = SentenceTransformer(model) | ||||
|     return _model_cache[key] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     server.run(debug=is_dev, host=server_host, port=server_port) | ||||
|     uvicorn.run("main:app", host=server_host, | ||||
|                 port=int(server_port), reload=is_dev, workers=1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user