mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
chore(ml): memory optimisations (#3934)
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import gunicorn
|
||||
import starlette
|
||||
from pydantic import BaseSettings
|
||||
from rich.console import Console
|
||||
@@ -56,12 +57,14 @@ LOG_LEVELS: dict[str, int] = {
|
||||
settings = Settings()
|
||||
log_settings = LogSettings()
|
||||
|
||||
console = Console(color_system="standard", no_color=log_settings.no_color)
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
|
||||
],
|
||||
)
|
||||
log = logging.getLogger("uvicorn")
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self) -> None:
|
||||
console = Console(color_system="standard", no_color=log_settings.no_color)
|
||||
super().__init__(
|
||||
show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger("gunicorn.access")
|
||||
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette.formparsers import MultiPartParser
|
||||
@@ -33,7 +30,7 @@ def init_state() -> None:
|
||||
)
|
||||
)
|
||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
|
||||
|
||||
@@ -73,17 +70,7 @@ async def predict(
|
||||
|
||||
|
||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
is_dev = os.getenv("NODE_ENV") == "development"
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=is_dev,
|
||||
workers=settings.workers,
|
||||
log_config=None,
|
||||
access_log=log.isEnabledFor(logging.INFO),
|
||||
)
|
||||
if app.state.thread_pool is not None:
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
else:
|
||||
return model.predict(inputs)
|
||||
|
||||
@@ -53,6 +53,7 @@ class InferenceModel(ABC):
|
||||
log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
|
||||
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
||||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||
self.sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
try:
|
||||
loader(**model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user