Skip to content

Commit

Permalink
feat(aggregations): Add aggregation support to sequelize
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-martin committed Jul 16, 2020
1 parent 7233c23 commit c37b7ae
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 3 deletions.
92 changes: 92 additions & 0 deletions packages/query-sequelize/__tests__/query/aggregate.builder.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* eslint-disable @typescript-eslint/naming-convention */
import { AggregateQuery } from '@nestjs-query/core';
import sequelize, { Projectable } from 'sequelize';
import { Test, TestingModule } from '@nestjs/testing';
import { SequelizeModule } from '@nestjs/sequelize';
import { Sequelize } from 'sequelize-typescript';
import { CONNECTION_OPTIONS } from '../__fixtures__/sequelize.fixture';
import { TestEntityTestRelationEntity } from '../__fixtures__/test-entity-test-relation.entity';
import { TestRelation } from '../__fixtures__/test-relation.entity';
import { TestEntity } from '../__fixtures__/test.entity';
import { AggregateBuilder } from '../../src/query';

describe('AggregateBuilder', (): void => {
let moduleRef: TestingModule;
const createAggregateBuilder = () => new AggregateBuilder<TestEntity>(TestEntity);

const assertSQL = (agg: AggregateQuery<TestEntity>, expected: Projectable): void => {
const actual = createAggregateBuilder().build(agg);
expect(actual).toEqual(expected);
};

afterEach(() => moduleRef.get(Sequelize).close());

beforeEach(async () => {
moduleRef = await Test.createTestingModule({
imports: [
SequelizeModule.forRoot(CONNECTION_OPTIONS),
SequelizeModule.forFeature([TestEntity, TestRelation, TestEntityTestRelationEntity]),
],
}).compile();
await moduleRef.get(Sequelize).sync();
});

it('should throw an error if no selects are generated', (): void => {
expect(() => createAggregateBuilder().build({})).toThrow('No aggregate fields found.');
});

it('or multiple operators for a single field together', (): void => {
assertSQL(
{
count: ['testEntityPk'],
avg: ['numberType'],
sum: ['numberType'],
max: ['stringType', 'dateType', 'numberType'],
min: ['stringType', 'dateType', 'numberType'],
},
{
attributes: [
[sequelize.fn('COUNT', sequelize.col('test_entity_pk')), 'COUNT_testEntityPk'],
[sequelize.fn('SUM', sequelize.col('number_type')), 'SUM_numberType'],
[sequelize.fn('AVG', sequelize.col('number_type')), 'AVG_numberType'],
[sequelize.fn('MAX', sequelize.col('string_type')), 'MAX_stringType'],
[sequelize.fn('MAX', sequelize.col('date_type')), 'MAX_dateType'],
[sequelize.fn('MAX', sequelize.col('number_type')), 'MAX_numberType'],
[sequelize.fn('MIN', sequelize.col('string_type')), 'MIN_stringType'],
[sequelize.fn('MIN', sequelize.col('date_type')), 'MIN_dateType'],
[sequelize.fn('MIN', sequelize.col('number_type')), 'MIN_numberType'],
],
},
);
});

describe('.convertToAggregateResponse', () => {
it('should convert a flat response into an Aggregtate response', () => {
const dbResult = {
COUNT_testEntityPk: 10,
SUM_numberType: 55,
AVG_numberType: 5,
MAX_stringType: 'z',
MAX_numberType: 10,
MIN_stringType: 'a',
MIN_numberType: 1,
};
expect(AggregateBuilder.convertToAggregateResponse<TestEntity>(dbResult)).toEqual({
count: { testEntityPk: 10 },
sum: { numberType: 55 },
avg: { numberType: 5 },
max: { stringType: 'z', numberType: 10 },
min: { stringType: 'a', numberType: 1 },
});
});

it('should throw an error if a column is not expected', () => {
const dbResult = {
COUNTtestEntityPk: 10,
};
expect(() => AggregateBuilder.convertToAggregateResponse<TestEntity>(dbResult)).toThrow(
'Unknown aggregate column encountered.',
);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,82 @@ describe('SequelizeQueryService', (): void => {
});
});

describe('#aggregate', () => {
it('call select with the aggregate columns and return the result', async () => {
const queryService = moduleRef.get(TestEntityService);
const queryResult = await queryService.aggregate(
{},
{
count: ['testEntityPk'],
avg: ['numberType'],
sum: ['numberType'],
max: ['testEntityPk', 'dateType', 'numberType', 'stringType'],
min: ['testEntityPk', 'dateType', 'numberType', 'stringType'],
},
);
return expect(queryResult).toEqual({
avg: {
numberType: 5.5,
},
count: {
testEntityPk: 10,
},
max: {
dateType: expect.stringMatching('2020-02-10'),
numberType: 10,
stringType: 'foo9',
testEntityPk: 'test-entity-9',
},
min: {
dateType: expect.stringMatching('2020-02-01'),
numberType: 1,
stringType: 'foo1',
testEntityPk: 'test-entity-1',
},
sum: {
numberType: 55,
},
});
});

it('call select with the aggregate columns and return the result with a filter', async () => {
const queryService = moduleRef.get(TestEntityService);
const queryResult = await queryService.aggregate(
{ stringType: { in: ['foo1', 'foo2', 'foo3'] } },
{
count: ['testEntityPk'],
avg: ['numberType'],
sum: ['numberType'],
max: ['testEntityPk', 'dateType', 'numberType', 'stringType'],
min: ['testEntityPk', 'dateType', 'numberType', 'stringType'],
},
);
return expect(queryResult).toEqual({
avg: {
numberType: 2,
},
count: {
testEntityPk: 3,
},
max: {
dateType: expect.stringMatching('2020-02-03'),
numberType: 3,
stringType: 'foo3',
testEntityPk: 'test-entity-3',
},
min: {
dateType: expect.stringMatching('2020-02-01'),
numberType: 1,
stringType: 'foo1',
testEntityPk: 'test-entity-1',
},
sum: {
numberType: 6,
},
});
});
});

describe('#count', () => {
it('call select and return the result', async () => {
const queryService = moduleRef.get(TestEntityService);
Expand Down
74 changes: 74 additions & 0 deletions packages/query-sequelize/src/query/aggregate.builder.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import sequelize, { Projectable } from 'sequelize';
import { AggregateQuery, AggregateResponse } from '@nestjs-query/core';
import { Model, ModelCtor } from 'sequelize-typescript';
import { BadRequestException } from '@nestjs/common';

enum AggregateFuncs {
AVG = 'AVG',
SUM = 'SUM',
COUNT = 'COUNT',
MAX = 'MAX',
MIN = 'MIN',
}

const AGG_REGEXP = /(AVG|SUM|COUNT|MAX|MIN)_(.*)/;

/**
* @internal
* Builds a WHERE clause from a Filter.
*/
export class AggregateBuilder<Entity extends Model<Entity>> {
static convertToAggregateResponse<Entity>(response: Record<string, unknown>): AggregateResponse<Entity> {
return Object.keys(response).reduce((agg, resultField: string) => {
const matchResult = AGG_REGEXP.exec(resultField);
if (!matchResult) {
throw new Error('Unknown aggregate column encountered.');
}
const [matchedFunc, matchedFieldName] = matchResult.slice(1);
const aggFunc = matchedFunc.toLowerCase() as keyof AggregateResponse<Entity>;
const fieldName = matchedFieldName as keyof Entity;
const aggResult = agg[aggFunc] || {};
return {
...agg,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
[aggFunc]: { ...aggResult, [fieldName]: response[resultField] },
};
}, {} as AggregateResponse<Entity>);
}

constructor(readonly model: ModelCtor<Entity>) {}

/**
* Builds a aggregate SELECT clause from a aggregate.
* @param qb - the `typeorm` SelectQueryBuilder
* @param aggregate - the aggregates to select.
* @param alias - optional alias to use to qualify an identifier
*/
build(aggregate: AggregateQuery<Entity>): Projectable {
const selects = [
...this.createAggSelect(AggregateFuncs.COUNT, aggregate.count),
...this.createAggSelect(AggregateFuncs.SUM, aggregate.sum),
...this.createAggSelect(AggregateFuncs.AVG, aggregate.avg),
...this.createAggSelect(AggregateFuncs.MAX, aggregate.max),
...this.createAggSelect(AggregateFuncs.MIN, aggregate.min),
];
if (!selects.length) {
throw new BadRequestException('No aggregate fields found.');
}
return {
attributes: selects,
};
}

private createAggSelect(func: AggregateFuncs, fields?: (keyof Entity)[]): [sequelize.Utils.Fn, string][] {
if (!fields) {
return [];
}
return fields.map((field) => {
const aggAlias = `${func}_${field as string}`;
const colName = this.model.rawAttributes[field as string].field;
const fn = sequelize.fn(func, sequelize.col(colName || (field as string)));
return [fn, aggAlias];
});
}
}
23 changes: 22 additions & 1 deletion packages/query-sequelize/src/query/filter-query.builder.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Filter, getFilterFields, Paging, Query, SortField } from '@nestjs-query/core';
import { AggregateQuery, Filter, getFilterFields, Paging, Query, SortField } from '@nestjs-query/core';
import {
FindOptions,
Filterable,
Expand All @@ -8,8 +8,10 @@ import {
UpdateOptions,
CountOptions,
Association,
Projectable,
} from 'sequelize';
import { Model, ModelCtor } from 'sequelize-typescript';
import { AggregateBuilder } from './aggregate.builder';
import { WhereBuilder } from './where.builder';

/**
Expand Down Expand Up @@ -40,6 +42,7 @@ export class FilterQueryBuilder<Entity extends Model<Entity>> {
constructor(
readonly model: ModelCtor<Entity>,
readonly whereBuilder: WhereBuilder<Entity> = new WhereBuilder<Entity>(),
readonly aggregateBuilder: AggregateBuilder<Entity> = new AggregateBuilder<Entity>(model),
) {}

/**
Expand All @@ -55,6 +58,18 @@ export class FilterQueryBuilder<Entity extends Model<Entity>> {
return opts;
}

/**
* Create a `sequelize` SelectQueryBuilder with `WHERE`, `ORDER BY` and `LIMIT/OFFSET` clauses.
*
* @param query - the query to apply.
*/
aggregateOptions(query: Query<Entity>, aggregate: AggregateQuery<Entity>): FindOptions {
let opts: FindOptions = {};
opts = this.applyAggregate(opts, aggregate);
opts = this.applyFilter(opts, query.filter);
return opts;
}

countOptions(query: Query<Entity>): CountOptions {
let opts: CountOptions = this.applyAssociationIncludes({}, query.filter);
opts.distinct = true;
Expand Down Expand Up @@ -144,6 +159,12 @@ export class FilterQueryBuilder<Entity extends Model<Entity>> {
return qb;
}

private applyAggregate<P extends Projectable>(opts: P, aggregate: AggregateQuery<Entity>): P {
// eslint-disable-next-line no-param-reassign
opts.attributes = this.aggregateBuilder.build(aggregate).attributes;
return opts;
}

private applyAssociationIncludes<Opts extends FindOptions | CountOptions>(
findOpts: Opts,
filter?: Filter<Entity>,
Expand Down
1 change: 1 addition & 0 deletions packages/query-sequelize/src/query/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './filter-query.builder';
export * from './where.builder';
export * from './sql-comparison.builder';
export * from './aggregate.builder';
21 changes: 19 additions & 2 deletions packages/query-sequelize/src/services/sequelize-query.service.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import { Query, DeleteManyResponse, UpdateManyResponse, DeepPartial, QueryService, Filter } from '@nestjs-query/core';
import {
Query,
DeleteManyResponse,
UpdateManyResponse,
DeepPartial,
QueryService,
Filter,
AggregateQuery,
AggregateResponse,
} from '@nestjs-query/core';
import lodashPick from 'lodash.pick';
import { Model, ModelCtor } from 'sequelize-typescript';
import { WhereOptions } from 'sequelize';
import { NotFoundException } from '@nestjs/common';
import { FilterQueryBuilder } from '../query';
import { FilterQueryBuilder, AggregateBuilder } from '../query';
import { RelationQueryService } from './relation-query.service';

/**
Expand Down Expand Up @@ -48,6 +57,14 @@ export class SequelizeQueryService<Entity extends Model<Entity>> extends Relatio
return this.model.findAll<Entity>(this.filterQueryBuilder.findOptions(query));
}

async aggregate(filter: Filter<Entity>, aggregate: AggregateQuery<Entity>): Promise<AggregateResponse<Entity>> {
const result = await this.model.findOne(this.filterQueryBuilder.aggregateOptions({ filter }, aggregate));
if (!result) {
return {};
}
return AggregateBuilder.convertToAggregateResponse(result.get({ plain: true }) as Record<string, unknown>);
}

async count(filter: Filter<Entity>): Promise<number> {
return this.model.count(this.filterQueryBuilder.countOptions({ filter }));
}
Expand Down

0 comments on commit c37b7ae

Please sign in to comment.