mirror of
				https://github.com/KevinMidboe/immich.git
				synced 2025-10-29 17:40:28 +00:00 
			
		
		
		
	feat(ml): env variables for tags, faces and eager startup (#2626)
* env variables for tags, faces and eager startup * chore(server,ml): remove object detection job and endpoint (#2627) * removed object detection job * removed object detection endpoint * env variables for tags, faces and eager startup * download without caching models if not eager * simplified `get_cached_model` * re-added env for clip text model
This commit is contained in:
		@@ -28,6 +28,12 @@ facial_recognition_model = os.getenv(
 | 
			
		||||
    "MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
min_face_score = float(os.getenv("MACHINE_LEARNING_MIN_FACE_SCORE", 0.7))
 | 
			
		||||
min_tag_score = float(os.getenv("MACHINE_LEARNING_MIN_TAG_SCORE", 0.9))
 | 
			
		||||
eager_startup = (
 | 
			
		||||
    os.getenv("MACHINE_LEARNING_EAGER_STARTUP", "true") == "true"
 | 
			
		||||
)  # loads all models at startup
 | 
			
		||||
 | 
			
		||||
cache_folder = os.getenv("MACHINE_LEARNING_CACHE_FOLDER", "/cache")
 | 
			
		||||
 | 
			
		||||
_model_cache = {}
 | 
			
		||||
@@ -37,11 +43,19 @@ app = FastAPI()
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
async def startup_event():
 | 
			
		||||
    models = [
 | 
			
		||||
        (classification_model, "image-classification"),
 | 
			
		||||
        (clip_image_model, "clip"),
 | 
			
		||||
        (clip_text_model, "clip"),
 | 
			
		||||
        (facial_recognition_model, "facial-recognition"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Get all models
 | 
			
		||||
    _get_model(classification_model, "image-classification")
 | 
			
		||||
    _get_model(clip_image_model)
 | 
			
		||||
    _get_model(clip_text_model)
 | 
			
		||||
    _get_model(facial_recognition_model, "facial-recognition")
 | 
			
		||||
    for model_name, model_type in models:
 | 
			
		||||
        if eager_startup:
 | 
			
		||||
            get_cached_model(model_name, model_type)
 | 
			
		||||
        else:
 | 
			
		||||
            _get_model(model_name, model_type)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/")
 | 
			
		||||
@@ -53,30 +67,31 @@ async def root():
 | 
			
		||||
def ping():
 | 
			
		||||
    return "pong"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/image-classifier/tag-image", status_code=200)
 | 
			
		||||
def image_classification(payload: MlRequestBody):
 | 
			
		||||
    model = _get_model(classification_model, "image-classification")
 | 
			
		||||
    model = get_cached_model(classification_model, "image-classification")
 | 
			
		||||
    assetPath = payload.thumbnailPath
 | 
			
		||||
    return run_engine(model, assetPath)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/sentence-transformer/encode-image", status_code=200)
 | 
			
		||||
def clip_encode_image(payload: MlRequestBody):
 | 
			
		||||
    model = _get_model(clip_image_model)
 | 
			
		||||
    model = get_cached_model(clip_image_model, "clip")
 | 
			
		||||
    assetPath = payload.thumbnailPath
 | 
			
		||||
    return model.encode(Image.open(assetPath)).tolist()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/sentence-transformer/encode-text", status_code=200)
 | 
			
		||||
def clip_encode_text(payload: ClipRequestBody):
 | 
			
		||||
    model = _get_model(clip_text_model)
 | 
			
		||||
    model = get_cached_model(clip_text_model, "clip")
 | 
			
		||||
    text = payload.text
 | 
			
		||||
    return model.encode(text).tolist()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/facial-recognition/detect-faces", status_code=200)
 | 
			
		||||
def facial_recognition(payload: MlRequestBody):
 | 
			
		||||
    model = _get_model(facial_recognition_model, "facial-recognition")
 | 
			
		||||
    model = get_cached_model(facial_recognition_model, "facial-recognition")
 | 
			
		||||
    assetPath = payload.thumbnailPath
 | 
			
		||||
    img = cv.imread(assetPath)
 | 
			
		||||
    height, width, _ = img.shape
 | 
			
		||||
@@ -84,7 +99,7 @@ def facial_recognition(payload: MlRequestBody):
 | 
			
		||||
    faces = model.get(img)
 | 
			
		||||
 | 
			
		||||
    for face in faces:
 | 
			
		||||
        if face.det_score < 0.7:
 | 
			
		||||
        if face.det_score < min_face_score:
 | 
			
		||||
            continue
 | 
			
		||||
        x1, y1, x2, y2 = face.bbox
 | 
			
		||||
 | 
			
		||||
@@ -111,7 +126,7 @@ def run_engine(engine, path):
 | 
			
		||||
 | 
			
		||||
    for index, pred in enumerate(predictions):
 | 
			
		||||
        tags = pred["label"].split(", ")
 | 
			
		||||
        if pred["score"] > 0.9:
 | 
			
		||||
        if pred["score"] > min_tag_score:
 | 
			
		||||
            result = [*result, *tags]
 | 
			
		||||
 | 
			
		||||
    if len(result) > 1:
 | 
			
		||||
@@ -120,26 +135,32 @@ def run_engine(engine, path):
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_model(model, task=None):
 | 
			
		||||
def get_cached_model(model, task):
 | 
			
		||||
    global _model_cache
 | 
			
		||||
    key = "|".join([model, str(task)])
 | 
			
		||||
    if key not in _model_cache:
 | 
			
		||||
        if task:
 | 
			
		||||
            if task == "facial-recognition":
 | 
			
		||||
                face_model = FaceAnalysis(
 | 
			
		||||
                    name=model,
 | 
			
		||||
                    root=cache_folder,
 | 
			
		||||
                    allowed_modules=["detection", "recognition"],
 | 
			
		||||
                )
 | 
			
		||||
                face_model.prepare(ctx_id=0, det_size=(640, 640))
 | 
			
		||||
                _model_cache[key] = face_model
 | 
			
		||||
            else:
 | 
			
		||||
                _model_cache[key] = pipeline(model=model, task=task)
 | 
			
		||||
        else:
 | 
			
		||||
            _model_cache[key] = SentenceTransformer(model, cache_folder=cache_folder)
 | 
			
		||||
        model = _get_model(model, task)
 | 
			
		||||
        _model_cache[key] = model
 | 
			
		||||
 | 
			
		||||
    return _model_cache[key]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_model(model, task):
 | 
			
		||||
    match task:
 | 
			
		||||
        case "facial-recognition":
 | 
			
		||||
            model = FaceAnalysis(
 | 
			
		||||
                name=model,
 | 
			
		||||
                root=cache_folder,
 | 
			
		||||
                allowed_modules=["detection", "recognition"],
 | 
			
		||||
            )
 | 
			
		||||
            model.prepare(ctx_id=0, det_size=(640, 640))
 | 
			
		||||
        case "clip":
 | 
			
		||||
            model = SentenceTransformer(model, cache_folder=cache_folder)
 | 
			
		||||
        case _:
 | 
			
		||||
            model = pipeline(model=model, task=task)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0")
 | 
			
		||||
    port = int(os.getenv("MACHINE_LEARNING_PORT", 3003))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user