mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-10-29 17:40:28 +00:00
feat(ml): export clip models to ONNX and host models on Hugging Face (#4700)
* export clip models * export to hf refactored export code * export mclip, general refactoring cleanup * updated conda deps * do transforms with pillow and numpy, add tokenization config to export, general refactoring * moved conda dockerfile, re-added poetry * minor fixes * updated link * updated tests * removed `requirements.txt` from workflow * fixed mimalloc path * removed torchvision * cleaner np typing * review suggestions * update default model name * update test
This commit is contained in:
21
machine-learning/export/Dockerfile
Normal file
21
machine-learning/export/Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM mambaorg/micromamba:bookworm-slim as builder
|
||||
|
||||
ENV NODE_ENV=production \
|
||||
TRANSFORMERS_CACHE=/cache \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PATH="/opt/venv/bin:$PATH" \
|
||||
PYTHONPATH=/usr/src
|
||||
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml /tmp/conda-lock.yml
|
||||
RUN micromamba install -y -n base -f /tmp/conda-lock.yml && \
|
||||
micromamba remove -y -n base cxx-compiler && \
|
||||
micromamba clean --all --yes
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER start.sh .
|
||||
COPY --chown=$MAMBA_USER:$MAMBA_USER app .
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
|
||||
CMD ["./start.sh"]
|
||||
3520
machine-learning/export/conda-lock.yml
Normal file
3520
machine-learning/export/conda-lock.yml
Normal file
File diff suppressed because it is too large
Load Diff
15
machine-learning/export/env.dev.yaml
Normal file
15
machine-learning/export/env.dev.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
name: base
|
||||
channels:
|
||||
- conda-forge
|
||||
platforms:
|
||||
- linux-64
|
||||
- linux-aarch64
|
||||
dependencies:
|
||||
- black
|
||||
- conda-lock
|
||||
- mypy
|
||||
- pytest
|
||||
- pytest-cov
|
||||
- pytest-mock
|
||||
- ruff
|
||||
category: dev
|
||||
25
machine-learning/export/env.yaml
Normal file
25
machine-learning/export/env.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
name: base
|
||||
channels:
|
||||
- conda-forge
|
||||
- nvidia
|
||||
- pytorch-nightly
|
||||
platforms:
|
||||
- linux-64
|
||||
dependencies:
|
||||
- cxx-compiler
|
||||
- onnx==1.*
|
||||
- onnxruntime==1.*
|
||||
- open-clip-torch==2.*
|
||||
- orjson==3.*
|
||||
- pip
|
||||
- python==3.11.*
|
||||
- pytorch
|
||||
- rich==13.*
|
||||
- safetensors==0.*
|
||||
- setuptools==68.*
|
||||
- torchvision
|
||||
- transformers==4.*
|
||||
- pip:
|
||||
- multilingual-clip
|
||||
- onnx-simplifier
|
||||
category: main
|
||||
0
machine-learning/export/models/__init__.py
Normal file
0
machine-learning/export/models/__init__.py
Normal file
67
machine-learning/export/models/mclip.py
Normal file
67
machine-learning/export/models/mclip.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .openclip import OpenCLIPModelConfig
|
||||
from .openclip import to_onnx as openclip_to_onnx
|
||||
from .optimize import optimize
|
||||
from .util import get_model_path
|
||||
|
||||
_MCLIP_TO_OPENCLIP = {
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
|
||||
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
|
||||
}
|
||||
|
||||
|
||||
def to_onnx(
|
||||
model_name: str,
|
||||
output_dir_visual: Path | str,
|
||||
output_dir_textual: Path | str,
|
||||
) -> None:
|
||||
textual_path = get_model_path(output_dir_textual)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
|
||||
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
export_text_encoder(model, textual_path)
|
||||
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
|
||||
optimize(textual_path)
|
||||
|
||||
|
||||
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
|
||||
output_path = Path(output_path)
|
||||
|
||||
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
embs = self.transformer(input_ids, attention_mask)[0]
|
||||
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
|
||||
embs = self.LinearTransformation(embs)
|
||||
return torch.nn.functional.normalize(embs, dim=-1)
|
||||
|
||||
# unfortunately need to monkeypatch for tracing to work here
|
||||
# otherwise it hits the 2GiB protobuf serialization limit
|
||||
MultilingualCLIP.forward = forward
|
||||
|
||||
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
args,
|
||||
output_path.as_posix(),
|
||||
input_names=["input_ids", "attention_mask"],
|
||||
output_names=["text_embedding"],
|
||||
opset_version=17,
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
||||
"attention_mask": {0: "batch_size", 1: "sequence_length"},
|
||||
},
|
||||
)
|
||||
109
machine-learning/export/models/openclip.py
Normal file
109
machine-learning/export/models/openclip.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import open_clip
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .optimize import optimize
|
||||
from .util import get_model_path, save_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCLIPModelConfig:
|
||||
name: str
|
||||
pretrained: str
|
||||
image_size: int = field(init=False)
|
||||
sequence_length: int = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
open_clip_cfg = open_clip.get_model_config(self.name)
|
||||
if open_clip_cfg is None:
|
||||
raise ValueError(f"Unknown model {self.name}")
|
||||
self.image_size = open_clip_cfg["vision_cfg"]["image_size"]
|
||||
self.sequence_length = open_clip_cfg["text_cfg"]["context_length"]
|
||||
|
||||
|
||||
def to_onnx(
|
||||
model_cfg: OpenCLIPModelConfig,
|
||||
output_dir_visual: Path | str | None = None,
|
||||
output_dir_textual: Path | str | None = None,
|
||||
) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model = open_clip.create_model(
|
||||
model_cfg.name,
|
||||
pretrained=model_cfg.pretrained,
|
||||
jit=False,
|
||||
cache_dir=tmpdir,
|
||||
require_pretrained=True,
|
||||
)
|
||||
|
||||
text_vision_cfg = open_clip.get_model_config(model_cfg.name)
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if output_dir_visual is not None:
|
||||
output_dir_visual = Path(output_dir_visual)
|
||||
visual_path = get_model_path(output_dir_visual)
|
||||
|
||||
save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json")
|
||||
save_config(text_vision_cfg, output_dir_visual.parent / "config.json")
|
||||
export_image_encoder(model, model_cfg, visual_path)
|
||||
|
||||
optimize(visual_path)
|
||||
|
||||
if output_dir_textual is not None:
|
||||
output_dir_textual = Path(output_dir_textual)
|
||||
textual_path = get_model_path(output_dir_textual)
|
||||
|
||||
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
|
||||
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
|
||||
export_text_encoder(model, model_cfg, textual_path)
|
||||
optimize(textual_path)
|
||||
|
||||
|
||||
def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
|
||||
output_path = Path(output_path)
|
||||
|
||||
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_image(image, normalize=True)
|
||||
|
||||
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
||||
traced = torch.jit.trace(encode_image, args)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
torch.onnx.export(
|
||||
traced,
|
||||
args,
|
||||
output_path.as_posix(),
|
||||
input_names=["image"],
|
||||
output_names=["image_embedding"],
|
||||
opset_version=17,
|
||||
dynamic_axes={"image": {0: "batch_size"}},
|
||||
)
|
||||
|
||||
|
||||
def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
|
||||
output_path = Path(output_path)
|
||||
|
||||
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_text(text, normalize=True)
|
||||
|
||||
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
||||
traced = torch.jit.trace(encode_text, args)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
torch.onnx.export(
|
||||
traced,
|
||||
args,
|
||||
output_path.as_posix(),
|
||||
input_names=["text"],
|
||||
output_names=["text_embedding"],
|
||||
opset_version=17,
|
||||
dynamic_axes={"text": {0: "batch_size"}},
|
||||
)
|
||||
38
machine-learning/export/models/optimize.py
Normal file
38
machine-learning/export/models/optimize.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import onnxsim
|
||||
|
||||
|
||||
def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None:
|
||||
model_path = Path(model_path)
|
||||
output_path = Path(output_path)
|
||||
model = onnx.load(model_path.as_posix())
|
||||
model, check = onnxsim.simplify(model, skip_shape_inference=True)
|
||||
assert check, "Simplified ONNX model could not be validated"
|
||||
onnx.save(model, output_path.as_posix())
|
||||
|
||||
|
||||
def optimize_ort(
|
||||
model_path: Path | str,
|
||||
output_path: Path | str,
|
||||
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
|
||||
) -> None:
|
||||
model_path = Path(model_path)
|
||||
output_path = Path(output_path)
|
||||
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = level
|
||||
sess_options.optimized_model_filepath = output_path.as_posix()
|
||||
|
||||
ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options)
|
||||
|
||||
|
||||
def optimize(model_path: Path | str) -> None:
|
||||
model_path = Path(model_path)
|
||||
|
||||
optimize_ort(model_path, model_path)
|
||||
# onnxsim serializes large models as a blob, which uses much more memory when loading the model at runtime
|
||||
if not any(file.name.startswith("Constant") for file in model_path.parent.iterdir()):
|
||||
optimize_onnxsim(model_path, model_path)
|
||||
15
machine-learning/export/models/util.py
Normal file
15
machine-learning/export/models/util.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_model_path(output_dir: Path | str) -> Path:
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir / "model.onnx"
|
||||
|
||||
|
||||
def save_config(config: Any, output_path: Path | str) -> None:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
json.dump(config, output_path.open("w"))
|
||||
76
machine-learning/export/run.py
Normal file
76
machine-learning/export/run.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import gc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from huggingface_hub import create_repo, login, upload_folder
|
||||
from models import mclip, openclip
|
||||
from rich.progress import Progress
|
||||
|
||||
models = [
|
||||
"RN50::openai",
|
||||
"RN50::yfcc15m",
|
||||
"RN50::cc12m",
|
||||
"RN101::openai",
|
||||
"RN101::yfcc15m",
|
||||
"RN50x4::openai",
|
||||
"RN50x16::openai",
|
||||
"RN50x64::openai",
|
||||
"ViT-B-32::openai",
|
||||
"ViT-B-32::laion2b_e16",
|
||||
"ViT-B-32::laion400m_e31",
|
||||
"ViT-B-32::laion400m_e32",
|
||||
"ViT-B-32::laion2b-s34b-b79k",
|
||||
"ViT-B-16::openai",
|
||||
"ViT-B-16::laion400m_e31",
|
||||
"ViT-B-16::laion400m_e32",
|
||||
"ViT-B-16-plus-240::laion400m_e31",
|
||||
"ViT-B-16-plus-240::laion400m_e32",
|
||||
"ViT-L-14::openai",
|
||||
"ViT-L-14::laion400m_e31",
|
||||
"ViT-L-14::laion400m_e32",
|
||||
"ViT-L-14::laion2b-s32b-b82k",
|
||||
"ViT-L-14-336::openai",
|
||||
"ViT-H-14::laion2b-s32b-b79k",
|
||||
"ViT-g-14::laion2b-s12b-b42k",
|
||||
"M-CLIP/LABSE-Vit-L-14",
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
|
||||
"M-CLIP/XLM-Roberta-Large-Vit-L-14",
|
||||
]
|
||||
|
||||
login(token=os.environ["HF_AUTH_TOKEN"])
|
||||
|
||||
with Progress() as progress:
|
||||
task1 = progress.add_task("[green]Exporting models...", total=len(models))
|
||||
task2 = progress.add_task("[yellow]Uploading models...", total=len(models))
|
||||
|
||||
with TemporaryDirectory() as tmp:
|
||||
tmpdir = Path(tmp)
|
||||
for model in models:
|
||||
model_name = model.split("/")[-1].replace("::", "__")
|
||||
config_path = tmpdir / model_name / "config.json"
|
||||
|
||||
def upload() -> None:
|
||||
progress.update(task2, description=f"[yellow]Uploading {model_name}")
|
||||
repo_id = f"immich-app/{model_name}"
|
||||
|
||||
create_repo(repo_id, exist_ok=True)
|
||||
upload_folder(repo_id=repo_id, folder_path=tmpdir / model_name)
|
||||
progress.update(task2, advance=1)
|
||||
|
||||
def export() -> None:
|
||||
progress.update(task1, description=f"[green]Exporting {model_name}")
|
||||
visual_dir = tmpdir / model_name / "visual"
|
||||
textual_dir = tmpdir / model_name / "textual"
|
||||
if model.startswith("M-CLIP"):
|
||||
mclip.to_onnx(model, visual_dir, textual_dir)
|
||||
else:
|
||||
name, _, pretrained = model_name.partition("__")
|
||||
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
|
||||
|
||||
progress.update(task1, advance=1)
|
||||
gc.collect()
|
||||
|
||||
export()
|
||||
upload()
|
||||
Reference in New Issue
Block a user