chore(ml): memory optimisations (#3934)

This commit is contained in:
Mert
2023-08-31 19:30:53 -04:00
committed by GitHub
parent c0a48d7357
commit 41461e0d5d
8 changed files with 122 additions and 107 deletions

View File

@@ -1,11 +1,8 @@
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import orjson
import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from starlette.formparsers import MultiPartParser
@@ -33,7 +30,7 @@ def init_state() -> None:
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@@ -73,17 +70,7 @@ async def predict(
async def run(model: InferenceModel, inputs: Any) -> Any:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
if __name__ == "__main__":
is_dev = os.getenv("NODE_ENV") == "development"
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=is_dev,
workers=settings.workers,
log_config=None,
access_log=log.isEnabledFor(logging.INFO),
)
if app.state.thread_pool is not None:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
else:
return model.predict(inputs)