feat(ml)!: customizable ML settings (#3891)

* consolidated endpoints, added live configuration

* added ml settings to server

* added settings dashboard

* updated deps, fixed typos

* simplified modelconfig

updated tests

* Added ml setting accordion for admin page

updated tests

* merge `clipText` and `clipVision`

* added face distance setting

clarified setting

* add clip mode in request, dropdown for face models

* polished ml settings

updated descriptions

* update clip field on error

* removed unused import

* add description for image classification threshold

* pin safetensors for arm wheel

updated poetry lock

* moved dto

* set model type only in ml repository

* revert form-data package install

use fetch instead of axios

* added slotted description with link

updated facial recognition description

clarified effect of disabling tasks

* validation before model load

* removed unnecessary getconfig call

* added migration

* updated api

updated api

updated api

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Mert
2023-08-29 09:58:00 -04:00
committed by GitHub
parent 22f5e05060
commit bcc36d14a1
56 changed files with 2324 additions and 655 deletions

View File

@@ -115,6 +115,7 @@ describe(FacialRecognitionService.name, () => {
personMock = newPersonRepositoryMock();
searchMock = newSearchRepositoryMock();
storageMock = newStorageRepositoryMock();
configMock = newSystemConfigRepositoryMock();
mediaMock.crop.mockResolvedValue(croppedFace);
@@ -179,9 +180,18 @@ describe(FacialRecognitionService.name, () => {
machineLearningMock.detectFaces.mockResolvedValue([]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleRecognizeFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: assetStub.image.resizePath,
});
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{
imagePath: assetStub.image.resizePath,
},
{
enabled: true,
maxDistance: 0.6,
minScore: 0.7,
modelName: 'buffalo_l',
},
);
expect(faceMock.create).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
});

View File

@@ -32,7 +32,7 @@ export class FacialRecognitionService {
async handleQueueRecognizeFaces({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
return true;
}
@@ -59,7 +59,7 @@ export class FacialRecognitionService {
async handleRecognizeFaces({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
return true;
}
@@ -68,7 +68,11 @@ export class FacialRecognitionService {
return false;
}
const faces = await this.machineLearning.detectFaces(machineLearning.url, { imagePath: asset.resizePath });
const faces = await this.machineLearning.detectFaces(
machineLearning.url,
{ imagePath: asset.resizePath },
machineLearning.facialRecognition,
);
this.logger.debug(`${faces.length} faces detected in ${asset.resizePath}`);
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `float[${face.embedding.length}]` })));
@@ -80,7 +84,7 @@ export class FacialRecognitionService {
// try to find a matching face and link to the associated person
// The closer to 0, the better the match. Range is from 0 to 2
if (faceSearchResult.total && faceSearchResult.distances[0] < 0.6) {
if (faceSearchResult.total && faceSearchResult.distances[0] <= machineLearning.facialRecognition.maxDistance) {
this.logger.verbose(`Match face with distance ${faceSearchResult.distances[0]}`);
personId = faceSearchResult.items[0].personId;
}
@@ -115,7 +119,7 @@ export class FacialRecognitionService {
async handleGenerateFaceThumbnail(data: IFaceThumbnailJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
if (!machineLearning.enabled || !machineLearning.facialRecognition.enabled) {
return true;
}

View File

@@ -86,6 +86,7 @@ export interface ISearchRepository {
deleteAssets(ids: string[]): Promise<void>;
deleteFaces(ids: string[]): Promise<void>;
deleteAllFaces(): Promise<number>;
updateCLIPField(num_dim: number): Promise<void>;
searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;

View File

@@ -121,15 +121,18 @@ export class SearchService {
await this.configCore.requireFeature(FeatureFlag.SEARCH);
const query = dto.q || dto.query || '*';
const hasClip = machineLearning.enabled && machineLearning.clipEncodeEnabled;
const hasClip = machineLearning.enabled && machineLearning.clip.enabled;
const strategy = dto.clip && hasClip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const filters = { userId: authUser.id, ...dto };
let assets: SearchResult<AssetEntity>;
switch (strategy) {
case SearchStrategy.CLIP:
const clip = await this.machineLearning.encodeText(machineLearning.url, query);
assets = await this.searchRepository.vectorSearch(clip, filters);
const {
machineLearning: { clip },
} = await this.configCore.getConfig();
const embedding = await this.machineLearning.encodeText(machineLearning.url, { text: query }, clip);
assets = await this.searchRepository.vectorSearch(embedding, filters);
break;
case SearchStrategy.TEXT:
default:

View File

@@ -0,0 +1 @@
export * from './model-config.dto';

View File

@@ -0,0 +1,50 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsBoolean, IsEnum, IsNotEmpty, IsNumber, IsOptional, IsString, Max, Min } from 'class-validator';
import { CLIPMode, ModelType } from '../machine-learning.interface';
export class ModelConfig {
@IsBoolean()
enabled!: boolean;
@IsString()
@IsNotEmpty()
modelName!: string;
@IsEnum(ModelType)
@IsOptional()
@ApiProperty({ enumName: 'ModelType', enum: ModelType })
modelType?: ModelType;
}
export class ClassificationConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
minScore!: number;
}
export class CLIPConfig extends ModelConfig {
@IsEnum(CLIPMode)
@IsOptional()
@ApiProperty({ enumName: 'CLIPMode', enum: CLIPMode })
mode?: CLIPMode;
}
export class RecognitionConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
minScore!: number;
@IsNumber()
@Min(0)
@Max(2)
@Type(() => Number)
@ApiProperty({ type: 'integer' })
maxDistance!: number;
}

View File

@@ -1,3 +1,4 @@
export * from './dto';
export * from './machine-learning.interface';
export * from './smart-info.repository';
export * from './smart-info.service';

View File

@@ -1,9 +1,15 @@
import { ClassificationConfig, CLIPConfig, RecognitionConfig } from './dto';
export const IMachineLearningRepository = 'IMachineLearningRepository';
export interface MachineLearningInput {
export interface VisionModelInput {
imagePath: string;
}
export interface TextModelInput {
text: string;
}
export interface BoundingBox {
x1: number;
y1: number;
@@ -19,9 +25,20 @@ export interface DetectFaceResult {
embedding: number[];
}
export interface IMachineLearningRepository {
classifyImage(url: string, input: MachineLearningInput): Promise<string[]>;
encodeImage(url: string, input: MachineLearningInput): Promise<number[]>;
encodeText(url: string, input: string): Promise<number[]>;
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]>;
export enum ModelType {
IMAGE_CLASSIFICATION = 'image-classification',
FACIAL_RECOGNITION = 'facial-recognition',
CLIP = 'clip',
}
export enum CLIPMode {
VISION = 'vision',
TEXT = 'text',
}
export interface IMachineLearningRepository {
classifyImage(url: string, input: VisionModelInput, config: ClassificationConfig): Promise<string[]>;
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]>;
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]>;
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]>;
}

View File

@@ -84,9 +84,13 @@ describe(SmartInfoService.name, () => {
await sut.handleClassifyImage({ id: asset.id });
expect(machineMock.classifyImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(machineMock.classifyImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{
imagePath: 'path/to/resize.ext',
},
{ enabled: true, minScore: 0.9, modelName: 'microsoft/resnet-50' },
);
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
tags: ['tag1', 'tag2', 'tag3'],
@@ -141,13 +145,16 @@ describe(SmartInfoService.name, () => {
});
it('should save the returned objects', async () => {
smartMock.upsert.mockResolvedValue();
machineMock.encodeImage.mockResolvedValue([0.01, 0.02, 0.03]);
await sut.handleEncodeClip({ id: asset.id });
expect(machineMock.encodeImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{ imagePath: 'path/to/resize.ext' },
{ enabled: true, modelName: 'ViT-B-32::openai' },
);
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
clipEmbedding: [0.01, 0.02, 0.03],

View File

@@ -22,7 +22,7 @@ export class SmartInfoService {
async handleQueueObjectTagging({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
return true;
}
@@ -43,7 +43,7 @@ export class SmartInfoService {
async handleClassifyImage({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
return true;
}
@@ -52,7 +52,11 @@ export class SmartInfoService {
return false;
}
const tags = await this.machineLearning.classifyImage(machineLearning.url, { imagePath: asset.resizePath });
const tags = await this.machineLearning.classifyImage(
machineLearning.url,
{ imagePath: asset.resizePath },
machineLearning.classification,
);
await this.repository.upsert({ assetId: asset.id, tags });
return true;
@@ -60,7 +64,7 @@ export class SmartInfoService {
async handleQueueEncodeClip({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
return true;
}
@@ -81,7 +85,7 @@ export class SmartInfoService {
async handleEncodeClip({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
if (!machineLearning.enabled || !machineLearning.clip.enabled) {
return true;
}
@@ -90,7 +94,12 @@ export class SmartInfoService {
return false;
}
const clipEmbedding = await this.machineLearning.encodeImage(machineLearning.url, { imagePath: asset.resizePath });
const clipEmbedding = await this.machineLearning.encodeImage(
machineLearning.url,
{ imagePath: asset.resizePath },
machineLearning.clip,
);
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
return true;

View File

@@ -1,4 +1,6 @@
import { IsBoolean, IsUrl, ValidateIf } from 'class-validator';
import { ClassificationConfig, CLIPConfig, RecognitionConfig } from '@app/domain';
import { Type } from 'class-transformer';
import { IsBoolean, IsObject, IsUrl, ValidateIf, ValidateNested } from 'class-validator';
export class SystemConfigMachineLearningDto {
@IsBoolean()
@@ -8,12 +10,18 @@ export class SystemConfigMachineLearningDto {
@ValidateIf((dto) => dto.enabled)
url!: string;
@IsBoolean()
clipEncodeEnabled!: boolean;
@Type(() => ClassificationConfig)
@ValidateNested()
@IsObject()
classification!: ClassificationConfig;
@IsBoolean()
facialRecognitionEnabled!: boolean;
@Type(() => CLIPConfig)
@ValidateNested()
@IsObject()
clip!: CLIPConfig;
@IsBoolean()
tagImageEnabled!: boolean;
@Type(() => RecognitionConfig)
@ValidateNested()
@IsObject()
facialRecognition!: RecognitionConfig;
}

View File

@@ -47,12 +47,25 @@ export const defaults = Object.freeze<SystemConfig>({
[QueueName.THUMBNAIL_GENERATION]: { concurrency: 5 },
[QueueName.VIDEO_CONVERSION]: { concurrency: 1 },
},
machineLearning: {
enabled: process.env.IMMICH_MACHINE_LEARNING_ENABLED !== 'false',
url: process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003',
facialRecognitionEnabled: true,
tagImageEnabled: true,
clipEncodeEnabled: true,
classification: {
enabled: true,
modelName: 'microsoft/resnet-50',
minScore: 0.9,
},
clip: {
enabled: true,
modelName: 'ViT-B-32::openai',
},
facialRecognition: {
enabled: true,
modelName: 'buffalo_l',
minScore: 0.7,
maxDistance: 0.6,
},
},
oauth: {
enabled: false,
@@ -143,9 +156,9 @@ export class SystemConfigCore {
const mlEnabled = config.machineLearning.enabled;
return {
[FeatureFlag.CLIP_ENCODE]: mlEnabled && config.machineLearning.clipEncodeEnabled,
[FeatureFlag.FACIAL_RECOGNITION]: mlEnabled && config.machineLearning.facialRecognitionEnabled,
[FeatureFlag.TAG_IMAGE]: mlEnabled && config.machineLearning.tagImageEnabled,
[FeatureFlag.CLIP_ENCODE]: mlEnabled && config.machineLearning.clip.enabled,
[FeatureFlag.FACIAL_RECOGNITION]: mlEnabled && config.machineLearning.facialRecognition.enabled,
[FeatureFlag.TAG_IMAGE]: mlEnabled && config.machineLearning.classification.enabled,
[FeatureFlag.SIDECAR]: true,
[FeatureFlag.SEARCH]: process.env.TYPESENSE_ENABLED !== 'false',
@@ -230,7 +243,7 @@ export class SystemConfigCore {
_.set(config, key, value);
}
return _.defaultsDeep(config, defaults) as SystemConfig;
return plainToClass(SystemConfigDto, _.defaultsDeep(config, defaults));
}
private async loadFromFile(filepath: string, force = false) {

View File

@@ -49,9 +49,21 @@ const updatedConfig = Object.freeze<SystemConfig>({
machineLearning: {
enabled: true,
url: 'http://immich-machine-learning:3003',
facialRecognitionEnabled: true,
tagImageEnabled: true,
clipEncodeEnabled: true,
classification: {
enabled: true,
modelName: 'microsoft/resnet-50',
minScore: 0.9,
},
clip: {
enabled: true,
modelName: 'ViT-B-32::openai',
},
facialRecognition: {
enabled: true,
modelName: 'buffalo_l',
minScore: 0.7,
maxDistance: 0.6,
},
},
oauth: {
autoLaunch: true,