From 5929f56a156731c77d289d8af24661c2598494ea Mon Sep 17 00:00:00 2001 From: Luca Nardelli Date: Mon, 11 Oct 2021 00:02:46 +0200 Subject: [PATCH] Custom filter support for typeorm --- packages/core/src/helpers/filter.builder.ts | 2 +- .../core/src/interfaces/filter.interface.ts | 3 +- .../__snapshots__/where.builder.spec.ts.snap | 13 ++ .../query/filter-query.builder.spec.ts | 86 +++++++------ .../__tests__/query/where.builder.spec.ts | 45 ++++++- .../src/query/custom-filter.registry.ts | 54 ++++++++ .../src/query/filter-query.builder.ts | 119 +++++++++++++----- packages/query-typeorm/src/query/index.ts | 1 + .../src/query/relation-query.builder.ts | 18 ++- .../query-typeorm/src/query/where.builder.ts | 113 ++++++++++++----- 10 files changed, 350 insertions(+), 104 deletions(-) create mode 100644 packages/query-typeorm/src/query/custom-filter.registry.ts diff --git a/packages/core/src/helpers/filter.builder.ts b/packages/core/src/helpers/filter.builder.ts index 7b4f7e9a43..e2383f4f98 100644 --- a/packages/core/src/helpers/filter.builder.ts +++ b/packages/core/src/helpers/filter.builder.ts @@ -58,6 +58,6 @@ export class FilterBuilder { throw new Error(`unknown comparison ${JSON.stringify(fieldOrNested)}`); } const nestedFilterFn = this.build(value); - return (dto?: DTO) => nestedFilterFn(dto ? dto[fieldOrNested] : null); + return (dto?: DTO) => nestedFilterFn(dto ? dto[fieldOrNested] : undefined); } } diff --git a/packages/core/src/interfaces/filter.interface.ts b/packages/core/src/interfaces/filter.interface.ts index 1b63d95ecd..263d3f15d9 100644 --- a/packages/core/src/interfaces/filter.interface.ts +++ b/packages/core/src/interfaces/filter.interface.ts @@ -111,5 +111,6 @@ type FilterGrouping = { * ``` * * @typeparam T - the type of object to filter on. + * @typeparam C - custom filters defined on the object. */ -export type Filter = FilterGrouping & FilterComparisons; +export type Filter> = FilterGrouping & FilterComparisons & { [K in keyof C]: C[K] }; diff --git a/packages/query-typeorm/__tests__/query/__snapshots__/where.builder.spec.ts.snap b/packages/query-typeorm/__tests__/query/__snapshots__/where.builder.spec.ts.snap index a0618ae704..a125772d0b 100644 --- a/packages/query-typeorm/__tests__/query/__snapshots__/where.builder.spec.ts.snap +++ b/packages/query-typeorm/__tests__/query/__snapshots__/where.builder.spec.ts.snap @@ -109,3 +109,16 @@ Array [ exports[`WhereBuilder should accept a empty filter 1`] = `SELECT "TestEntity"."test_entity_pk" AS "TestEntity_test_entity_pk", "TestEntity"."string_type" AS "TestEntity_string_type", "TestEntity"."bool_type" AS "TestEntity_bool_type", "TestEntity"."number_type" AS "TestEntity_number_type", "TestEntity"."date_type" AS "TestEntity_date_type", "TestEntity"."oneTestRelationTestRelationPk" AS "TestEntity_oneTestRelationTestRelationPk" FROM "test_entity" "TestEntity"`; exports[`WhereBuilder should accept a empty filter 2`] = `Array []`; + +exports[`WhereBuilder should accept custom filters alongside regular filters 1`] = `SELECT "TestEntity"."test_entity_pk" AS "TestEntity_test_entity_pk", "TestEntity"."string_type" AS "TestEntity_string_type", "TestEntity"."bool_type" AS "TestEntity_bool_type", "TestEntity"."number_type" AS "TestEntity_number_type", "TestEntity"."date_type" AS "TestEntity_date_type", "TestEntity"."oneTestRelationTestRelationPk" AS "TestEntity_oneTestRelationTestRelationPk" FROM "test_entity" "TestEntity" WHERE ("TestEntity"."number_type" >= ? OR "TestEntity"."number_type" <= ? OR ("TestEntity"."numberType" % ?) == 0) AND (ST_Distance("TestEntity"."fakePointType", ST_MakePoint(?,?)) <= ?)`; + +exports[`WhereBuilder should accept custom filters alongside regular filters 2`] = ` +Array [ + 1, + 10, + 5, + 45.3, + 9.5, + 50000, +] +`; diff --git a/packages/query-typeorm/__tests__/query/filter-query.builder.spec.ts b/packages/query-typeorm/__tests__/query/filter-query.builder.spec.ts index 5bda27a779..173a988b07 100644 --- a/packages/query-typeorm/__tests__/query/filter-query.builder.spec.ts +++ b/packages/query-typeorm/__tests__/query/filter-query.builder.spec.ts @@ -1,10 +1,10 @@ -import { anything, instance, mock, verify, when, deepEqual } from 'ts-mockito'; -import { QueryBuilder, WhereExpression } from 'typeorm'; import { Class, Filter, Query, SortDirection, SortNulls } from '@nestjs-query/core'; +import { anything, deepEqual, instance, mock, verify, when } from 'ts-mockito'; +import { QueryBuilder, WhereExpression } from 'typeorm'; +import { CustomFilterRegistry, FilterQueryBuilder, WhereBuilder } from '../../src/query'; import { closeTestConnection, createTestConnection, getTestConnection } from '../__fixtures__/connection.fixture'; import { TestSoftDeleteEntity } from '../__fixtures__/test-soft-delete.entity'; import { TestEntity } from '../__fixtures__/test.entity'; -import { FilterQueryBuilder, WhereBuilder } from '../../src/query'; describe('FilterQueryBuilder', (): void => { beforeEach(createTestConnection); @@ -98,15 +98,23 @@ describe('FilterQueryBuilder', (): void => { it('should not call whereBuilder#build', () => { const mockWhereBuilder = mock>(WhereBuilder); expectSelectSQLSnapshot({}, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should call whereBuilder#build if there is a filter', () => { const mockWhereBuilder = mock>(WhereBuilder); const query = { filter: { stringType: { eq: 'foo' } } }; - when(mockWhereBuilder.build(anything(), query.filter, deepEqual({}), 'TestEntity')).thenCall( - (where: WhereExpression, field: Filter, relationNames: string[], alias: string) => - where.andWhere(`${alias}.stringType = 'foo'`), + when( + mockWhereBuilder.build(anything(), query.filter, deepEqual({}), TestEntity, undefined, 'TestEntity'), + ).thenCall( + ( + where: WhereExpression, + field: Filter, + relationNames: string[], + klass: Class, + customFilters: CustomFilterRegistry, + alias: string, + ) => where.andWhere(`${alias}.stringType = 'foo'`), ); expectSelectSQLSnapshot(query, instance(mockWhereBuilder)); }); @@ -116,19 +124,23 @@ describe('FilterQueryBuilder', (): void => { it('should apply empty paging args', () => { const mockWhereBuilder = mock>(WhereBuilder); expectSelectSQLSnapshot({}, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), deepEqual({}), 'TestEntity')).never(); + verify( + mockWhereBuilder.build(anything(), anything(), deepEqual({}), TestEntity, undefined, 'TestEntity'), + ).never(); }); it('should apply paging args going forward', () => { const mockWhereBuilder = mock>(WhereBuilder); expectSelectSQLSnapshot({ paging: { limit: 10, offset: 11 } }, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), deepEqual({}), 'TestEntity')).never(); + verify( + mockWhereBuilder.build(anything(), anything(), deepEqual({}), TestEntity, undefined, 'TestEntity'), + ).never(); }); it('should apply paging args going backward', () => { const mockWhereBuilder = mock>(WhereBuilder); expectSelectSQLSnapshot({ paging: { limit: 10, offset: 10 } }, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); }); @@ -139,7 +151,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should apply ASC NULLS_FIRST sorting', () => { @@ -148,7 +160,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC, nulls: SortNulls.NULLS_FIRST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should apply ASC NULLS_LAST sorting', () => { @@ -157,7 +169,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC, nulls: SortNulls.NULLS_LAST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should apply DESC sorting', () => { @@ -166,7 +178,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.DESC }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should apply DESC NULLS_FIRST sorting', () => { @@ -183,7 +195,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.DESC, nulls: SortNulls.NULLS_LAST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); it('should apply multiple sorts', () => { @@ -199,7 +211,7 @@ describe('FilterQueryBuilder', (): void => { }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), {}, 'TestEntity')).never(); + verify(mockWhereBuilder.build(anything(), anything(), {}, TestEntity, undefined, 'TestEntity')).never(); }); }); }); @@ -214,9 +226,9 @@ describe('FilterQueryBuilder', (): void => { it('should call whereBuilder#build if there is a filter', () => { const mockWhereBuilder = mock>(WhereBuilder); const query = { filter: { stringType: { eq: 'foo' } } }; - when(mockWhereBuilder.build(anything(), query.filter, deepEqual({}), undefined)).thenCall( - (where: WhereExpression) => where.andWhere(`stringType = 'foo'`), - ); + when( + mockWhereBuilder.build(anything(), query.filter, deepEqual({}), TestEntity, undefined, undefined), + ).thenCall((where: WhereExpression) => where.andWhere(`stringType = 'foo'`)); expectUpdateSQLSnapshot(query, instance(mockWhereBuilder)); }); }); @@ -224,7 +236,7 @@ describe('FilterQueryBuilder', (): void => { it('should ignore paging args', () => { const mockWhereBuilder = mock>(WhereBuilder); expectUpdateSQLSnapshot({ paging: { limit: 10, offset: 11 } }, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); @@ -235,7 +247,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply ASC NULLS_FIRST sorting', () => { @@ -244,7 +256,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC, nulls: SortNulls.NULLS_FIRST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply ASC NULLS_LAST sorting', () => { @@ -253,7 +265,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.ASC, nulls: SortNulls.NULLS_LAST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply DESC sorting', () => { @@ -262,7 +274,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.DESC }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply DESC NULLS_FIRST sorting', () => { @@ -271,7 +283,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.DESC, nulls: SortNulls.NULLS_FIRST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply DESC NULLS_LAST sorting', () => { @@ -280,7 +292,7 @@ describe('FilterQueryBuilder', (): void => { { sorting: [{ field: 'numberType', direction: SortDirection.DESC, nulls: SortNulls.NULLS_LAST }] }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); it('should apply multiple sorts', () => { @@ -296,7 +308,7 @@ describe('FilterQueryBuilder', (): void => { }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); }); @@ -311,9 +323,9 @@ describe('FilterQueryBuilder', (): void => { it('should call whereBuilder#build if there is a filter', () => { const mockWhereBuilder = mock>(WhereBuilder); const query = { filter: { stringType: { eq: 'foo' } } }; - when(mockWhereBuilder.build(anything(), query.filter, deepEqual({}), undefined)).thenCall( - (where: WhereExpression) => where.andWhere(`stringType = 'foo'`), - ); + when( + mockWhereBuilder.build(anything(), query.filter, deepEqual({}), TestEntity, undefined, undefined), + ).thenCall((where: WhereExpression) => where.andWhere(`stringType = 'foo'`)); expectDeleteSQLSnapshot(query, instance(mockWhereBuilder)); }); }); @@ -321,7 +333,7 @@ describe('FilterQueryBuilder', (): void => { it('should ignore paging args', () => { const mockWhereBuilder = mock>(WhereBuilder); expectDeleteSQLSnapshot({ paging: { limit: 10, offset: 11 } }, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); @@ -339,7 +351,7 @@ describe('FilterQueryBuilder', (): void => { }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); }); @@ -357,9 +369,9 @@ describe('FilterQueryBuilder', (): void => { it('should call whereBuilder#build if there is a filter', () => { const mockWhereBuilder = mock>(WhereBuilder); const query = { filter: { stringType: { eq: 'foo' } } }; - when(mockWhereBuilder.build(anything(), query.filter, deepEqual({}), undefined)).thenCall( - (where: WhereExpression) => where.andWhere(`stringType = 'foo'`), - ); + when( + mockWhereBuilder.build(anything(), query.filter, deepEqual({}), TestSoftDeleteEntity, undefined, undefined), + ).thenCall((where: WhereExpression) => where.andWhere(`stringType = 'foo'`)); expectSoftDeleteSQLSnapshot(query, instance(mockWhereBuilder)); }); }); @@ -367,7 +379,7 @@ describe('FilterQueryBuilder', (): void => { it('should ignore paging args', () => { const mockWhereBuilder = mock>(WhereBuilder); expectSoftDeleteSQLSnapshot({ paging: { limit: 10, offset: 11 } }, instance(mockWhereBuilder)); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); @@ -383,7 +395,7 @@ describe('FilterQueryBuilder', (): void => { }, instance(mockWhereBuilder), ); - verify(mockWhereBuilder.build(anything(), anything(), anything())).never(); + verify(mockWhereBuilder.build(anything(), anything(), anything(), anything())).never(); }); }); }); diff --git a/packages/query-typeorm/__tests__/query/where.builder.spec.ts b/packages/query-typeorm/__tests__/query/where.builder.spec.ts index 2e386092e9..65e690c843 100644 --- a/packages/query-typeorm/__tests__/query/where.builder.spec.ts +++ b/packages/query-typeorm/__tests__/query/where.builder.spec.ts @@ -1,7 +1,8 @@ import { Filter } from '@nestjs-query/core'; +import { randomString } from '../../src/common'; +import { CustomFilterRegistry, CustomFilterResult, WhereBuilder } from '../../src/query'; import { closeTestConnection, createTestConnection, getTestConnection } from '../__fixtures__/connection.fixture'; import { TestEntity } from '../__fixtures__/test.entity'; -import { WhereBuilder } from '../../src/query'; describe('WhereBuilder', (): void => { beforeEach(createTestConnection); @@ -11,8 +12,40 @@ describe('WhereBuilder', (): void => { const getQueryBuilder = () => getRepo().createQueryBuilder(); const createWhereBuilder = () => new WhereBuilder(); + const customFilterRegistry = new CustomFilterRegistry(); + customFilterRegistry.setFilter(TestEntity, 'numberType', 'isMultipleOf', { + apply(field, cmp, val: number, alias): CustomFilterResult { + alias = alias ? alias : ''; + const pname = `param${randomString()}`; + return { + sql: `("${alias}"."${field}" % :${pname}) == 0`, + params: { [pname]: val }, + }; + }, + }); + // This property does not actually exist in the entity, but since we are testing only the generated SQL it's ok. + customFilterRegistry.setFilter(TestEntity, 'fakePointType', 'distanceFrom', { + apply(field, cmp, val: { point: { lat: number; lng: number }; radius: number }, alias): CustomFilterResult { + alias = alias ? alias : ''; + const plat = `param${randomString()}`; + const plng = `param${randomString()}`; + const prad = `param${randomString()}`; + return { + sql: `ST_Distance("${alias}"."${field}", ST_MakePoint(:${plat},:${plng})) <= :${prad}`, + params: { [plat]: val.point.lat, [plng]: val.point.lng, [prad]: val.radius }, + }; + }, + }); + const expectSQLSnapshot = (filter: Filter): void => { - const selectQueryBuilder = createWhereBuilder().build(getQueryBuilder(), filter, {}, 'TestEntity'); + const selectQueryBuilder = createWhereBuilder().build( + getQueryBuilder(), + filter, + {}, + TestEntity, + customFilterRegistry, + 'TestEntity', + ); const [sql, params] = selectQueryBuilder.getQueryAndParameters(); expect(sql).toMatchSnapshot(); expect(params).toMatchSnapshot(); @@ -30,6 +63,14 @@ describe('WhereBuilder', (): void => { expectSQLSnapshot({ numberType: { eq: 1 }, stringType: { like: 'foo%' }, boolType: { is: true } }); }); + // TODO Fix typings to avoid usage of any + it('should accept custom filters alongside regular filters', (): void => { + expectSQLSnapshot({ + numberType: { gte: 1, lte: 10, isMultipleOf: 5 }, + fakePointType: { distanceFrom: { point: { lat: 45.3, lng: 9.5 }, radius: 50000 } }, + } as any); + }); + describe('and', (): void => { it('and multiple expressions together', (): void => { expectSQLSnapshot({ diff --git a/packages/query-typeorm/src/query/custom-filter.registry.ts b/packages/query-typeorm/src/query/custom-filter.registry.ts new file mode 100644 index 0000000000..4ac7f71610 --- /dev/null +++ b/packages/query-typeorm/src/query/custom-filter.registry.ts @@ -0,0 +1,54 @@ +import { Class } from '@nestjs-query/core'; +import merge from 'lodash.merge'; +import { ObjectLiteral } from 'typeorm'; + +/** + * @internal + */ +export type CustomFilterResult = { sql: string; params: ObjectLiteral }; + +/** + * @internal + * Used to create custom filters + */ +export interface CustomFilter { + apply(field: keyof Entity & string, cmp: string, val: T, alias?: string): CustomFilterResult; +} + +type EntityCustomFilters = Record>; + +export class CustomFilterRegistry { + private registry: Map, EntityCustomFilters> = new Map(); + + getEntityFilters(klass: Class): EntityCustomFilters { + if (!this.registry.has(klass)) { + this.registry.set(klass, {}); + } + return this.registry.get(klass) as EntityCustomFilters; + } + + getFieldFilters(klass: Class, field: string | keyof Entity): Record { + return this.getEntityFilters(klass)[field]; + } + + getFilter( + klass: Class, + field: string | keyof Entity, + opName: string, + ): CustomFilter | undefined { + return this.getFieldFilters(klass, field)?.[opName]; + } + + setFilter( + klass: Class, + field: keyof Entity | string, + opName: string, + filter: CustomFilter, + ): void { + if (!this.registry.has(klass)) { + this.registry.set(klass, {}); + } + const klassFilters = this.registry.get(klass) as EntityCustomFilters; + klassFilters[field] = merge(klassFilters[field], { [opName]: filter }); + } +} diff --git a/packages/query-typeorm/src/query/filter-query.builder.ts b/packages/query-typeorm/src/query/filter-query.builder.ts index 13d1383f6f..0a037d2e2f 100644 --- a/packages/query-typeorm/src/query/filter-query.builder.ts +++ b/packages/query-typeorm/src/query/filter-query.builder.ts @@ -1,17 +1,18 @@ -import { Filter, Paging, Query, SortField, getFilterFields, AggregateQuery } from '@nestjs-query/core'; +import { AggregateQuery, Class, Filter, getFilterFields, Paging, Query, SortField } from '@nestjs-query/core'; +import merge from 'lodash.merge'; import { DeleteQueryBuilder, + EntityMetadata, QueryBuilder, Repository, SelectQueryBuilder, UpdateQueryBuilder, WhereExpression, - EntityMetadata, } from 'typeorm'; import { SoftDeleteQueryBuilder } from 'typeorm/query-builder/SoftDeleteQueryBuilder'; import { AggregateBuilder } from './aggregate.builder'; +import { CustomFilterRegistry } from './custom-filter.registry'; import { WhereBuilder } from './where.builder'; -import merge from 'lodash.merge'; /** * @internal @@ -33,8 +34,11 @@ interface Groupable extends QueryBuilder { */ interface Pageable extends QueryBuilder { limit(limit?: number): this; + offset(offset?: number): this; + skip(skip?: number): this; + take(take?: number): this; } @@ -55,6 +59,13 @@ interface R { */ export type NestedRecord = R; +export interface RelationMeta { + targetKlass: Class; + relations: Record; +} + +export type RelationsMeta = Record; + /** * @internal * @@ -67,6 +78,10 @@ export class FilterQueryBuilder { readonly aggregateBuilder: AggregateBuilder = new AggregateBuilder(), ) {} + private get relationNames(): string[] { + return this.repo.metadata.relations.map((r) => r.propertyName); + } + /** * Create a `typeorm` SelectQueryBuilder with `WHERE`, `ORDER BY` and `LIMIT/OFFSET` clauses. * @@ -75,10 +90,11 @@ export class FilterQueryBuilder { select(query: Query): SelectQueryBuilder { const hasRelations = this.filterHasRelations(query.filter); let qb = this.createQueryBuilder(); + const klass = this.repo.metadata.target as Class; qb = hasRelations ? this.applyRelationJoinsRecursive(qb, this.getReferencedRelationsRecursive(this.repo.metadata, query.filter)) : qb; - qb = this.applyFilter(qb, query.filter, qb.alias); + qb = this.applyFilter(qb, klass, undefined, query.filter, qb.alias); qb = this.applySorting(qb, query.sorting, qb.alias); qb = this.applyPaging(qb, query.paging, hasRelations); return qb; @@ -87,11 +103,12 @@ export class FilterQueryBuilder { selectById(id: string | number | (string | number)[], query: Query): SelectQueryBuilder { const hasRelations = this.filterHasRelations(query.filter); let qb = this.createQueryBuilder(); + const klass = this.repo.metadata.target as Class; qb = hasRelations ? this.applyRelationJoinsRecursive(qb, this.getReferencedRelationsRecursive(this.repo.metadata, query.filter)) : qb; qb = qb.andWhereInIds(id); - qb = this.applyFilter(qb, query.filter, qb.alias); + qb = this.applyFilter(qb, klass, undefined, query.filter, qb.alias); qb = this.applySorting(qb, query.sorting, qb.alias); qb = this.applyPaging(qb, query.paging, hasRelations); return qb; @@ -99,8 +116,9 @@ export class FilterQueryBuilder { aggregate(query: Query, aggregate: AggregateQuery): SelectQueryBuilder { let qb = this.createQueryBuilder(); + const klass = this.repo.metadata.target as Class; qb = this.applyAggregate(qb, aggregate, qb.alias); - qb = this.applyFilter(qb, query.filter, qb.alias); + qb = this.applyFilter(qb, klass, undefined, query.filter, qb.alias); qb = this.applyAggregateSorting(qb, aggregate.groupBy, qb.alias); qb = this.applyGroupBy(qb, aggregate.groupBy, qb.alias); return qb; @@ -112,7 +130,9 @@ export class FilterQueryBuilder { * @param query - the query to apply. */ delete(query: Query): DeleteQueryBuilder { - return this.applyFilter(this.repo.createQueryBuilder().delete(), query.filter); + const qb = this.repo.createQueryBuilder().delete(); + const klass = this.repo.metadata.target as Class; + return this.applyFilter(qb, klass, undefined, query.filter); } /** @@ -121,10 +141,9 @@ export class FilterQueryBuilder { * @param query - the query to apply. */ softDelete(query: Query): SoftDeleteQueryBuilder { - return this.applyFilter( - this.repo.createQueryBuilder().softDelete() as SoftDeleteQueryBuilder, - query.filter, - ); + const qb = this.repo.createQueryBuilder().softDelete() as SoftDeleteQueryBuilder; + const klass = this.repo.metadata.target as Class; + return this.applyFilter(qb, klass, undefined, query.filter); } /** @@ -133,7 +152,9 @@ export class FilterQueryBuilder { * @param query - the query to apply. */ update(query: Query): UpdateQueryBuilder { - const qb = this.applyFilter(this.repo.createQueryBuilder().update(), query.filter); + const qb = this.repo.createQueryBuilder().update(); + const klass = this.repo.metadata.target as Class; + this.applyFilter(qb, klass, undefined, query.filter); return this.applySorting(qb, query.sorting); } @@ -170,14 +191,29 @@ export class FilterQueryBuilder { * Applies the filter from a Query to a `typeorm` QueryBuilder. * * @param qb - the `typeorm` QueryBuilder. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param filter - the filter. * @param alias - optional alias to use to qualify an identifier */ - applyFilter(qb: Where, filter?: Filter, alias?: string): Where { + applyFilter( + qb: Where, + klass: Class, + customFilters?: CustomFilterRegistry, + filter?: Filter, + alias?: string, + ): Where { if (!filter) { return qb; } - return this.whereBuilder.build(qb, filter, this.getReferencedRelationsRecursive(this.repo.metadata, filter), alias); + return this.whereBuilder.build( + qb, + filter, + this.getReferencedRelationsMetaRecursive(this.repo.metadata, filter), + klass, + customFilters, + alias, + ); } /** @@ -216,14 +252,6 @@ export class FilterQueryBuilder { }, qb); } - /** - * Create a `typeorm` SelectQueryBuilder which can be used as an entry point to create update, delete or insert - * QueryBuilders. - */ - private createQueryBuilder(): SelectQueryBuilder { - return this.repo.createQueryBuilder(); - } - /** * Gets relations referenced in the filter and adds joins for them to the query builder * @param qb - the `typeorm` QueryBuilder. @@ -263,15 +291,10 @@ export class FilterQueryBuilder { return this.getReferencedRelations(filter).length > 0; } - private getReferencedRelations(filter: Filter): string[] { - const { relationNames } = this; - const referencedFields = getFilterFields(filter); - return referencedFields.filter((f) => relationNames.includes(f)); - } - getReferencedRelationsRecursive(metadata: EntityMetadata, filter: Filter = {}): NestedRecord { - const referencedFields = Array.from(new Set(Object.keys(filter) as (keyof Filter)[])); + const referencedFields = Array.from(new Set(Object.keys(filter))); return referencedFields.reduce((prev, curr) => { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment const currFilterValue = filter[curr]; if ((curr === 'and' || curr === 'or') && currFilterValue) { for (const subFilter of currFilterValue) { @@ -290,7 +313,41 @@ export class FilterQueryBuilder { }, {}); } - private get relationNames(): string[] { - return this.repo.metadata.relations.map((r) => r.propertyName); + getReferencedRelationsMetaRecursive(metadata: EntityMetadata, filter: Filter = {}): RelationsMeta { + const referencedFields = Array.from(new Set(Object.keys(filter))); + let meta: RelationsMeta = {}; + for (const referencedField of referencedFields) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const currFilterValue = filter[referencedField]; + if ((referencedField === 'and' || referencedField === 'or') && currFilterValue) { + for (const subFilter of currFilterValue) { + meta = merge(meta, this.getReferencedRelationsRecursive(metadata, subFilter)); + } + } + const referencedRelation = metadata.relations.find((r) => r.propertyName === referencedField); + if (!referencedRelation) continue; + meta[referencedField] = { + targetKlass: referencedRelation.target as Class, + relations: merge( + meta?.[referencedField]?.relations, + this.getReferencedRelationsRecursive(referencedRelation.inverseEntityMetadata, currFilterValue), + ), + }; + } + return meta; + } + + /** + * Create a `typeorm` SelectQueryBuilder which can be used as an entry point to create update, delete or insert + * QueryBuilders. + */ + private createQueryBuilder(): SelectQueryBuilder { + return this.repo.createQueryBuilder(); + } + + private getReferencedRelations(filter: Filter): string[] { + const { relationNames } = this; + const referencedFields = getFilterFields(filter); + return referencedFields.filter((f) => relationNames.includes(f)); } } diff --git a/packages/query-typeorm/src/query/index.ts b/packages/query-typeorm/src/query/index.ts index 852faaf423..37a3c7968c 100644 --- a/packages/query-typeorm/src/query/index.ts +++ b/packages/query-typeorm/src/query/index.ts @@ -3,3 +3,4 @@ export * from './where.builder'; export * from './sql-comparison.builder'; export * from './relation-query.builder'; export * from './aggregate.builder'; +export * from './custom-filter.registry'; diff --git a/packages/query-typeorm/src/query/relation-query.builder.ts b/packages/query-typeorm/src/query/relation-query.builder.ts index 7bdd8d5b38..072125e74e 100644 --- a/packages/query-typeorm/src/query/relation-query.builder.ts +++ b/packages/query-typeorm/src/query/relation-query.builder.ts @@ -79,7 +79,14 @@ export class RelationQueryBuilder { this.filterQueryBuilder.getReferencedRelationsRecursive(this.relationRepo.metadata, query.filter), ) : relationBuilder; - relationBuilder = this.filterQueryBuilder.applyFilter(relationBuilder, query.filter, relationBuilder.alias); + const klass = this.repo.metadata.target as Class; + relationBuilder = this.filterQueryBuilder.applyFilter( + relationBuilder, + klass, + undefined, + query.filter, + relationBuilder.alias, + ); relationBuilder = this.filterQueryBuilder.applyPaging(relationBuilder, query.paging); return this.filterQueryBuilder.applySorting(relationBuilder, query.sorting, relationBuilder.alias); } @@ -120,8 +127,15 @@ export class RelationQueryBuilder { aggregateQuery: AggregateQuery, ): SelectQueryBuilder { let relationBuilder = this.createRelationQueryBuilder(entity); + const klass = this.repo.metadata.target as Class; relationBuilder = this.filterQueryBuilder.applyAggregate(relationBuilder, aggregateQuery, relationBuilder.alias); - relationBuilder = this.filterQueryBuilder.applyFilter(relationBuilder, query.filter, relationBuilder.alias); + relationBuilder = this.filterQueryBuilder.applyFilter( + relationBuilder, + klass, + undefined, + query.filter, + relationBuilder.alias, + ); relationBuilder = this.filterQueryBuilder.applyAggregateSorting( relationBuilder, aggregateQuery.groupBy, diff --git a/packages/query-typeorm/src/query/where.builder.ts b/packages/query-typeorm/src/query/where.builder.ts index 79c47f3174..43d18619ad 100644 --- a/packages/query-typeorm/src/query/where.builder.ts +++ b/packages/query-typeorm/src/query/where.builder.ts @@ -1,7 +1,8 @@ +import { Class, Filter, FilterComparisons, FilterFieldComparison } from '@nestjs-query/core'; import { Brackets, WhereExpression } from 'typeorm'; -import { Filter, FilterComparisons, FilterFieldComparison } from '@nestjs-query/core'; +import { CustomFilterRegistry } from './custom-filter.registry'; +import { RelationsMeta } from './filter-query.builder'; import { EntityComparisonField, SQLComparisonBuilder } from './sql-comparison.builder'; -import { NestedRecord } from './filter-query.builder'; /** * @internal @@ -14,23 +15,27 @@ export class WhereBuilder { * Builds a WHERE clause from a Filter. * @param where - the `typeorm` WhereExpression * @param filter - the filter to build the WHERE clause from. - * @param relationNames - the relations tree. + * @param relationMeta - the relations tree. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param alias - optional alias to use to qualify an identifier */ build( where: Where, filter: Filter, - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, alias?: string, ): Where { const { and, or } = filter; if (and && and.length) { - this.filterAnd(where, and, relationNames, alias); + this.filterAnd(where, and, relationMeta, klass, customFilters, alias); } if (or && or.length) { - this.filterOr(where, or, relationNames, alias); + this.filterOr(where, or, relationMeta, klass, customFilters, alias); } - return this.filterFields(where, filter, relationNames, alias); + return this.filterFields(where, filter, relationMeta, klass, customFilters, alias); } /** @@ -38,17 +43,23 @@ export class WhereBuilder { * * @param where - the `typeorm` WhereExpression * @param filters - the array of filters to AND together - * @param relationNames - the relations tree. + * @param relationMeta - the relations tree. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param alias - optional alias to use to qualify an identifier */ private filterAnd( where: Where, filters: Filter[], - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, alias?: string, ): Where { return where.andWhere( - new Brackets((qb) => filters.reduce((w, f) => qb.andWhere(this.createBrackets(f, relationNames, alias)), qb)), + new Brackets((qb) => + filters.reduce((w, f) => qb.andWhere(this.createBrackets(f, relationMeta, klass, customFilters, alias)), qb), + ), ); } @@ -57,17 +68,23 @@ export class WhereBuilder { * * @param where - the `typeorm` WhereExpression * @param filter - the array of filters to OR together - * @param relationNames - the relations tree. + * @param relationMeta - the relations tree. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param alias - optional alias to use to qualify an identifier */ private filterOr( where: Where, filter: Filter[], - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, alias?: string, ): Where { return where.andWhere( - new Brackets((qb) => filter.reduce((w, f) => qb.orWhere(this.createBrackets(f, relationNames, alias)), qb)), + new Brackets((qb) => + filter.reduce((w, f) => qb.orWhere(this.createBrackets(f, relationMeta, klass, customFilters, alias)), qb), + ), ); } @@ -78,33 +95,47 @@ export class WhereBuilder { * {a: { eq: 1 }, b: { gt: 2 } } // "((a = 1) AND (b > 2))" * ``` * @param filter - the filter to wrap in brackets. - * @param relationNames - the relations tree. + * @param relationMeta - the relations tree. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param alias - optional alias to use to qualify an identifier */ - private createBrackets(filter: Filter, relationNames: NestedRecord, alias?: string): Brackets { - return new Brackets((qb) => this.build(qb, filter, relationNames, alias)); + private createBrackets( + filter: Filter, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, + alias?: string, + ): Brackets { + return new Brackets((qb) => this.build(qb, filter, relationMeta, klass, customFilters, alias)); } /** * Creates field comparisons from a filter. This method will ignore and/or properties. * @param where - the `typeorm` WhereExpression * @param filter - the filter with fields to create comparisons for. - * @param relationNames - the relations tree. + * @param relationMeta - the relations tree. + * @param klass - the class currently being processed + * @param customFilters - the custom filters map that this builder should process * @param alias - optional alias to use to qualify an identifier */ private filterFields( where: Where, filter: Filter, - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, alias?: string, ): Where { return Object.keys(filter).reduce((w, field) => { if (field !== 'and' && field !== 'or') { return this.withFilterComparison( where, - field as keyof Entity, + field as keyof Entity & string, this.getField(filter, field as keyof Entity), - relationNames, + relationMeta, + klass, + customFilters, alias, ); } @@ -121,20 +152,40 @@ export class WhereBuilder { private withFilterComparison( where: Where, - field: T, + field: T & string, cmp: FilterFieldComparison, - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, alias?: string, ): Where { - if (relationNames[field as string]) { - return this.withRelationFilter(where, field, cmp as Filter, relationNames[field as string]); + if (relationMeta && relationMeta[field as string]) { + return this.withRelationFilter( + where, + field, + cmp as Filter, + relationMeta[field as string].relations, + relationMeta[field as string].targetKlass, + customFilters, + ); } return where.andWhere( new Brackets((qb) => { - const opts = Object.keys(cmp) as (keyof FilterFieldComparison)[]; - const sqlComparisons = opts.map((cmpType) => - this.sqlComparisonBuilder.build(field, cmpType, cmp[cmpType] as EntityComparisonField, alias), - ); + // Fallback sqlComparisonBuilder + const opts = Object.keys(cmp) as (keyof FilterFieldComparison & string)[]; + const sqlComparisons = opts.map((cmpType) => { + const customFilter = customFilters?.getFilter(klass, field, cmpType); + // If we have a registered customfilter for this cmpType, this has priority over the standard sqlComparisonBuilder + if (customFilter) { + return customFilter.apply(field, cmpType, cmp[cmpType], alias); + } + return this.sqlComparisonBuilder.build( + field, + cmpType, + cmp[cmpType] as EntityComparisonField, + alias, + ); + }); sqlComparisons.map(({ sql, params }) => qb.orWhere(sql, params)); }), ); @@ -144,12 +195,14 @@ export class WhereBuilder { where: Where, field: T, cmp: Filter, - relationNames: NestedRecord, + relationMeta: RelationsMeta, + klass: Class, + customFilters?: CustomFilterRegistry, ): Where { return where.andWhere( new Brackets((qb) => { const relationWhere = new WhereBuilder(); - return relationWhere.build(qb, cmp, relationNames, field as string); + return relationWhere.build(qb, cmp, relationMeta, klass, customFilters, field as string); }), ); }