chore(ml): memory optimisations (#3934)

This commit is contained in:
Mert
2023-08-31 19:30:53 -04:00
committed by GitHub
parent c0a48d7357
commit 41461e0d5d
8 changed files with 122 additions and 107 deletions

View File

@@ -2,6 +2,7 @@ import logging
import os
from pathlib import Path
import gunicorn
import starlette
from pydantic import BaseSettings
from rich.console import Console
@@ -56,12 +57,14 @@ LOG_LEVELS: dict[str, int] = {
settings = Settings()
log_settings = LogSettings()
console = Console(color_system="standard", no_color=log_settings.no_color)
logging.basicConfig(
format="%(message)s",
handlers=[
RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
],
)
log = logging.getLogger("uvicorn")
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
console = Console(color_system="standard", no_color=log_settings.no_color)
super().__init__(
show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
)
log = logging.getLogger("gunicorn.access")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))

View File

@@ -1,11 +1,8 @@
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import orjson
import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from starlette.formparsers import MultiPartParser
@@ -33,7 +30,7 @@ def init_state() -> None:
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@@ -73,17 +70,7 @@ async def predict(
async def run(model: InferenceModel, inputs: Any) -> Any:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
if __name__ == "__main__":
is_dev = os.getenv("NODE_ENV") == "development"
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=is_dev,
workers=settings.workers,
log_config=None,
access_log=log.isEnabledFor(logging.INFO),
)
if app.state.thread_pool is not None:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
else:
return model.predict(inputs)

View File

@@ -53,6 +53,7 @@ class InferenceModel(ABC):
log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
self.sess_options.inter_op_num_threads = inter_op_num_threads
self.sess_options.intra_op_num_threads = intra_op_num_threads
self.sess_options.enable_cpu_mem_arena = False
try:
loader(**model_kwargs)