feat(ml)!: customizable ML settings (#3891)

* consolidated endpoints, added live configuration

* added ml settings to server

* added settings dashboard

* updated deps, fixed typos

* simplified modelconfig

updated tests

* Added ml setting accordion for admin page

updated tests

* merge `clipText` and `clipVision`

* added face distance setting

clarified setting

* add clip mode in request, dropdown for face models

* polished ml settings

updated descriptions

* update clip field on error

* removed unused import

* add description for image classification threshold

* pin safetensors for arm wheel

updated poetry lock

* moved dto

* set model type only in ml repository

* revert form-data package install

use fetch instead of axios

* added slotted description with link

updated facial recognition description

clarified effect of disabling tasks

* validation before model load

* removed unnecessary getconfig call

* added migration

* updated api

updated api

updated api

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Mert
2023-08-29 09:58:00 -04:00
committed by GitHub
parent 22f5e05060
commit bcc36d14a1
56 changed files with 2324 additions and 655 deletions

View File

@@ -0,0 +1 @@
export * from './model-config.dto';

View File

@@ -0,0 +1,50 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsBoolean, IsEnum, IsNotEmpty, IsNumber, IsOptional, IsString, Max, Min } from 'class-validator';
import { CLIPMode, ModelType } from '../machine-learning.interface';
export class ModelConfig {
@IsBoolean()
enabled!: boolean;
@IsString()
@IsNotEmpty()
modelName!: string;
@IsEnum(ModelType)
@IsOptional()
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
modelType?: ModelType;
}
export class ClassificationConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
minScore!: number;
}
export class CLIPConfig extends ModelConfig {
@IsEnum(CLIPMode)
@IsOptional()
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
mode?: CLIPMode;
}
export class RecognitionConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
minScore!: number;
@IsNumber()
@Min(0)
@Max(2)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
maxDistance!: number;
}

View File

@@ -1,3 +1,4 @@
export * from './dto';
export * from './machine-learning.interface';
export * from './smart-info.repository';
export * from './smart-info.service';

View File

@@ -1,9 +1,15 @@
import { ClassificationConfig, CLIPConfig, RecognitionConfig } from './dto';
export const IMachineLearningRepository = 'IMachineLearningRepository';
export interface MachineLearningInput {
export interface VisionModelInput {
imagePath: string;
}
export interface TextModelInput {
text: string;
}
export interface BoundingBox {
x1: number;
y1: number;
@@ -19,9 +25,20 @@ export interface DetectFaceResult {
embedding: number[];
}
export interface IMachineLearningRepository {
classifyImage(url: string, input: MachineLearningInput): Promise<string[]>;
encodeImage(url: string, input: MachineLearningInput): Promise<number[]>;
encodeText(url: string, input: string): Promise<number[]>;
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]>;
export enum ModelType {
IMAGE_CLASSIFICATION = 'image-classification',
FACIAL_RECOGNITION = 'facial-recognition',
CLIP = 'clip',
}
export enum CLIPMode {
VISION = 'vision',
TEXT = 'text',
}
export interface IMachineLearningRepository {
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]>;
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
}

View File

@@ -84,9 +84,13 @@ describe(SmartInfoService.name, () => {
await sut.handleClassifyImage({ id: asset.id });
expect(machineMock.classifyImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(machineMock.classifyImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{
imagePath: 'path/to/resize.ext',
},
{ enabled: true, minScore: 0.9, modelName: 'microsoft/resnet-50' },
);
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
tags: ['tag1', 'tag2', 'tag3'],
@@ -141,13 +145,16 @@ describe(SmartInfoService.name, () => {
});
it('should save the returned objects', async () => {
smartMock.upsert.mockResolvedValue();
machineMock.encodeImage.mockResolvedValue([0.01, 0.02, 0.03]);
await sut.handleEncodeClip({ id: asset.id });
expect(machineMock.encodeImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{ imagePath: 'path/to/resize.ext' },
{ enabled: true, modelName: 'ViT-B-32::openai' },
);
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
clipEmbedding: [0.01, 0.02, 0.03],

View File

@@ -22,7 +22,7 @@ export class SmartInfoService {
async handleQueueObjectTagging({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
return true;
}
@@ -43,7 +43,7 @@ export class SmartInfoService {
async handleClassifyImage({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
return true;
}
@@ -52,7 +52,11 @@ export class SmartInfoService {
return false;
}
const tags = await this.machineLearning.classifyImage(machineLearning.url, { imagePath: asset.resizePath });
const tags = await this.machineLearning.classifyImage(
machineLearning.url,
{ imagePath: asset.resizePath },
machineLearning.classification,
);
await this.repository.upsert({ assetId: asset.id, tags });
return true;
@@ -60,7 +64,7 @@ export class SmartInfoService {
async handleQueueEncodeClip({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
return true;
}
@@ -81,7 +85,7 @@ export class SmartInfoService {
async handleEncodeClip({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
return true;
}
@@ -90,7 +94,12 @@ export class SmartInfoService {
return false;
}
const clipEmbedding = await this.machineLearning.encodeImage(machineLearning.url, { imagePath: asset.resizePath });
const clipEmbedding = await this.machineLearning.encodeImage(
machineLearning.url,
{ imagePath: asset.resizePath },
machineLearning.clip,
);
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
return true;