feat(web,server)!: configure machine learning via the UI (#3768)

This commit is contained in:
Jason Rasmussen
2023-08-25 00:15:03 -04:00
committed by GitHub
parent 2cccef174a
commit 8211afb726
52 changed files with 831 additions and 649 deletions

View File

@@ -1,5 +1,4 @@
import { AssetType } from '@app/infra/entities';
import { BadRequestException } from '@nestjs/common';
import { Duration } from 'luxon';
import { extname } from 'node:path';
import pkg from 'src/../../package.json';
@@ -24,17 +23,6 @@ export const SERVER_VERSION = `${serverVersion.major}.${serverVersion.minor}.${s
export const APP_MEDIA_LOCATION = process.env.IMMICH_MEDIA_LOCATION || './upload';
export const SEARCH_ENABLED = process.env.TYPESENSE_ENABLED !== 'false';
export const MACHINE_LEARNING_URL = process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003';
export const MACHINE_LEARNING_ENABLED = MACHINE_LEARNING_URL !== 'false';
export function assertMachineLearningEnabled() {
if (!MACHINE_LEARNING_ENABLED) {
throw new BadRequestException('Machine learning is not enabled.');
}
}
const image: Record<string, string[]> = {
'.3fr': ['image/3fr', 'image/x-hasselblad-3fr'],
'.ari': ['image/ari', 'image/x-arriflex-ari'],

View File

@@ -9,6 +9,7 @@ import {
newPersonRepositoryMock,
newSearchRepositoryMock,
newStorageRepositoryMock,
newSystemConfigRepositoryMock,
personStub,
} from '@test';
import { IAssetRepository, WithoutProperty } from '../asset';
@@ -18,6 +19,7 @@ import { IPersonRepository } from '../person';
import { ISearchRepository } from '../search';
import { IMachineLearningRepository } from '../smart-info';
import { IStorageRepository } from '../storage';
import { ISystemConfigRepository } from '../system-config';
import { IFaceRepository } from './face.repository';
import { FacialRecognitionService } from './facial-recognition.services';
@@ -94,6 +96,7 @@ const faceSearch = {
describe(FacialRecognitionService.name, () => {
let sut: FacialRecognitionService;
let assetMock: jest.Mocked<IAssetRepository>;
let configMock: jest.Mocked<ISystemConfigRepository>;
let faceMock: jest.Mocked<IFaceRepository>;
let jobMock: jest.Mocked<IJobRepository>;
let machineLearningMock: jest.Mocked<IMachineLearningRepository>;
@@ -104,6 +107,7 @@ describe(FacialRecognitionService.name, () => {
beforeEach(async () => {
assetMock = newAssetRepositoryMock();
configMock = newSystemConfigRepositoryMock();
faceMock = newFaceRepositoryMock();
jobMock = newJobRepositoryMock();
machineLearningMock = newMachineLearningRepositoryMock();
@@ -116,6 +120,7 @@ describe(FacialRecognitionService.name, () => {
sut = new FacialRecognitionService(
assetMock,
configMock,
faceMock,
jobMock,
machineLearningMock,
@@ -174,7 +179,7 @@ describe(FacialRecognitionService.name, () => {
machineLearningMock.detectFaces.mockResolvedValue([]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleRecognizeFaces({ id: assetStub.image.id });
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith({
expect(machineLearningMock.detectFaces).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: assetStub.image.resizePath,
});
expect(faceMock.create).not.toHaveBeenCalled();

View File

@@ -1,7 +1,6 @@
import { Inject, Logger } from '@nestjs/common';
import { join } from 'path';
import { IAssetRepository, WithoutProperty } from '../asset';
import { MACHINE_LEARNING_ENABLED } from '../domain.constant';
import { usePagination } from '../domain.util';
import { IBaseJob, IEntityJob, IFaceThumbnailJob, IJobRepository, JobName, JOBS_ASSET_PAGINATION_SIZE } from '../job';
import { CropOptions, FACE_THUMBNAIL_SIZE, IMediaRepository } from '../media';
@@ -9,14 +8,17 @@ import { IPersonRepository } from '../person/person.repository';
import { ISearchRepository } from '../search/search.repository';
import { IMachineLearningRepository } from '../smart-info';
import { IStorageRepository, StorageCore, StorageFolder } from '../storage';
import { ISystemConfigRepository, SystemConfigCore } from '../system-config';
import { AssetFaceId, IFaceRepository } from './face.repository';
export class FacialRecognitionService {
private logger = new Logger(FacialRecognitionService.name);
private storageCore = new StorageCore();
private configCore: SystemConfigCore;
constructor(
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
@Inject(IFaceRepository) private faceRepository: IFaceRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
@@ -24,9 +26,16 @@ export class FacialRecognitionService {
@Inject(IPersonRepository) private personRepository: IPersonRepository,
@Inject(ISearchRepository) private searchRepository: ISearchRepository,
@Inject(IStorageRepository) private storageRepository: IStorageRepository,
) {}
) {
this.configCore = new SystemConfigCore(configRepository);
}
async handleQueueRecognizeFaces({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
return true;
}
const assetPagination = usePagination(JOBS_ASSET_PAGINATION_SIZE, (pagination) => {
return force
? this.assetRepository.getAll(pagination, { order: 'DESC' })
@@ -49,12 +58,17 @@ export class FacialRecognitionService {
}
async handleRecognizeFaces({ id }: IEntityJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
return true;
}
const [asset] = await this.assetRepository.getByIds([id]);
if (!asset || !MACHINE_LEARNING_ENABLED || !asset.resizePath) {
if (!asset || !asset.resizePath) {
return false;
}
const faces = await this.machineLearning.detectFaces({ imagePath: asset.resizePath });
const faces = await this.machineLearning.detectFaces(machineLearning.url, { imagePath: asset.resizePath });
this.logger.debug(`${faces.length} faces detected in ${asset.resizePath}`);
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `float[${face.embedding.length}]` })));
@@ -100,6 +114,11 @@ export class FacialRecognitionService {
}
async handleGenerateFaceThumbnail(data: IFaceThumbnailJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.facialRecognitionEnabled) {
return true;
}
const { assetId, personId, boundingBox, imageWidth, imageHeight } = data;
const [asset] = await this.assetRepository.getByIds([assetId]);

View File

@@ -2,8 +2,7 @@ import { AssetType } from '@app/infra/entities';
import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common';
import { IAssetRepository, mapAsset } from '../asset';
import { CommunicationEvent, ICommunicationRepository } from '../communication';
import { assertMachineLearningEnabled } from '../domain.constant';
import { ISystemConfigRepository } from '../system-config';
import { FeatureFlag, ISystemConfigRepository } from '../system-config';
import { SystemConfigCore } from '../system-config/system-config.core';
import { JobCommand, JobName, QueueName } from './job.constants';
import { AllJobStatusResponseDto, JobCommandDto, JobStatusDto } from './job.dto';
@@ -78,23 +77,25 @@ export class JobService {
return this.jobRepository.queue({ name: JobName.STORAGE_TEMPLATE_MIGRATION });
case QueueName.OBJECT_TAGGING:
assertMachineLearningEnabled();
await this.configCore.requireFeature(FeatureFlag.TAG_IMAGE);
return this.jobRepository.queue({ name: JobName.QUEUE_OBJECT_TAGGING, data: { force } });
case QueueName.CLIP_ENCODING:
assertMachineLearningEnabled();
await this.configCore.requireFeature(FeatureFlag.CLIP_ENCODE);
return this.jobRepository.queue({ name: JobName.QUEUE_ENCODE_CLIP, data: { force } });
case QueueName.METADATA_EXTRACTION:
return this.jobRepository.queue({ name: JobName.QUEUE_METADATA_EXTRACTION, data: { force } });
case QueueName.SIDECAR:
await this.configCore.requireFeature(FeatureFlag.SIDECAR);
return this.jobRepository.queue({ name: JobName.QUEUE_SIDECAR, data: { force } });
case QueueName.THUMBNAIL_GENERATION:
return this.jobRepository.queue({ name: JobName.QUEUE_GENERATE_THUMBNAILS, data: { force } });
case QueueName.RECOGNIZE_FACES:
await this.configCore.requireFeature(FeatureFlag.FACIAL_RECOGNITION);
return this.jobRepository.queue({ name: JobName.QUEUE_RECOGNIZE_FACES, data: { force } });
default:

View File

@@ -1,3 +1,2 @@
export * from './search-config-response.dto';
export * from './search-explore.response.dto';
export * from './search-response.dto';

View File

@@ -1,3 +0,0 @@
export class SearchConfigResponseDto {
enabled!: boolean;
}

View File

@@ -1,5 +1,3 @@
import { BadRequestException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import {
albumStub,
assetStub,
@@ -12,12 +10,14 @@ import {
newJobRepositoryMock,
newMachineLearningRepositoryMock,
newSearchRepositoryMock,
newSystemConfigRepositoryMock,
searchStub,
} from '@test';
import { plainToInstance } from 'class-transformer';
import { IAlbumRepository } from '../album/album.repository';
import { IAssetRepository } from '../asset/asset.repository';
import { IFaceRepository } from '../facial-recognition';
import { ISystemConfigRepository } from '../index';
import { JobName } from '../job';
import { IJobRepository } from '../job/job.repository';
import { IMachineLearningRepository } from '../smart-info';
@@ -31,29 +31,26 @@ describe(SearchService.name, () => {
let sut: SearchService;
let albumMock: jest.Mocked<IAlbumRepository>;
let assetMock: jest.Mocked<IAssetRepository>;
let configMock: jest.Mocked<ISystemConfigRepository>;
let faceMock: jest.Mocked<IFaceRepository>;
let jobMock: jest.Mocked<IJobRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>;
let searchMock: jest.Mocked<ISearchRepository>;
let configMock: jest.Mocked<ConfigService>;
const makeSut = (value?: string) => {
if (value) {
configMock.get.mockReturnValue(value);
}
return new SearchService(albumMock, assetMock, faceMock, jobMock, machineMock, searchMock, configMock);
};
beforeEach(() => {
beforeEach(async () => {
albumMock = newAlbumRepositoryMock();
assetMock = newAssetRepositoryMock();
configMock = newSystemConfigRepositoryMock();
faceMock = newFaceRepositoryMock();
jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock();
searchMock = newSearchRepositoryMock();
configMock = { get: jest.fn() } as unknown as jest.Mocked<ConfigService>;
sut = makeSut();
sut = new SearchService(albumMock, assetMock, configMock, faceMock, jobMock, machineMock, searchMock);
searchMock.checkMigrationStatus.mockResolvedValue({ assets: false, albums: false, faces: false });
await sut.init();
});
afterEach(() => {
@@ -86,45 +83,18 @@ describe(SearchService.name, () => {
});
});
describe('isEnabled', () => {
it('should be enabled by default', () => {
expect(sut.isEnabled()).toBe(true);
});
it('should be disabled via an env variable', () => {
const sut = makeSut('false');
expect(sut.isEnabled()).toBe(false);
});
});
describe('getConfig', () => {
it('should return the config', () => {
expect(sut.getConfig()).toEqual({ enabled: true });
});
it('should return the config when search is disabled', () => {
const sut = makeSut('false');
expect(sut.getConfig()).toEqual({ enabled: false });
});
});
describe(`init`, () => {
it('should skip when search is disabled', async () => {
const sut = makeSut('false');
// it('should skip when search is disabled', async () => {
// await sut.init();
await sut.init();
// expect(searchMock.setup).not.toHaveBeenCalled();
// expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled();
// expect(jobMock.queue).not.toHaveBeenCalled();
expect(searchMock.setup).not.toHaveBeenCalled();
expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
sut.teardown();
});
// sut.teardown();
// });
it('should skip schema migration if not needed', async () => {
searchMock.checkMigrationStatus.mockResolvedValue({ assets: false, albums: false, faces: false });
await sut.init();
expect(searchMock.setup).toHaveBeenCalled();
@@ -145,14 +115,14 @@ describe(SearchService.name, () => {
});
describe('search', () => {
it('should throw an error is search is disabled', async () => {
const sut = makeSut('false');
// it('should throw an error is search is disabled', async () => {
// sut['enabled'] = false;
await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException);
// await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException);
expect(searchMock.searchAlbums).not.toHaveBeenCalled();
expect(searchMock.searchAssets).not.toHaveBeenCalled();
});
// expect(searchMock.searchAlbums).not.toHaveBeenCalled();
// expect(searchMock.searchAssets).not.toHaveBeenCalled();
// });
it('should search assets and albums', async () => {
searchMock.searchAssets.mockResolvedValue(searchStub.emptyResults);
@@ -205,7 +175,7 @@ describe(SearchService.name, () => {
});
it('should skip if search is disabled', async () => {
const sut = makeSut('false');
sut['enabled'] = false;
await sut.handleIndexAssets();
@@ -216,7 +186,7 @@ describe(SearchService.name, () => {
describe('handleIndexAsset', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleIndexAsset({ ids: [assetStub.image.id] });
});
@@ -227,7 +197,7 @@ describe(SearchService.name, () => {
describe('handleIndexAlbums', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleIndexAlbums();
});
@@ -242,7 +212,7 @@ describe(SearchService.name, () => {
describe('handleIndexAlbum', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleIndexAlbum({ ids: [albumStub.empty.id] });
});
@@ -253,7 +223,7 @@ describe(SearchService.name, () => {
describe('handleRemoveAlbum', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleRemoveAlbum({ ids: ['album1'] });
});
@@ -264,7 +234,7 @@ describe(SearchService.name, () => {
describe('handleRemoveAsset', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleRemoveAsset({ ids: ['asset1'] });
});
@@ -305,7 +275,7 @@ describe(SearchService.name, () => {
});
it('should skip if search is disabled', async () => {
const sut = makeSut('false');
sut['enabled'] = false;
await sut.handleIndexFaces();
@@ -315,7 +285,7 @@ describe(SearchService.name, () => {
describe('handleIndexAsset', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleIndexFace({ assetId: 'asset-1', personId: 'person-1' });
expect(searchMock.importFaces).not.toHaveBeenCalled();
@@ -333,7 +303,7 @@ describe(SearchService.name, () => {
describe('handleRemoveFace', () => {
it('should skip if search is disabled', () => {
const sut = makeSut('false');
sut['enabled'] = false;
sut.handleRemoveFace({ assetId: 'asset-1', personId: 'person-1' });
});

View File

@@ -1,18 +1,17 @@
import { AlbumEntity, AssetEntity, AssetFaceEntity } from '@app/infra/entities';
import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Inject, Injectable, Logger } from '@nestjs/common';
import { mapAlbumWithAssets } from '../album';
import { IAlbumRepository } from '../album/album.repository';
import { AssetResponseDto, mapAsset } from '../asset';
import { IAssetRepository } from '../asset/asset.repository';
import { AuthUserDto } from '../auth';
import { MACHINE_LEARNING_ENABLED } from '../domain.constant';
import { usePagination } from '../domain.util';
import { AssetFaceId, IFaceRepository } from '../facial-recognition';
import { IAssetFaceJob, IBulkEntityJob, IJobRepository, JobName, JOBS_ASSET_PAGINATION_SIZE } from '../job';
import { IMachineLearningRepository } from '../smart-info';
import { FeatureFlag, ISystemConfigRepository, SystemConfigCore } from '../system-config';
import { SearchDto } from './dto';
import { SearchConfigResponseDto, SearchResponseDto } from './response-dto';
import { SearchResponseDto } from './response-dto';
import {
ISearchRepository,
OwnedFaceEntity,
@@ -30,8 +29,9 @@ interface SyncQueue {
@Injectable()
export class SearchService {
private logger = new Logger(SearchService.name);
private enabled: boolean;
private enabled = false;
private timer: NodeJS.Timer | null = null;
private configCore: SystemConfigCore;
private albumQueue: SyncQueue = {
upsert: new Set(),
@@ -51,16 +51,13 @@ export class SearchService {
constructor(
@Inject(IAlbumRepository) private albumRepository: IAlbumRepository,
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
@Inject(IFaceRepository) private faceRepository: IFaceRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
@Inject(ISearchRepository) private searchRepository: ISearchRepository,
configService: ConfigService,
) {
this.enabled = configService.get('TYPESENSE_ENABLED') !== 'false';
if (this.enabled) {
this.timer = setInterval(() => this.flush(), 5_000);
}
this.configCore = new SystemConfigCore(configRepository);
}
teardown() {
@@ -70,17 +67,8 @@ export class SearchService {
}
}
isEnabled() {
return this.enabled;
}
getConfig(): SearchConfigResponseDto {
return {
enabled: this.enabled,
};
}
async init() {
this.enabled = await this.configCore.hasFeature(FeatureFlag.SEARCH);
if (!this.enabled) {
return;
}
@@ -101,10 +89,13 @@ export class SearchService {
this.logger.debug('Queueing job to re-index all faces');
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_FACES });
}
this.timer = setInterval(() => this.flush(), 5_000);
}
async getExploreData(authUser: AuthUserDto): Promise<SearchExploreItem<AssetResponseDto>[]> {
this.assertEnabled();
await this.configCore.requireFeature(FeatureFlag.SEARCH);
const results = await this.searchRepository.explore(authUser.id);
const lookup = await this.getLookupMap(
results.reduce(
@@ -126,16 +117,18 @@ export class SearchService {
}
async search(authUser: AuthUserDto, dto: SearchDto): Promise<SearchResponseDto> {
this.assertEnabled();
const { machineLearning } = await this.configCore.getConfig();
await this.configCore.requireFeature(FeatureFlag.SEARCH);
const query = dto.q || dto.query || '*';
const strategy = dto.clip && MACHINE_LEARNING_ENABLED ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const hasClip = machineLearning.enabled && machineLearning.clipEncodeEnabled;
const strategy = dto.clip && hasClip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const filters = { userId: authUser.id, ...dto };
let assets: SearchResult<AssetEntity>;
switch (strategy) {
case SearchStrategy.CLIP:
const clip = await this.machineLearning.encodeText(query);
const clip = await this.machineLearning.encodeText(machineLearning.url, query);
assets = await this.searchRepository.vectorSearch(clip, filters);
break;
case SearchStrategy.TEXT:
@@ -333,12 +326,6 @@ export class SearchService {
}
}
private assertEnabled() {
if (!this.enabled) {
throw new BadRequestException('Search is disabled');
}
}
private async idsToAlbums(ids: string[]): Promise<AlbumEntity[]> {
const entities = await this.albumRepository.getByIds(ids);
return this.patchAlbums(entities);

View File

@@ -1,4 +1,4 @@
import { IServerVersion } from '@app/domain';
import { FeatureFlags, IServerVersion } from '@app/domain';
import { ApiProperty, ApiResponseProperty } from '@nestjs/swagger';
export class ServerPingResponse {
@@ -79,10 +79,14 @@ export class ServerMediaTypesResponseDto {
sidecar!: string[];
}
export class ServerFeaturesDto {
machineLearning!: boolean;
export class ServerFeaturesDto implements FeatureFlags {
clipEncode!: boolean;
facialRecognition!: boolean;
sidecar!: boolean;
search!: boolean;
tagImage!: boolean;
// TODO: use these instead of `POST oauth/config`
oauth!: boolean;
oauthAutoLaunch!: boolean;
passwordLogin!: boolean;

View File

@@ -147,11 +147,14 @@ describe(ServerInfoService.name, () => {
describe('getFeatures', () => {
it('should respond the server features', async () => {
await expect(sut.getFeatures()).resolves.toEqual({
machineLearning: true,
clipEncode: true,
facialRecognition: true,
oauth: false,
oauthAutoLaunch: false,
passwordLogin: true,
search: true,
sidecar: true,
tagImage: true,
});
expect(configMock.load).toHaveBeenCalled();
});

View File

@@ -1,9 +1,8 @@
import { Inject, Injectable } from '@nestjs/common';
import { MACHINE_LEARNING_ENABLED, mimeTypes, SEARCH_ENABLED, serverVersion } from '../domain.constant';
import { mimeTypes, serverVersion } from '../domain.constant';
import { asHumanReadable } from '../domain.util';
import { IStorageRepository, StorageCore, StorageFolder } from '../storage';
import { ISystemConfigRepository } from '../system-config';
import { SystemConfigCore } from '../system-config/system-config.core';
import { ISystemConfigRepository, SystemConfigCore } from '../system-config';
import { IUserRepository, UserStatsQueryResponse } from '../user';
import {
ServerFeaturesDto,
@@ -52,18 +51,8 @@ export class ServerInfoService {
return serverVersion;
}
async getFeatures(): Promise<ServerFeaturesDto> {
const config = await this.configCore.getConfig();
return {
machineLearning: MACHINE_LEARNING_ENABLED,
search: SEARCH_ENABLED,
// TODO: use these instead of `POST oauth/config`
oauth: config.oauth.enabled,
oauthAutoLaunch: config.oauth.autoLaunch,
passwordLogin: config.passwordLogin.enabled,
};
getFeatures(): Promise<ServerFeaturesDto> {
return this.configCore.getFeatures();
}
async getStats(): Promise<ServerStatsResponseDto> {

View File

@@ -20,8 +20,8 @@ export interface DetectFaceResult {
}
export interface IMachineLearningRepository {
classifyImage(input: MachineLearningInput): Promise<string[]>;
encodeImage(input: MachineLearningInput): Promise<number[]>;
encodeText(input: string): Promise<number[]>;
detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]>;
classifyImage(url: string, input: MachineLearningInput): Promise<string[]>;
encodeImage(url: string, input: MachineLearningInput): Promise<number[]>;
encodeText(url: string, input: string): Promise<number[]>;
detectFaces(url: string, input: MachineLearningInput): Promise<DetectFaceResult[]>;
}

View File

@@ -5,9 +5,11 @@ import {
newJobRepositoryMock,
newMachineLearningRepositoryMock,
newSmartInfoRepositoryMock,
newSystemConfigRepositoryMock,
} from '@test';
import { IAssetRepository, WithoutProperty } from '../asset';
import { IJobRepository, JobName } from '../job';
import { ISystemConfigRepository } from '../system-config';
import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository';
import { SmartInfoService } from './smart-info.service';
@@ -20,16 +22,18 @@ const asset = {
describe(SmartInfoService.name, () => {
let sut: SmartInfoService;
let assetMock: jest.Mocked<IAssetRepository>;
let configMock: jest.Mocked<ISystemConfigRepository>;
let jobMock: jest.Mocked<IJobRepository>;
let smartMock: jest.Mocked<ISmartInfoRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>;
beforeEach(async () => {
assetMock = newAssetRepositoryMock();
configMock = newSystemConfigRepositoryMock();
smartMock = newSmartInfoRepositoryMock();
jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock();
sut = new SmartInfoService(assetMock, jobMock, smartMock, machineMock);
sut = new SmartInfoService(assetMock, configMock, jobMock, smartMock, machineMock);
assetMock.getByIds.mockResolvedValue([asset]);
});
@@ -80,7 +84,9 @@ describe(SmartInfoService.name, () => {
await sut.handleClassifyImage({ id: asset.id });
expect(machineMock.classifyImage).toHaveBeenCalledWith({ imagePath: 'path/to/resize.ext' });
expect(machineMock.classifyImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
tags: ['tag1', 'tag2', 'tag3'],
@@ -139,7 +145,9 @@ describe(SmartInfoService.name, () => {
await sut.handleEncodeClip({ id: asset.id });
expect(machineMock.encodeImage).toHaveBeenCalledWith({ imagePath: 'path/to/resize.ext' });
expect(machineMock.encodeImage).toHaveBeenCalledWith('http://immich-machine-learning:3003', {
imagePath: 'path/to/resize.ext',
});
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',
clipEmbedding: [0.01, 0.02, 0.03],

View File

@@ -1,23 +1,31 @@
import { Inject, Injectable, Logger } from '@nestjs/common';
import { Inject, Injectable } from '@nestjs/common';
import { IAssetRepository, WithoutProperty } from '../asset';
import { MACHINE_LEARNING_ENABLED } from '../domain.constant';
import { usePagination } from '../domain.util';
import { IBaseJob, IEntityJob, IJobRepository, JobName, JOBS_ASSET_PAGINATION_SIZE } from '../job';
import { ISystemConfigRepository, SystemConfigCore } from '../system-config';
import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository';
@Injectable()
export class SmartInfoService {
private logger = new Logger(SmartInfoService.name);
private configCore: SystemConfigCore;
constructor(
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISmartInfoRepository) private repository: ISmartInfoRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
) {}
) {
this.configCore = new SystemConfigCore(configRepository);
}
async handleQueueObjectTagging({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
return true;
}
const assetPagination = usePagination(JOBS_ASSET_PAGINATION_SIZE, (pagination) => {
return force
? this.assetRepository.getAll(pagination)
@@ -34,19 +42,28 @@ export class SmartInfoService {
}
async handleClassifyImage({ id }: IEntityJob) {
const [asset] = await this.assetRepository.getByIds([id]);
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.tagImageEnabled) {
return true;
}
if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) {
const [asset] = await this.assetRepository.getByIds([id]);
if (!asset.resizePath) {
return false;
}
const tags = await this.machineLearning.classifyImage({ imagePath: asset.resizePath });
const tags = await this.machineLearning.classifyImage(machineLearning.url, { imagePath: asset.resizePath });
await this.repository.upsert({ assetId: asset.id, tags });
return true;
}
async handleQueueEncodeClip({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
return true;
}
const assetPagination = usePagination(JOBS_ASSET_PAGINATION_SIZE, (pagination) => {
return force
? this.assetRepository.getAll(pagination)
@@ -63,13 +80,17 @@ export class SmartInfoService {
}
async handleEncodeClip({ id }: IEntityJob) {
const [asset] = await this.assetRepository.getByIds([id]);
const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.clipEncodeEnabled) {
return true;
}
if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) {
const [asset] = await this.assetRepository.getByIds([id]);
if (!asset.resizePath) {
return false;
}
const clipEmbedding = await this.machineLearning.encodeImage({ imagePath: asset.resizePath });
const clipEmbedding = await this.machineLearning.encodeImage(machineLearning.url, { imagePath: asset.resizePath });
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
return true;

View File

@@ -0,0 +1,19 @@
import { IsBoolean, IsUrl, ValidateIf } from 'class-validator';
export class SystemConfigMachineLearningDto {
@IsBoolean()
enabled!: boolean;
@IsUrl({ require_tld: false })
@ValidateIf((dto) => dto.enabled)
url!: string;
@IsBoolean()
clipEncodeEnabled!: boolean;
@IsBoolean()
facialRecognitionEnabled!: boolean;
@IsBoolean()
tagImageEnabled!: boolean;
}

View File

@@ -4,16 +4,22 @@ import { Type } from 'class-transformer';
import { IsObject, ValidateNested } from 'class-validator';
import { SystemConfigFFmpegDto } from './system-config-ffmpeg.dto';
import { SystemConfigJobDto } from './system-config-job.dto';
import { SystemConfigMachineLearningDto } from './system-config-machine-learning.dto';
import { SystemConfigOAuthDto } from './system-config-oauth.dto';
import { SystemConfigPasswordLoginDto } from './system-config-password-login.dto';
import { SystemConfigStorageTemplateDto } from './system-config-storage-template.dto';
export class SystemConfigDto {
export class SystemConfigDto implements SystemConfig {
@Type(() => SystemConfigFFmpegDto)
@ValidateNested()
@IsObject()
ffmpeg!: SystemConfigFFmpegDto;
@Type(() => SystemConfigMachineLearningDto)
@ValidateNested()
@IsObject()
machineLearning!: SystemConfigMachineLearningDto;
@Type(() => SystemConfigOAuthDto)
@ValidateNested()
@IsObject()

View File

@@ -1,5 +1,6 @@
export * from './dto';
export * from './response-dto';
export * from './system-config.constants';
export * from './system-config.core';
export * from './system-config.repository';
export * from './system-config.service';

View File

@@ -9,7 +9,7 @@ import {
TranscodePolicy,
VideoCodec,
} from '@app/infra/entities';
import { BadRequestException, Injectable, Logger } from '@nestjs/common';
import { BadRequestException, ForbiddenException, Injectable, Logger } from '@nestjs/common';
import * as _ from 'lodash';
import { Subject } from 'rxjs';
import { DeepPartial } from 'typeorm';
@@ -44,6 +44,13 @@ export const defaults = Object.freeze<SystemConfig>({
[QueueName.THUMBNAIL_GENERATION]: { concurrency: 5 },
[QueueName.VIDEO_CONVERSION]: { concurrency: 1 },
},
machineLearning: {
enabled: process.env.IMMICH_MACHINE_LEARNING_ENABLED !== 'false',
url: process.env.IMMICH_MACHINE_LEARNING_URL || 'http://immich-machine-learning:3003',
facialRecognitionEnabled: true,
tagImageEnabled: true,
clipEncodeEnabled: true,
},
oauth: {
enabled: false,
issuerUrl: '',
@@ -71,6 +78,19 @@ export const defaults = Object.freeze<SystemConfig>({
},
});
export enum FeatureFlag {
CLIP_ENCODE = 'clipEncode',
FACIAL_RECOGNITION = 'facialRecognition',
TAG_IMAGE = 'tagImage',
SIDECAR = 'sidecar',
SEARCH = 'search',
OAUTH = 'oauth',
OAUTH_AUTO_LAUNCH = 'oauthAutoLaunch',
PASSWORD_LOGIN = 'passwordLogin',
}
export type FeatureFlags = Record<FeatureFlag, boolean>;
const singleton = new Subject<SystemConfig>();
@Injectable()
@@ -82,6 +102,53 @@ export class SystemConfigCore {
constructor(private repository: ISystemConfigRepository) {}
async requireFeature(feature: FeatureFlag) {
const hasFeature = await this.hasFeature(feature);
if (!hasFeature) {
switch (feature) {
case FeatureFlag.CLIP_ENCODE:
throw new BadRequestException('Clip encoding is not enabled');
case FeatureFlag.FACIAL_RECOGNITION:
throw new BadRequestException('Facial recognition is not enabled');
case FeatureFlag.TAG_IMAGE:
throw new BadRequestException('Image tagging is not enabled');
case FeatureFlag.SIDECAR:
throw new BadRequestException('Sidecar is not enabled');
case FeatureFlag.SEARCH:
throw new BadRequestException('Search is not enabled');
case FeatureFlag.OAUTH:
throw new BadRequestException('OAuth is not enabled');
case FeatureFlag.PASSWORD_LOGIN:
throw new BadRequestException('Password login is not enabled');
default:
throw new ForbiddenException(`Missing required feature: ${feature}`);
}
}
}
async hasFeature(feature: FeatureFlag) {
const features = await this.getFeatures();
return features[feature] ?? false;
}
async getFeatures(): Promise<FeatureFlags> {
const config = await this.getConfig();
const mlEnabled = config.machineLearning.enabled;
return {
[FeatureFlag.CLIP_ENCODE]: mlEnabled && config.machineLearning.clipEncodeEnabled,
[FeatureFlag.FACIAL_RECOGNITION]: mlEnabled && config.machineLearning.facialRecognitionEnabled,
[FeatureFlag.TAG_IMAGE]: mlEnabled && config.machineLearning.tagImageEnabled,
[FeatureFlag.SIDECAR]: true,
[FeatureFlag.SEARCH]: process.env.TYPESENSE_ENABLED !== 'false',
// TODO: use these instead of `POST oauth/config`
[FeatureFlag.OAUTH]: config.oauth.enabled,
[FeatureFlag.OAUTH_AUTO_LAUNCH]: config.oauth.autoLaunch,
[FeatureFlag.PASSWORD_LOGIN]: config.passwordLogin.enabled,
};
}
public getDefaults(): SystemConfig {
return defaults;
}

View File

@@ -46,6 +46,13 @@ const updatedConfig = Object.freeze<SystemConfig>({
accel: TranscodeHWAccel.DISABLED,
tonemap: ToneMapping.HABLE,
},
machineLearning: {
enabled: true,
url: 'http://immich-machine-learning:3003',
facialRecognitionEnabled: true,
tagImageEnabled: true,
clipEncodeEnabled: true,
},
oauth: {
autoLaunch: true,
autoRegister: true,