feat(ml)!: customizable ML settings (#3891)

* consolidated endpoints, added live configuration

* added ml settings to server

* added settings dashboard

* updated deps, fixed typos

* simplified modelconfig

updated tests

* Added ml setting accordion for admin page

updated tests

* merge `clipText` and `clipVision`

* added face distance setting

clarified setting

* add clip mode in request, dropdown for face models

* polished ml settings

updated descriptions

* update clip field on error

* removed unused import

* add description for image classification threshold

* pin safetensors for arm wheel

updated poetry lock

* moved dto

* set model type only in ml repository

* revert form-data package install

use fetch instead of axios

* added slotted description with link

updated facial recognition description

clarified effect of disabling tasks

* validation before model load

* removed unnecessary getconfig call

* added migration

* updated api

updated api

updated api

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Mert
2023-08-29 09:58:00 -04:00
committed by GitHub
parent 22f5e05060
commit bcc36d14a1
56 changed files with 2324 additions and 655 deletions

View File

@@ -1,4 +1,4 @@
import { QueueName } from '@app/domain/job/job.constants';
import { QueueName } from '@app/domain';
import { Column, Entity, PrimaryColumn } from 'typeorm';
@Entity('system_config')
@@ -39,9 +39,18 @@ export enum SystemConfigKey {
MACHINE_LEARNING_ENABLED = 'machineLearning.enabled',
MACHINE_LEARNING_URL = 'machineLearning.url',
MACHINE_LEARNING_FACIAL_RECOGNITION_ENABLED = 'machineLearning.facialRecognitionEnabled',
MACHINE_LEARNING_TAG_IMAGE_ENABLED = 'machineLearning.tagImageEnabled',
MACHINE_LEARNING_CLIP_ENCODE_ENABLED = 'machineLearning.clipEncodeEnabled',
MACHINE_LEARNING_CLASSIFICATION_ENABLED = 'machineLearning.classification.enabled',
MACHINE_LEARNING_CLASSIFICATION_MODEL_NAME = 'machineLearning.classification.modelName',
MACHINE_LEARNING_CLASSIFICATION_MIN_SCORE = 'machineLearning.classification.minScore',
MACHINE_LEARNING_CLIP_ENABLED = 'machineLearning.clip.enabled',
MACHINE_LEARNING_CLIP_MODEL_NAME = 'machineLearning.clip.modelName',
MACHINE_LEARNING_FACIAL_RECOGNITION_ENABLED = 'machineLearning.facialRecognition.enabled',
MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL_NAME = 'machineLearning.facialRecognition.modelName',
MACHINE_LEARNING_FACIAL_RECOGNITION_MIN_SCORE = 'machineLearning.facialRecognition.minScore',
MACHINE_LEARNING_FACIAL_RECOGNITION_MAX_DISTANCE = 'machineLearning.facialRecognition.maxDistance',
OAUTH_ENABLED = 'oauth.enabled',
OAUTH_ISSUER_URL = 'oauth.issuerUrl',
@@ -114,9 +123,21 @@ export interface SystemConfig {
machineLearning: {
enabled: boolean;
url: string;
clipEncodeEnabled: boolean;
facialRecognitionEnabled: boolean;
tagImageEnabled: boolean;
classification: {
enabled: boolean;
modelName: string;
minScore: number;
};
clip: {
enabled: boolean;
modelName: string;
};
facialRecognition: {
enabled: boolean;
modelName: string;
minScore: number;
maxDistance: number;
};
};
oauth: {
enabled: boolean;

View File

@@ -0,0 +1,25 @@
import { MigrationInterface, QueryRunner } from "typeorm"
export class RenameMLEnableFlags1693236627291 implements MigrationInterface {
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`
UPDATE system_config SET key = CASE
WHEN key = 'ffmpeg.classificationEnabled' THEN 'ffmpeg.classification.enabled'
WHEN key = 'ffmpeg.clipEnabled' THEN 'ffmpeg.clip.enabled'
WHEN key = 'ffmpeg.facialRecognitionEnabled' THEN 'ffmpeg.facialRecognition.enabled'
ELSE key
END
`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`
UPDATE system_config SET key = CASE
WHEN key = 'ffmpeg.classification.enabled' THEN 'ffmpeg.classificationEnabled'
WHEN key = 'ffmpeg.clip.enabled' THEN 'ffmpeg.clipEnabled'
WHEN key = 'ffmpeg.facialRecognition.enabled' THEN 'ffmpeg.facialRecognitionEnabled'
ELSE key
END
`);
}
}

View File

@@ -1,29 +1,65 @@
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput } from '@app/domain';
import {
ClassificationConfig,
CLIPConfig,
CLIPMode,
DetectFaceResult,
IMachineLearningRepository,
ModelConfig,
ModelType,
RecognitionConfig,
TextModelInput,
VisionModelInput,
} from '@app/domain';
import { Injectable } from '@nestjs/common';
import axios from 'axios';
import { createReadStream } from 'fs';
const client = axios.create();
import { readFile } from 'fs/promises';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private post<T>(input: MachineLearningInput, endpoint: string): Promise<T> {
return client.post<T>(endpoint, createReadStream(input.imagePath)).then((res) => res.data);
private async post<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
const formData = await this.getFormData(input, config);
const res = await fetch(`${url}/predict`, { method: 'POST', body: formData });
return res.json();
}
classifyImage(url: string, input: MachineLearningInput): Promise<string[]> {
return this.post<string[]>(input, `${url}/image-classifier/tag-image`);
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]> {
return this.post<string[]>(url, input, { ...config, modelType: ModelType.IMAGE_CLASSIFICATION });
}
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]> {
return this.post<DetectFaceResult[]>(input, `${url}/facial-recognition/detect-faces`);
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
return this.post<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
}
encodeImage(url: string, input: MachineLearningInput): Promise<number[]> {
return this.post<number[]>(input, `${url}/sentence-transformer/encode-image`);
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
return this.post<number[]>(url, input, {
...config,
modelType: ModelType.CLIP,
mode: CLIPMode.VISION,
} as CLIPConfig);
}
encodeText(url: string, input: string): Promise<number[]> {
return client.post<number[]>(`${url}/sentence-transformer/encode-text`, { text: input }).then((res) => res.data);
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
return this.post<number[]>(url, input, { ...config, modelType: ModelType.CLIP, mode: CLIPMode.TEXT } as CLIPConfig);
}
async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
const formData = new FormData();
const { modelName, modelType, ...options } = config;
formData.append('modelName', modelName);
if (modelType) {
formData.append('modelType', modelType);
}
if (options) {
formData.append('options', JSON.stringify(options));
}
if ('imagePath' in input) {
formData.append('image', new Blob([await readFile(input.imagePath)]));
} else if ('text' in input) {
formData.append('text', input.text);
} else {
throw new Error('Invalid input');
}
return formData;
}
}

View File

@@ -52,6 +52,8 @@ export class TypesenseRepository implements ISearchRepository {
private logger = new Logger(TypesenseRepository.name);
private _client: Client | null = null;
private _updateCLIPLock = false;
private get client(): Client {
if (!this._client) {
throw new Error('Typesense client not available (no apiKey was provided)');
@@ -141,7 +143,7 @@ export class TypesenseRepository implements ISearchRepository {
await this.updateAlias(collection);
}
} catch (error: any) {
this.handleError(error);
await this.handleError(error);
}
}
@@ -221,6 +223,30 @@ export class TypesenseRepository implements ISearchRepository {
return records.num_deleted;
}
async deleteAllAssets(): Promise<number> {
const records = await this.client.collections(assetSchema.name).documents().delete({ filter_by: 'ownerId:!=null' });
return records.num_deleted;
}
async updateCLIPField(num_dim: number): Promise<void> {
const clipField = assetSchema.fields?.find((field) => field.name === 'smartInfo.clipEmbedding');
if (clipField && !this._updateCLIPLock) {
try {
this._updateCLIPLock = true;
clipField.num_dim = num_dim;
await this.deleteAllAssets();
await this.client
.collections(assetSchema.name)
.update({ fields: [{ name: 'smartInfo.clipEmbedding', drop: true } as any, clipField] });
this.logger.log(`Successfully updated CLIP dimensions to ${num_dim}`);
} catch (err: any) {
this.logger.error(`Error while updating CLIP field: ${err.message}`);
} finally {
this._updateCLIPLock = false;
}
}
}
async delete(collection: SearchCollection, ids: string[]): Promise<void> {
await this.client
.collections(schemaMap[collection].name)
@@ -326,21 +352,34 @@ export class TypesenseRepository implements ISearchRepository {
} as SearchResult<T>;
}
private handleError(error: any) {
private async handleError(error: any) {
this.logger.error('Unable to index documents');
const results = error.importResults || [];
let dimsChanged = false;
for (const result of results) {
try {
result.document = JSON.parse(result.document);
if (result.error.includes('Field `smartInfo.clipEmbedding` must have')) {
dimsChanged = true;
this.logger.warn(
`CLIP embedding dimensions have changed, now ${result.document.smartInfo.clipEmbedding.length} dims. Updating schema...`,
);
await this.updateCLIPField(result.document.smartInfo.clipEmbedding.length);
break;
}
if (result.document?.smartInfo?.clipEmbedding) {
result.document.smartInfo.clipEmbedding = '<truncated>';
}
} catch {}
} catch (err: any) {
this.logger.error(`Error while updating CLIP field: ${(err.message, err.stack)}`);
}
}
this.logger.verbose(JSON.stringify(results, null, 2));
if (!dimsChanged) {
this.logger.log(JSON.stringify(results, null, 2));
}
}
private async updateAlias(collection: SearchCollection) {
const schema = schemaMap[collection];
const alias = await this.client