mirror of
https://github.com/KevinMidboe/immich.git
synced 2025-12-08 20:29:05 +00:00
feat(ml)!: customizable ML settings (#3891)
* consolidated endpoints, added live configuration * added ml settings to server * added settings dashboard * updated deps, fixed typos * simplified modelconfig updated tests * Added ml setting accordion for admin page updated tests * merge `clipText` and `clipVision` * added face distance setting clarified setting * add clip mode in request, dropdown for face models * polished ml settings updated descriptions * update clip field on error * removed unused import * add description for image classification threshold * pin safetensors for arm wheel updated poetry lock * moved dto * set model type only in ml repository * revert form-data package install use fetch instead of axios * added slotted description with link updated facial recognition description clarified effect of disabling tasks * validation before model load * removed unnecessary getconfig call * added migration * updated api updated api updated api --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
@@ -1,29 +1,65 @@
|
||||
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput } from '@app/domain';
|
||||
import {
|
||||
ClassificationConfig,
|
||||
CLIPConfig,
|
||||
CLIPMode,
|
||||
DetectFaceResult,
|
||||
IMachineLearningRepository,
|
||||
ModelConfig,
|
||||
ModelType,
|
||||
RecognitionConfig,
|
||||
TextModelInput,
|
||||
VisionModelInput,
|
||||
} from '@app/domain';
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import axios from 'axios';
|
||||
import { createReadStream } from 'fs';
|
||||
|
||||
const client = axios.create();
|
||||
import { readFile } from 'fs/promises';
|
||||
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private post<T>(input: MachineLearningInput, endpoint: string): Promise<T> {
|
||||
return client.post<T>(endpoint, createReadStream(input.imagePath)).then((res) => res.data);
|
||||
private async post<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
||||
const formData = await this.getFormData(input, config);
|
||||
const res = await fetch(`${url}/predict`, { method: 'POST', body: formData });
|
||||
return res.json();
|
||||
}
|
||||
|
||||
classifyImage(url: string, input: MachineLearningInput): Promise<string[]> {
|
||||
return this.post<string[]>(input, `${url}/image-classifier/tag-image`);
|
||||
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]> {
|
||||
return this.post<string[]>(url, input, { ...config, modelType: ModelType.IMAGE_CLASSIFICATION });
|
||||
}
|
||||
|
||||
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]> {
|
||||
return this.post<DetectFaceResult[]>(input, `${url}/facial-recognition/detect-faces`);
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
||||
return this.post<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
||||
}
|
||||
|
||||
encodeImage(url: string, input: MachineLearningInput): Promise<number[]> {
|
||||
return this.post<number[]>(input, `${url}/sentence-transformer/encode-image`);
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.post<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.VISION,
|
||||
} as CLIPConfig);
|
||||
}
|
||||
|
||||
encodeText(url: string, input: string): Promise<number[]> {
|
||||
return client.post<number[]>(`${url}/sentence-transformer/encode-text`, { text: input }).then((res) => res.data);
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.post<number[]>(url, input, { ...config, modelType: ModelType.CLIP, mode: CLIPMode.TEXT } as CLIPConfig);
|
||||
}
|
||||
|
||||
async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
const { modelName, modelType, ...options } = config;
|
||||
|
||||
formData.append('modelName', modelName);
|
||||
if (modelType) {
|
||||
formData.append('modelType', modelType);
|
||||
}
|
||||
if (options) {
|
||||
formData.append('options', JSON.stringify(options));
|
||||
}
|
||||
if ('imagePath' in input) {
|
||||
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
||||
} else if ('text' in input) {
|
||||
formData.append('text', input.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
}
|
||||
|
||||
return formData;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +52,8 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
private logger = new Logger(TypesenseRepository.name);
|
||||
|
||||
private _client: Client | null = null;
|
||||
private _updateCLIPLock = false;
|
||||
|
||||
private get client(): Client {
|
||||
if (!this._client) {
|
||||
throw new Error('Typesense client not available (no apiKey was provided)');
|
||||
@@ -141,7 +143,7 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
await this.updateAlias(collection);
|
||||
}
|
||||
} catch (error: any) {
|
||||
this.handleError(error);
|
||||
await this.handleError(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +223,30 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
return records.num_deleted;
|
||||
}
|
||||
|
||||
async deleteAllAssets(): Promise<number> {
|
||||
const records = await this.client.collections(assetSchema.name).documents().delete({ filter_by: 'ownerId:!=null' });
|
||||
return records.num_deleted;
|
||||
}
|
||||
|
||||
async updateCLIPField(num_dim: number): Promise<void> {
|
||||
const clipField = assetSchema.fields?.find((field) => field.name === 'smartInfo.clipEmbedding');
|
||||
if (clipField && !this._updateCLIPLock) {
|
||||
try {
|
||||
this._updateCLIPLock = true;
|
||||
clipField.num_dim = num_dim;
|
||||
await this.deleteAllAssets();
|
||||
await this.client
|
||||
.collections(assetSchema.name)
|
||||
.update({ fields: [{ name: 'smartInfo.clipEmbedding', drop: true } as any, clipField] });
|
||||
this.logger.log(`Successfully updated CLIP dimensions to ${num_dim}`);
|
||||
} catch (err: any) {
|
||||
this.logger.error(`Error while updating CLIP field: ${err.message}`);
|
||||
} finally {
|
||||
this._updateCLIPLock = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async delete(collection: SearchCollection, ids: string[]): Promise<void> {
|
||||
await this.client
|
||||
.collections(schemaMap[collection].name)
|
||||
@@ -326,21 +352,34 @@ export class TypesenseRepository implements ISearchRepository {
|
||||
} as SearchResult<T>;
|
||||
}
|
||||
|
||||
private handleError(error: any) {
|
||||
private async handleError(error: any) {
|
||||
this.logger.error('Unable to index documents');
|
||||
const results = error.importResults || [];
|
||||
let dimsChanged = false;
|
||||
for (const result of results) {
|
||||
try {
|
||||
result.document = JSON.parse(result.document);
|
||||
if (result.error.includes('Field `smartInfo.clipEmbedding` must have')) {
|
||||
dimsChanged = true;
|
||||
this.logger.warn(
|
||||
`CLIP embedding dimensions have changed, now ${result.document.smartInfo.clipEmbedding.length} dims. Updating schema...`,
|
||||
);
|
||||
await this.updateCLIPField(result.document.smartInfo.clipEmbedding.length);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.document?.smartInfo?.clipEmbedding) {
|
||||
result.document.smartInfo.clipEmbedding = '<truncated>';
|
||||
}
|
||||
} catch {}
|
||||
} catch (err: any) {
|
||||
this.logger.error(`Error while updating CLIP field: ${(err.message, err.stack)}`);
|
||||
}
|
||||
}
|
||||
|
||||
this.logger.verbose(JSON.stringify(results, null, 2));
|
||||
if (!dimsChanged) {
|
||||
this.logger.log(JSON.stringify(results, null, 2));
|
||||
}
|
||||
}
|
||||
|
||||
private async updateAlias(collection: SearchCollection) {
|
||||
const schema = schemaMap[collection];
|
||||
const alias = await this.client
|
||||
|
||||
Reference in New Issue
Block a user