diff --git a/backend/package.json b/backend/package.json index 13bcfd94..9e0b75b8 100644 --- a/backend/package.json +++ b/backend/package.json @@ -35,7 +35,6 @@ "@leoscope/openapi-response-validator": "^1.0.2", "@metlo/testing": "^0.0.3", "@types/async-retry": "^1.4.4", - "@types/newman": "^5.3.0", "@types/ssh2": "^1.11.5", "async-retry": "^1.3.3", "aws-sdk": "^2.1189.0", @@ -53,7 +52,6 @@ "json-source-map": "^0.6.1", "lodash": "^4.17.21", "luxon": "^3.0.3", - "memory-cache": "^0.2.0", "multer": "^1.4.5-lts.1", "node-schedule": "^2.1.0", "node-ssh": "^13.0.0", @@ -79,7 +77,6 @@ "@types/js-yaml": "^4.0.5", "@types/lodash": "^4.14.184", "@types/luxon": "^3.0.1", - "@types/memory-cache": "^0.2.2", "@types/multer": "^1.4.7", "@types/node": "^18.6.1", "@types/node-schedule": "^2.1.0", diff --git a/backend/src/analyze-traces.ts b/backend/src/analyze-traces.ts index 95762fea..1845f2fc 100644 --- a/backend/src/analyze-traces.ts +++ b/backend/src/analyze-traces.ts @@ -1,10 +1,9 @@ import { v4 as uuidv4 } from "uuid" import { AppDataSource } from "data-source" -import { ApiTrace, ApiEndpoint, DataField, Alert } from "models" +import { ApiTrace, ApiEndpoint, DataField, Alert, OpenApiSpec } from "models" import { DataFieldService } from "services/data-field" import { SpecService } from "services/spec" import { AlertService } from "services/alert" -import { DatabaseService } from "services/database" import { RedisClient } from "utils/redis" import { TRACES_QUEUE } from "~/constants" import { QueryRunner } from "typeorm" @@ -18,14 +17,20 @@ import { import { getPathTokens } from "@common/utils" import { AlertType } from "@common/enums" import { isGraphQlEndpoint } from "services/graphql" +import { isQueryFailedError, retryTypeormTransaction } from "utils/db" +import { MetloContext } from "types" +import { DatabaseService } from "services/database" +import { getEntityManager, getQB } from "services/database/utils" -const GET_ENDPOINT_QUERY = ` +const getEndpointQuery = (ctx: MetloContext) => ` SELECT endpoint. *, CASE WHEN spec."isAutoGenerated" IS NULL THEN NULL ELSE json_build_object('isAutoGenerated', spec."isAutoGenerated") END as "openapiSpec" FROM - "api_endpoint" endpoint - LEFT JOIN "open_api_spec" spec ON endpoint."openapiSpecName" = spec.name + ${ApiEndpoint.getTableName(ctx)} endpoint + LEFT JOIN ${OpenApiSpec.getTableName( + ctx, + )} spec ON endpoint."openapiSpecName" = spec.name WHERE $1 ~ "pathRegex" AND method = $2 @@ -39,7 +44,7 @@ LIMIT 1 ` -const GET_DATA_FIELDS_QUERY = ` +const getDataFieldsQuery = (ctx: MetloContext) => ` SELECT uuid, "dataClasses"::text[], @@ -53,14 +58,19 @@ SELECT "dataPath", "apiEndpointUuid" FROM - data_field + ${DataField.getTableName(ctx)} data_field WHERE "apiEndpointUuid" = $1 ` -const getQueuedApiTrace = async (): Promise => { +const getQueuedApiTrace = async ( + ctx: MetloContext, +): Promise => { try { - const traceString = await RedisClient.popValueFromRedisList(TRACES_QUEUE) + const traceString = await RedisClient.popValueFromRedisList( + ctx, + TRACES_QUEUE, + ) return JSON.parse(traceString) } catch (err) { return null @@ -68,6 +78,7 @@ const getQueuedApiTrace = async (): Promise => { } const analyze = async ( + ctx: MetloContext, trace: QueuedApiTrace, apiEndpoint: ApiEndpoint, queryRunner: QueryRunner, @@ -76,11 +87,13 @@ const analyze = async ( endpointUpdateDates(trace.createdAt, apiEndpoint) const dataFields = DataFieldService.findAllDataFields(trace, apiEndpoint) let alerts = await SpecService.findOpenApiSpecDiff( + ctx, trace, apiEndpoint, queryRunner, ) const sensitiveDataAlerts = await AlertService.createDataFieldAlerts( + ctx, dataFields, apiEndpoint.uuid, apiEndpoint.path, @@ -90,6 +103,7 @@ const analyze = async ( alerts = alerts?.concat(sensitiveDataAlerts) if (newEndpoint) { const newEndpointAlert = await AlertService.createAlert( + ctx, AlertType.NEW_ENDPOINT, apiEndpoint, ) @@ -99,18 +113,17 @@ const analyze = async ( } await queryRunner.startTransaction() - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager.insert(ApiTrace, { + getEntityManager(ctx, queryRunner).insert(ApiTrace, { ...trace, apiEndpointUuid: apiEndpoint.uuid, }), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .insert() .into(DataField) .values(dataFields) @@ -127,10 +140,9 @@ const analyze = async ( .execute(), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .insert() .into(Alert) .values(alerts) @@ -138,10 +150,9 @@ const analyze = async ( .execute(), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .update(ApiEndpoint) .set({ firstDetected: apiEndpoint.firstDetected, @@ -156,6 +167,7 @@ const analyze = async ( } const generateEndpoint = async ( + ctx: MetloContext, trace: QueuedApiTrace, queryRunner: QueryRunner, ): Promise => { @@ -201,10 +213,9 @@ const generateEndpoint = async ( try { await queryRunner.startTransaction() - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .insert() .into(ApiEndpoint) .values(apiEndpoint) @@ -212,25 +223,25 @@ const generateEndpoint = async ( 5, ) await queryRunner.commitTransaction() - await analyze(trace, apiEndpoint, queryRunner, true) + await analyze(ctx, trace, apiEndpoint, queryRunner, true) } catch (err) { if (queryRunner.isTransactionActive) { await queryRunner.rollbackTransaction() } - if (DatabaseService.isQueryFailedError(err) && err.code === "23505") { - const existingEndpoint = await queryRunner.manager.findOne( - ApiEndpoint, - { - where: { - path: trace.path, - host: trace.host, - method: trace.method, - }, - relations: { dataFields: true }, + if (isQueryFailedError(err) && err.code === "23505") { + const existingEndpoint = await getEntityManager( + ctx, + queryRunner, + ).findOne(ApiEndpoint, { + where: { + path: trace.path, + host: trace.host, + method: trace.method, }, - ) + relations: { dataFields: true }, + }) if (existingEndpoint) { - await analyze(trace, existingEndpoint, queryRunner) + await analyze(ctx, trace, existingEndpoint, queryRunner) } } else { console.error(`Error generating new endpoint: ${err}`) @@ -240,6 +251,8 @@ const generateEndpoint = async ( } const analyzeTraces = async (): Promise => { + const ctx: MetloContext = {} + const datasource = await AppDataSource.initialize() if (!datasource.isInitialized) { console.error("Couldn't initialize datasource...") @@ -251,26 +264,26 @@ const analyzeTraces = async (): Promise => { await queryRunner.connect() while (true) { try { - const trace = await getQueuedApiTrace() + const trace = await getQueuedApiTrace(ctx) if (trace) { trace.createdAt = new Date(trace.createdAt) const apiEndpoint: ApiEndpoint = ( - await queryRunner.query(GET_ENDPOINT_QUERY, [ + await queryRunner.query(getEndpointQuery(ctx), [ trace.path, trace.method, trace.host, ]) )?.[0] if (apiEndpoint && !skipAutoGeneratedMatch(apiEndpoint, trace.path)) { - const dataFields: DataField[] = await queryRunner.query( - GET_DATA_FIELDS_QUERY, + const dataFields: DataField[] = await DatabaseService.executeRawQuery( + getDataFieldsQuery(ctx), [apiEndpoint.uuid], ) apiEndpoint.dataFields = dataFields - await analyze(trace, apiEndpoint, queryRunner) + await analyze(ctx, trace, apiEndpoint, queryRunner) } else { if (trace.responseStatus !== 404 && trace.responseStatus !== 405) { - await generateEndpoint(trace, queryRunner) + await generateEndpoint(ctx, trace, queryRunner) } } } diff --git a/backend/src/api/alert/index.ts b/backend/src/api/alert/index.ts index 1e21b717..3369c905 100644 --- a/backend/src/api/alert/index.ts +++ b/backend/src/api/alert/index.ts @@ -2,14 +2,15 @@ import { Request, Response } from "express" import { AlertService } from "services/alert" import { GetAlertParams, UpdateAlertParams } from "@common/types" import ApiResponseHandler from "api-response-handler" +import { MetloRequest } from "types" export const getAlertsHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const alertParams: GetAlertParams = req.query - const alerts = await AlertService.getAlerts(alertParams) + const alerts = await AlertService.getAlerts(req.ctx, alertParams) await ApiResponseHandler.success(res, alerts) } catch (err) { await ApiResponseHandler.error(res, err) @@ -17,13 +18,14 @@ export const getAlertsHandler = async ( } export const updateAlertHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { alertId } = req.params const updateAlertParams: UpdateAlertParams = req.body const updatedAlert = await AlertService.updateAlert( + req.ctx, alertId, updateAlertParams, ) diff --git a/backend/src/api/alert/vulnerability.ts b/backend/src/api/alert/vulnerability.ts index c7e9adc3..62f62ec5 100644 --- a/backend/src/api/alert/vulnerability.ts +++ b/backend/src/api/alert/vulnerability.ts @@ -1,15 +1,16 @@ -import { Request, Response } from "express" +import { Response } from "express" import { GetVulnerabilityAggParams } from "@common/types" import ApiResponseHandler from "api-response-handler" import { getVulnerabilityAgg } from "services/summary/vulnerabilities" +import { MetloRequest } from "types" export const getVulnerabilitySummaryHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const params: GetVulnerabilityAggParams = req.query - const out = await getVulnerabilityAgg(params) + const out = await getVulnerabilityAgg(req.ctx, params) await ApiResponseHandler.success(res, out) } catch (err) { await ApiResponseHandler.error(res, err) diff --git a/backend/src/api/connections/index.ts b/backend/src/api/connections/index.ts index 75dc8626..d22e5338 100644 --- a/backend/src/api/connections/index.ts +++ b/backend/src/api/connections/index.ts @@ -1,24 +1,27 @@ -import { Request, Response } from "express" +import { Response } from "express" import { ConnectionsService } from "services/connections" import ApiResponseHandler from "api-response-handler" import { decrypt } from "utils/encryption" import { delete_connection as delete_connection_request } from "suricata_setup/" import { ConnectionType } from "@common/enums" import { randomUUID } from "crypto" -import { addToRedis, addToRedisFromPromise } from "suricata_setup/utils" +import { RedisClient } from "utils/redis" +import { MetloRequest } from "types" -const listConnections = async (req: Request, res: Response) => { +const listConnections = async (req: MetloRequest, res: Response) => { try { - const connections = (await ConnectionsService.listConnections()).map(v => { - if (v.connectionType === ConnectionType.AWS) { - delete v.aws.keypair - delete v.aws.access_id - delete v.aws.secret_access_key - } else if (v.connectionType === ConnectionType.GCP) { - delete v.gcp.key_file - } - return v - }) + const connections = (await ConnectionsService.listConnections(req.ctx)).map( + v => { + if (v.connectionType === ConnectionType.AWS) { + delete v.aws.keypair + delete v.aws.access_id + delete v.aws.secret_access_key + } else if (v.connectionType === ConnectionType.GCP) { + delete v.gcp.key_file + } + return v + }, + ) await ApiResponseHandler.success(res, connections) } catch (err) { @@ -26,10 +29,13 @@ const listConnections = async (req: Request, res: Response) => { } } -const getConnectionForUuid = async (req: Request, res: Response) => { +const getConnectionForUuid = async (req: MetloRequest, res: Response) => { try { const { uuid } = req.params - const connection = await ConnectionsService.getConnectionForUuid(uuid) + const connection = await ConnectionsService.getConnectionForUuid( + req.ctx, + uuid, + ) delete connection.aws.keypair delete connection.aws.access_id @@ -41,10 +47,14 @@ const getConnectionForUuid = async (req: Request, res: Response) => { } } -const getSshKeyForConnectionUuid = async (req: Request, res: Response) => { +const getSshKeyForConnectionUuid = async (req: MetloRequest, res: Response) => { try { const { uuid } = req.params - const connection = await ConnectionsService.getConnectionForUuid(uuid, true) + const connection = await ConnectionsService.getConnectionForUuid( + req.ctx, + uuid, + true, + ) const ssh_key = decrypt( connection.aws.keypair, Buffer.from(process.env.ENCRYPTION_KEY, "base64"), @@ -57,22 +67,29 @@ const getSshKeyForConnectionUuid = async (req: Request, res: Response) => { } } -const updateConnection = async (req: Request, res: Response) => { +const updateConnection = async (req: MetloRequest, res: Response) => { try { const { name, id: uuid } = req.body - let resp = await ConnectionsService.updateConnectionForUuid({ name, uuid }) + let resp = await ConnectionsService.updateConnectionForUuid(req.ctx, { + name, + uuid, + }) await ApiResponseHandler.success(res, { name: name }) } catch (err) { await ApiResponseHandler.error(res, err) } } -const deleteConnection = async (req: Request, res: Response) => { +const deleteConnection = async (req: MetloRequest, res: Response) => { const { uuid } = req.params try { - const connection = await ConnectionsService.getConnectionForUuid(uuid, true) + const connection = await ConnectionsService.getConnectionForUuid( + req.ctx, + uuid, + true, + ) const retry_uuid = randomUUID() - await addToRedis(retry_uuid, { success: "FETCHING" }) + await RedisClient.addToRedis(req.ctx, retry_uuid, { success: "FETCHING" }) if (connection.connectionType === ConnectionType.AWS) { const access_key = decrypt( connection.aws.access_id, @@ -89,7 +106,8 @@ const deleteConnection = async (req: Request, res: Response) => { connection.aws.access_id = access_key connection.aws.secret_access_key = secret_access_key - addToRedisFromPromise( + RedisClient.addToRedisFromPromise( + req.ctx, retry_uuid, delete_connection_request(connection.connectionType, { ...connection.aws, @@ -97,7 +115,7 @@ const deleteConnection = async (req: Request, res: Response) => { name: connection.name, }) .then(() => { - return ConnectionsService.deleteConnectionForUuid({ + return ConnectionsService.deleteConnectionForUuid(req.ctx, { uuid: connection.uuid, }).then(() => ({ success: "OK", @@ -114,7 +132,8 @@ const deleteConnection = async (req: Request, res: Response) => { ) connection.gcp.key_file = key_file - addToRedisFromPromise( + RedisClient.addToRedisFromPromise( + req.ctx, retry_uuid, delete_connection_request(connection.connectionType, { ...connection.gcp, @@ -122,7 +141,7 @@ const deleteConnection = async (req: Request, res: Response) => { name: connection.name, }) .then(() => { - return ConnectionsService.deleteConnectionForUuid({ + return ConnectionsService.deleteConnectionForUuid(req.ctx, { uuid: connection.uuid, }).then(() => ({ success: "OK", diff --git a/backend/src/api/data-field/index.ts b/backend/src/api/data-field/index.ts index 44522f99..91fbcd6d 100644 --- a/backend/src/api/data-field/index.ts +++ b/backend/src/api/data-field/index.ts @@ -1,12 +1,13 @@ -import { Request, Response } from "express" +import { Response } from "express" import { DataFieldService } from "services/data-field" import { UpdateDataFieldClassesParams } from "@common/types" import ApiResponseHandler from "api-response-handler" import Error400BadRequest from "errors/error-400-bad-request" import { GetEndpointsService } from "services/get-endpoints" +import { MetloRequest } from "types" export const updateDataFieldClasses = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -23,6 +24,7 @@ export const updateDataFieldClasses = async ( throw new Error400BadRequest("No data section provided.") } const updatedDataField = await DataFieldService.updateDataClasses( + req.ctx, dataFieldId, dataClasses, dataPath, @@ -30,6 +32,7 @@ export const updateDataFieldClasses = async ( ) if (updatedDataField) { await GetEndpointsService.updateEndpointRiskScore( + req.ctx, updatedDataField.apiEndpointUuid, ) } @@ -40,14 +43,18 @@ export const updateDataFieldClasses = async ( } export const deleteDataFieldHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { dataFieldId } = req.params - const removedDataField = await DataFieldService.deleteDataField(dataFieldId) + const removedDataField = await DataFieldService.deleteDataField( + req.ctx, + dataFieldId, + ) if (removedDataField) { await GetEndpointsService.updateEndpointRiskScore( + req.ctx, removedDataField.apiEndpointUuid, ) } diff --git a/backend/src/api/data-field/sensitive-data.ts b/backend/src/api/data-field/sensitive-data.ts index 42227141..dab013a0 100644 --- a/backend/src/api/data-field/sensitive-data.ts +++ b/backend/src/api/data-field/sensitive-data.ts @@ -1,15 +1,16 @@ -import { Request, Response } from "express" -import { GetSensitiveDataAggParams, SensitiveDataSummary } from "@common/types" +import { Response } from "express" +import { GetSensitiveDataAggParams } from "@common/types" import ApiResponseHandler from "api-response-handler" import { getPIIDataTypeAgg } from "services/summary/piiData" +import { MetloRequest } from "types" export const getSensitiveDataSummaryHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const params: GetSensitiveDataAggParams = req.query - const out = await getPIIDataTypeAgg(params) + const out = await getPIIDataTypeAgg(req.ctx, params) await ApiResponseHandler.success(res, out) } catch (err) { await ApiResponseHandler.error(res, err) diff --git a/backend/src/api/get-endpoints/index.ts b/backend/src/api/get-endpoints/index.ts index ff2c54f7..8e67c239 100644 --- a/backend/src/api/get-endpoints/index.ts +++ b/backend/src/api/get-endpoints/index.ts @@ -1,17 +1,21 @@ -import { Request, Response } from "express" +import { Response } from "express" import validator from "validator" import { GetEndpointsService } from "services/get-endpoints" import { GetEndpointParams } from "@common/types" import ApiResponseHandler from "api-response-handler" import Error404NotFound from "errors/error-404-not-found" +import { MetloRequest } from "types" export const getEndpointsHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { const getEndpointParams: GetEndpointParams = req.query try { - const endpoints = await GetEndpointsService.getEndpoints(getEndpointParams) + const endpoints = await GetEndpointsService.getEndpoints( + req.ctx, + getEndpointParams, + ) await ApiResponseHandler.success(res, endpoints) } catch (err) { await ApiResponseHandler.error(res, err) @@ -19,7 +23,7 @@ export const getEndpointsHandler = async ( } export const getEndpointHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -27,7 +31,7 @@ export const getEndpointHandler = async ( if (!validator.isUUID(endpointId)) { throw new Error404NotFound("Endpoint does not exist.") } - const endpoint = await GetEndpointsService.getEndpoint(endpointId) + const endpoint = await GetEndpointsService.getEndpoint(req.ctx, endpointId) await ApiResponseHandler.success(res, endpoint) } catch (err) { await ApiResponseHandler.error(res, err) @@ -35,11 +39,11 @@ export const getEndpointHandler = async ( } export const getHostsHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { - const hosts = await GetEndpointsService.getHosts() + const hosts = await GetEndpointsService.getHosts(req.ctx) await ApiResponseHandler.success(res, hosts) } catch (err) { await ApiResponseHandler.error(res, err) @@ -47,7 +51,7 @@ export const getHostsHandler = async ( } export const getUsageHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -55,7 +59,7 @@ export const getUsageHandler = async ( if (!validator.isUUID(endpointId)) { throw new Error404NotFound("Endpoint does not exist.") } - const usageData = await GetEndpointsService.getUsage(endpointId) + const usageData = await GetEndpointsService.getUsage(req.ctx, endpointId) await ApiResponseHandler.success(res, usageData) } catch (err) { await ApiResponseHandler.error(res, err) @@ -63,13 +67,14 @@ export const getUsageHandler = async ( } export const updateEndpointIsAuthenticated = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { endpointId } = req.params const params: { authenticated: boolean } = req.body await GetEndpointsService.updateIsAuthenticated( + req.ctx, endpointId, params.authenticated, ) diff --git a/backend/src/api/keys/index.ts b/backend/src/api/keys/index.ts index 1965a028..3425a2c1 100644 --- a/backend/src/api/keys/index.ts +++ b/backend/src/api/keys/index.ts @@ -1,14 +1,18 @@ import ApiResponseHandler from "api-response-handler" -import { Request, Response } from "express" -import { AppDataSource } from "data-source" +import { Response } from "express" import { ApiKey } from "models" import { ApiKey as ApiKeyType } from "@common/types" import Error404NotFound from "errors/error-404-not-found" import { createApiKey } from "./service" import Error400BadRequest from "errors/error-400-bad-request" +import { createQB, getRepository } from "services/database/utils" +import { MetloRequest } from "types" -export const listKeys = async (req: Request, res: Response): Promise => { - const keys = await AppDataSource.getRepository(ApiKey).find() +export const listKeys = async ( + req: MetloRequest, + res: Response, +): Promise => { + const keys = await getRepository(req.ctx, ApiKey).find() return ApiResponseHandler.success( res, keys.map(v => ({ @@ -20,9 +24,12 @@ export const listKeys = async (req: Request, res: Response): Promise => { ) } -export const createKey = async (req: Request, res: Response): Promise => { +export const createKey = async ( + req: MetloRequest, + res: Response, +): Promise => { const { name: keyName } = req.body - const key_exists = await AppDataSource.getRepository(ApiKey).countBy({ + const key_exists = await getRepository(req.ctx, ApiKey).countBy({ name: keyName, }) if (key_exists) { @@ -38,7 +45,7 @@ export const createKey = async (req: Request, res: Response): Promise => { ) } const [key, rawKey] = createApiKey(keyName) - await AppDataSource.getRepository(ApiKey).save(key) + await getRepository(req.ctx, ApiKey).save(key) return ApiResponseHandler.success(res, { apiKey: rawKey, name: key.name, @@ -48,10 +55,13 @@ export const createKey = async (req: Request, res: Response): Promise => { }) } -export const deleteKey = async (req: Request, res: Response): Promise => { +export const deleteKey = async ( + req: MetloRequest, + res: Response, +): Promise => { const { name: keyName } = req.params - let del_resp = await AppDataSource.createQueryBuilder() + let del_resp = await createQB(req.ctx) .delete() .from(ApiKey) .where("name = :name", { name: keyName }) diff --git a/backend/src/api/settings/index.ts b/backend/src/api/settings/index.ts index fc2aec42..e77b3407 100644 --- a/backend/src/api/settings/index.ts +++ b/backend/src/api/settings/index.ts @@ -1,15 +1,15 @@ -import { Request, Response } from "express" +import { Response } from "express" import ApiResponseHandler from "api-response-handler" -import { AppDataSource } from "data-source" import { InstanceSettings } from "models" +import { getRepository } from "services/database/utils" +import { MetloRequest } from "types" export const getInstanceSettingsHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { - const instanceSettingsRepository = - AppDataSource.getRepository(InstanceSettings) + const instanceSettingsRepository = getRepository(req.ctx, InstanceSettings) const instanceSettings = await instanceSettingsRepository.findOne({ where: {}, }) @@ -20,13 +20,12 @@ export const getInstanceSettingsHandler = async ( } export const putInstanceSettingsHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { email, skip } = req.body - const instanceSettingsRepository = - AppDataSource.getRepository(InstanceSettings) + const instanceSettingsRepository = getRepository(req.ctx, InstanceSettings) const instanceSettings = await instanceSettingsRepository.findOne({ where: {}, }) diff --git a/backend/src/api/setup/index.ts b/backend/src/api/setup/index.ts index d6bd00f3..d7db1807 100644 --- a/backend/src/api/setup/index.ts +++ b/backend/src/api/setup/index.ts @@ -1,16 +1,17 @@ -import { Request, Response } from "express" +import { Response } from "express" import ApiResponseHandler from "api-response-handler" -import { AWS_CONNECTION, SSH_INFO, STEP_RESPONSE } from "@common/types" +import { STEP_RESPONSE } from "@common/types" import { ConnectionType } from "@common/enums" import { setup } from "suricata_setup" import "express-session" import { EC2_CONN } from "suricata_setup/aws-services/create-ec2-instance" import { VirtualizationType } from "@aws-sdk/client-ec2" -import { deleteKeyFromRedis, getFromRedis } from "suricata_setup/utils" import { list_images, list_machines, } from "suricata_setup/gcp-services/gcp_setup" +import { MetloRequest } from "types" +import { RedisClient } from "utils/redis" declare module "express-session" { interface SessionData { @@ -28,7 +29,7 @@ declare module "express-session" { } export const setupConnection = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -50,7 +51,7 @@ export const setupConnection = async ( ...params, id: id, } - let resp = await setup(step, type, combined_params) + let resp = await setup(req.ctx, step, type, combined_params) req.session.connection_config[id] = { ...req.session.connection_config[id], ...resp, @@ -65,7 +66,7 @@ export const setupConnection = async ( } export const awsOsChoices = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { const { id } = req.body @@ -79,7 +80,7 @@ export const awsOsChoices = async ( } export const gcpOsChoices = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -96,7 +97,7 @@ export const gcpOsChoices = async ( } export const awsInstanceChoices = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -119,7 +120,7 @@ export const awsInstanceChoices = async ( } export const gcpInstanceChoices = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -145,12 +146,12 @@ export const gcpInstanceChoices = async ( } } -export const getLongRunningState = async (req: Request, res: Response) => { +export const getLongRunningState = async (req: MetloRequest, res: Response) => { const { uuid } = req.params try { - let resp: STEP_RESPONSE = await getFromRedis(uuid) + let resp: STEP_RESPONSE = await RedisClient.getFromRedis(req.ctx, uuid) if (["OK", "FAIL"].includes(resp.success)) { - await deleteKeyFromRedis(uuid) + await RedisClient.deleteKeyFromRedis(req.ctx, uuid) } try { // try to add things to connection cache if they exist diff --git a/backend/src/api/spec/index.ts b/backend/src/api/spec/index.ts index 885a34ad..40ea9185 100644 --- a/backend/src/api/spec/index.ts +++ b/backend/src/api/spec/index.ts @@ -8,14 +8,16 @@ import { AppDataSource } from "data-source" import { OpenApiSpec } from "models" import { SpecExtension } from "@common/enums" import { EXTENSION_TO_MIME_TYPE } from "~/constants" +import { MetloRequest } from "types" +import { getRepository } from "services/database/utils" export const getSpecHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { specFileName } = req.params - const spec = await SpecService.getSpec(specFileName) + const spec = await SpecService.getSpec(req.ctx, specFileName) await ApiResponseHandler.success(res, spec) } catch (err) { await ApiResponseHandler.error(res, err) @@ -23,13 +25,13 @@ export const getSpecHandler = async ( } export const getSpecListHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { listAutoGenerated: listAutoGeneratedParam } = req.query const listAutoGenerated = listAutoGeneratedParam === "true" ? true : false - const specList = await SpecService.getSpecs(listAutoGenerated) + const specList = await SpecService.getSpecs(req.ctx, listAutoGenerated) await ApiResponseHandler.success(res, specList) } catch (err) { await ApiResponseHandler.error(res, err) @@ -37,7 +39,7 @@ export const getSpecListHandler = async ( } export const uploadNewSpecHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -66,7 +68,7 @@ export const uploadNewSpecHandler = async ( if (!fileName) { throw new Error400BadRequest("No filename provided.") } - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + const openApiSpecRepository = getRepository(req.ctx, OpenApiSpec) const exisitingSpec = await openApiSpecRepository.findOneBy({ name: fileName, }) @@ -77,6 +79,7 @@ export const uploadNewSpecHandler = async ( specFile.buffer.toString(), ) as JSONValue await SpecService.uploadNewSpec( + req.ctx, specObject, fileName, extension, @@ -89,12 +92,12 @@ export const uploadNewSpecHandler = async ( } export const deleteSpecHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { specFileName } = req.params - await SpecService.deleteSpec(specFileName) + await SpecService.deleteSpec(req.ctx, specFileName) await ApiResponseHandler.success(res, null) } catch (err) { await ApiResponseHandler.error(res, err) @@ -102,7 +105,7 @@ export const deleteSpecHandler = async ( } export const updateSpecHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { @@ -125,7 +128,7 @@ export const updateSpecHandler = async ( if (!specFileName) { throw new Error400BadRequest("No filename provided.") } - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + const openApiSpecRepository = getRepository(req.ctx, OpenApiSpec) const exisitingSpec = await openApiSpecRepository.findOneBy({ name: specFileName, }) @@ -136,6 +139,7 @@ export const updateSpecHandler = async ( specFile.buffer.toString(), ) as JSONValue await SpecService.updateSpec( + req.ctx, specObject, specFileName, extension, diff --git a/backend/src/api/summary/index.ts b/backend/src/api/summary/index.ts index 45351406..b611b6ec 100644 --- a/backend/src/api/summary/index.ts +++ b/backend/src/api/summary/index.ts @@ -1,13 +1,14 @@ -import { Request, Response } from "express" -import { SummaryService } from "services/summary" +import { Response } from "express" +import { getSummaryData } from "services/summary" import ApiResponseHandler from "api-response-handler" +import { MetloRequest } from "types" export const getSummaryHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { - const summaryResponse = await SummaryService.getSummaryData() + const summaryResponse = await getSummaryData(req.ctx) await ApiResponseHandler.success(res, summaryResponse) } catch (err) { await ApiResponseHandler.error(res, err) diff --git a/backend/src/api/tests/index.ts b/backend/src/api/tests/index.ts index 6a003dd9..43f641e1 100644 --- a/backend/src/api/tests/index.ts +++ b/backend/src/api/tests/index.ts @@ -1,17 +1,21 @@ -import { Request, Response } from "express" +import { Response } from "express" import ApiResponseHandler from "api-response-handler" import { runTest } from "@metlo/testing" -import { AppDataSource } from "data-source" import { ApiEndpointTest } from "models" import { GetEndpointsService } from "services/get-endpoints" +import { getRepoQB } from "services/database/utils" +import { MetloRequest } from "types" export const runTestHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { try { const { test, endpointUuid } = req.body - const endpoint = await GetEndpointsService.getEndpoint(endpointUuid) + const endpoint = await GetEndpointsService.getEndpoint( + req.ctx, + endpointUuid, + ) let envVars = new Map() envVars.set("baseUrl", `https://${endpoint.host}`) const testRes = await runTest(test, envVars) @@ -21,13 +25,15 @@ export const runTestHandler = async ( } } -export const saveTest = async (req: Request, res: Response): Promise => { +export const saveTest = async ( + req: MetloRequest, + res: Response, +): Promise => { const { test: { uuid, name, tags, requests }, endpointUuid, } = req.body - let testInsert = await AppDataSource.getRepository(ApiEndpointTest) - .createQueryBuilder() + let testInsert = await getRepoQB(req.ctx, ApiEndpointTest) .insert() .into(ApiEndpointTest) .values({ @@ -41,19 +47,20 @@ export const saveTest = async (req: Request, res: Response): Promise => { }) .orUpdate(["name", "tags", "requests"], ["uuid"]) .execute() - let resp = await AppDataSource.getRepository(ApiEndpointTest) - .createQueryBuilder() + let resp = await getRepoQB(req.ctx, ApiEndpointTest) .select() .where("uuid = :uuid", testInsert.identifiers[0]) .getOne() await ApiResponseHandler.success(res, resp) } -export const getTest = async (req: Request, res: Response): Promise => { +export const getTest = async ( + req: MetloRequest, + res: Response, +): Promise => { const { uuid } = req.params try { - let resp = await AppDataSource.getRepository(ApiEndpointTest) - .createQueryBuilder() + let resp = await getRepoQB(req.ctx, ApiEndpointTest) .select() .where("uuid = :uuid", { uuid }) .getOne() @@ -63,12 +70,14 @@ export const getTest = async (req: Request, res: Response): Promise => { } } -export const listTests = async (req: Request, res: Response): Promise => { +export const listTests = async ( + req: MetloRequest, + res: Response, +): Promise => { const { hostname } = req.query var resp: ApiEndpointTest[] try { - let partial_resp = AppDataSource.getRepository(ApiEndpointTest) - .createQueryBuilder("test") + let partial_resp = getRepoQB(req.ctx, ApiEndpointTest, "test") .select() .leftJoinAndSelect("test.apiEndpoint", "endpoint") if (hostname) { @@ -86,14 +95,13 @@ export const listTests = async (req: Request, res: Response): Promise => { } export const deleteTest = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { const { uuid } = req.params try { - let resp = await AppDataSource.getRepository(ApiEndpointTest) - .createQueryBuilder() + let resp = await getRepoQB(req.ctx, ApiEndpointTest) .delete() .from(ApiEndpointTest) .where("uuid = :uuid", { uuid }) diff --git a/backend/src/collector.ts b/backend/src/collector.ts index 05333037..ce010b18 100644 --- a/backend/src/collector.ts +++ b/backend/src/collector.ts @@ -1,28 +1,26 @@ -import express, { Express, Request, Response } from "express" import dotenv from "dotenv" -import yaml from "js-yaml" -import fs from "fs" +dotenv.config() + +import express, { Express, Request, Response } from "express" import { AppDataSource } from "data-source" import { logRequestBatchHandler, logRequestSingleHandler, } from "collector_src/log-request" import { verifyApiKeyMiddleware } from "middleware/verify-api-key-middleware" -import { AuthenticationConfig } from "models" -import { getPathRegex } from "utils" -import { AuthType, DisableRestMethod } from "@common/enums" import { bodyParserMiddleware } from "middleware/body-parser-middleware" -import { AUTH_CONFIG_LIST_KEY, BLOCK_FIELDS_ALL_REGEX } from "./constants" -import { BlockFieldsService } from "services/block-fields" -import { BlockFieldEntry } from "@common/types" -import { RedisClient } from "utils/redis" - -dotenv.config() +import { populateBlockFields } from "collector_src/block-fields" +import { populateAuthentication } from "collector_src/authentication" +import { MetloContext, MetloRequest } from "types" const app: Express = express() const port = process.env.PORT || 8081 - app.disable("x-powered-by") + +app.use(async (req: MetloRequest, res, next) => { + req.ctx = {} + next() +}) app.use(express.json({ limit: "250mb" })) app.use(express.urlencoded({ limit: "250mb", extended: true })) @@ -36,169 +34,6 @@ app.use(bodyParserMiddleware) app.post("/api/v1/log-request/single", logRequestSingleHandler) app.post("/api/v1/log-request/batch", logRequestBatchHandler) -const addToBlockFields = ( - blockFieldsEntries: Record, - host: string, - method: DisableRestMethod, - path: string, - pathRegex: string, - disabledPaths: string[], -) => { - const disabledPathsObj = { - reqQuery: [], - reqHeaders: [], - reqBody: [], - resHeaders: [], - resBody: [], - } - disabledPaths.forEach(path => { - if (path.includes("req.query")) disabledPathsObj.reqQuery.push(path) - else if (path.includes("req.headers")) - disabledPathsObj.reqHeaders.push(path) - else if (path.includes("req.body")) disabledPathsObj.reqBody.push(path) - else if (path.includes("res.headers")) - disabledPathsObj.resHeaders.push(path) - else if (path.includes("res.body")) disabledPathsObj.resBody.push(path) - }) - const entry = { - method, - path, - pathRegex, - disabledPaths: disabledPathsObj, - numberParams: BlockFieldsService.getNumberParams(pathRegex, method, path), - } - if (blockFieldsEntries[host]) { - blockFieldsEntries[host].push(entry) - } else { - blockFieldsEntries[host] = [entry] - } -} - -const populateBlockFields = async () => { - try { - const metloConfig: object = yaml.load( - fs.readFileSync("./metlo-config.yaml", "utf-8"), - ) as object - const blockFieldsDoc = metloConfig?.["blockFields"] - const blockFieldsEntries: Record = {} - if (blockFieldsDoc) { - for (const host in blockFieldsDoc) { - const hostObj = blockFieldsDoc[host] - let allDisablePaths = [] - if (hostObj) { - if (hostObj["ALL"]) { - allDisablePaths = hostObj["ALL"]["disable_paths"] ?? [] - const pathRegex = BLOCK_FIELDS_ALL_REGEX - const path = "/" - addToBlockFields( - blockFieldsEntries, - host, - DisableRestMethod.ALL, - path, - pathRegex, - allDisablePaths, - ) - } - for (const endpoint in hostObj) { - if (endpoint && endpoint !== "ALL") { - let endpointDisablePaths = allDisablePaths - if (hostObj[endpoint]["ALL"]) { - endpointDisablePaths = endpointDisablePaths?.concat( - hostObj[endpoint]["ALL"]["disable_paths"] ?? [], - ) - const pathRegex = getPathRegex(endpoint) - addToBlockFields( - blockFieldsEntries, - host, - DisableRestMethod.ALL, - endpoint, - pathRegex, - endpointDisablePaths, - ) - } - for (const method in hostObj[endpoint]) { - if (method && method !== "ALL") { - const blockFieldMethod = DisableRestMethod[method] - const pathRegex = getPathRegex(endpoint) - const disabledPaths = endpointDisablePaths?.concat( - hostObj[endpoint][method]?.["disable_paths"] ?? [], - ) - addToBlockFields( - blockFieldsEntries, - host, - blockFieldMethod, - endpoint, - pathRegex, - disabledPaths, - ) - } - } - } - } - } - } - } - BlockFieldsService.entries = blockFieldsEntries - } catch (err) { - console.error(`Error in populating metlo config blockFields: ${err}`) - } -} - -const populateAuthentication = async () => { - const key = process.env.ENCRYPTION_KEY - if (!key) { - console.error(`No ENCRYPTION_KEY found. Cannot set authentication config.`) - return - } - const queryRunner = AppDataSource.createQueryRunner() - await queryRunner.connect() - try { - await queryRunner.startTransaction() - const metloConfig: object = yaml.load( - fs.readFileSync("./metlo-config.yaml", "utf-8"), - ) as object - const authConfigDoc = metloConfig?.["authentication"] - const authConfigEntries: AuthenticationConfig[] = [] - const currAuthConfigEntries = await RedisClient.getValuesFromSet( - AUTH_CONFIG_LIST_KEY, - ) - if (authConfigDoc) { - authConfigDoc.forEach(item => { - const newConfig = new AuthenticationConfig() - newConfig.host = item.host - newConfig.authType = item.authType as AuthType - if (item.headerKey) newConfig.headerKey = item.headerKey - if (item.jwtUserPath) newConfig.jwtUserPath = item.jwtUserPath - if (item.cookieName) newConfig.cookieName = item.cookieName - authConfigEntries.push(newConfig) - }) - } - const deleteQb = queryRunner.manager - .createQueryBuilder() - .delete() - .from(AuthenticationConfig) - const addQb = queryRunner.manager - .createQueryBuilder() - .insert() - .into(AuthenticationConfig) - .values(authConfigEntries) - await deleteQb.execute() - await addQb.execute() - await queryRunner.commitTransaction() - if (currAuthConfigEntries) { - await RedisClient.deleteFromRedis([ - ...currAuthConfigEntries, - AUTH_CONFIG_LIST_KEY, - ]) - } - } catch (err) { - console.error(`Error in populating metlo config authentication: ${err}`) - await queryRunner.rollbackTransaction() - } finally { - await queryRunner?.release() - } -} - const main = async () => { try { const datasource = await AppDataSource.initialize() @@ -210,7 +45,8 @@ const main = async () => { app.listen(port, () => { console.log(`⚡️[server]: Server is running at http://localhost:${port}`) }) - await Promise.all([populateBlockFields(), populateAuthentication()]) + const ctx: MetloContext = {} + await Promise.all([populateBlockFields(), populateAuthentication(ctx)]) } catch (err) { console.error(`CatchBlockInsideMain: ${err}`) } diff --git a/backend/src/collector_src/authentication.ts b/backend/src/collector_src/authentication.ts new file mode 100644 index 00000000..c78dc355 --- /dev/null +++ b/backend/src/collector_src/authentication.ts @@ -0,0 +1,61 @@ +import yaml from "js-yaml" +import fs from "fs" +import { AppDataSource } from "data-source" +import { AuthenticationConfig } from "models" +import { AuthType } from "@common/enums" +import { RedisClient } from "utils/redis" +import { AUTH_CONFIG_LIST_KEY } from "~/constants" +import { MetloContext } from "types" +import { getQB } from "services/database/utils" + +export const populateAuthentication = async (ctx: MetloContext) => { + const key = process.env.ENCRYPTION_KEY + if (!key) { + console.error(`No ENCRYPTION_KEY found. Cannot set authentication config.`) + return + } + const queryRunner = AppDataSource.createQueryRunner() + await queryRunner.connect() + try { + await queryRunner.startTransaction() + const metloConfig: object = yaml.load( + fs.readFileSync("./metlo-config.yaml", "utf-8"), + ) as object + const authConfigDoc = metloConfig?.["authentication"] + const authConfigEntries: AuthenticationConfig[] = [] + const currAuthConfigEntries = await RedisClient.getValuesFromSet( + ctx, + AUTH_CONFIG_LIST_KEY, + ) + if (authConfigDoc) { + authConfigDoc.forEach(item => { + const newConfig = new AuthenticationConfig() + newConfig.host = item.host + newConfig.authType = item.authType as AuthType + if (item.headerKey) newConfig.headerKey = item.headerKey + if (item.jwtUserPath) newConfig.jwtUserPath = item.jwtUserPath + if (item.cookieName) newConfig.cookieName = item.cookieName + authConfigEntries.push(newConfig) + }) + } + const deleteQb = getQB(ctx, queryRunner).delete().from(AuthenticationConfig) + const addQb = getQB(ctx, queryRunner) + .insert() + .into(AuthenticationConfig) + .values(authConfigEntries) + await deleteQb.execute() + await addQb.execute() + await queryRunner.commitTransaction() + if (currAuthConfigEntries) { + await RedisClient.deleteFromRedis(ctx, [ + ...currAuthConfigEntries, + AUTH_CONFIG_LIST_KEY, + ]) + } + } catch (err) { + console.error(`Error in populating metlo config authentication: ${err}`) + await queryRunner.rollbackTransaction() + } finally { + await queryRunner?.release() + } +} diff --git a/backend/src/collector_src/block-fields.ts b/backend/src/collector_src/block-fields.ts new file mode 100644 index 00000000..974bed52 --- /dev/null +++ b/backend/src/collector_src/block-fields.ts @@ -0,0 +1,115 @@ +import yaml from "js-yaml" +import fs from "fs" +import { getPathRegex } from "utils" +import { DisableRestMethod } from "@common/enums" +import { BlockFieldsService } from "services/block-fields" +import { BlockFieldEntry } from "@common/types" +import { BLOCK_FIELDS_ALL_REGEX } from "~/constants" + +const addToBlockFields = ( + blockFieldsEntries: Record, + host: string, + method: DisableRestMethod, + path: string, + pathRegex: string, + disabledPaths: string[], +) => { + const disabledPathsObj = { + reqQuery: [], + reqHeaders: [], + reqBody: [], + resHeaders: [], + resBody: [], + } + disabledPaths.forEach(path => { + if (path.includes("req.query")) disabledPathsObj.reqQuery.push(path) + else if (path.includes("req.headers")) + disabledPathsObj.reqHeaders.push(path) + else if (path.includes("req.body")) disabledPathsObj.reqBody.push(path) + else if (path.includes("res.headers")) + disabledPathsObj.resHeaders.push(path) + else if (path.includes("res.body")) disabledPathsObj.resBody.push(path) + }) + const entry = { + method, + path, + pathRegex, + disabledPaths: disabledPathsObj, + numberParams: BlockFieldsService.getNumberParams(pathRegex, method, path), + } + if (blockFieldsEntries[host]) { + blockFieldsEntries[host].push(entry) + } else { + blockFieldsEntries[host] = [entry] + } +} + +export const populateBlockFields = async () => { + try { + const metloConfig: object = yaml.load( + fs.readFileSync("./metlo-config.yaml", "utf-8"), + ) as object + const blockFieldsDoc = metloConfig?.["blockFields"] + const blockFieldsEntries: Record = {} + if (blockFieldsDoc) { + for (const host in blockFieldsDoc) { + const hostObj = blockFieldsDoc[host] + let allDisablePaths = [] + if (hostObj) { + if (hostObj["ALL"]) { + allDisablePaths = hostObj["ALL"]["disable_paths"] ?? [] + const pathRegex = BLOCK_FIELDS_ALL_REGEX + const path = "/" + addToBlockFields( + blockFieldsEntries, + host, + DisableRestMethod.ALL, + path, + pathRegex, + allDisablePaths, + ) + } + for (const endpoint in hostObj) { + if (endpoint && endpoint !== "ALL") { + let endpointDisablePaths = allDisablePaths + if (hostObj[endpoint]["ALL"]) { + endpointDisablePaths = endpointDisablePaths?.concat( + hostObj[endpoint]["ALL"]["disable_paths"] ?? [], + ) + const pathRegex = getPathRegex(endpoint) + addToBlockFields( + blockFieldsEntries, + host, + DisableRestMethod.ALL, + endpoint, + pathRegex, + endpointDisablePaths, + ) + } + for (const method in hostObj[endpoint]) { + if (method && method !== "ALL") { + const blockFieldMethod = DisableRestMethod[method] + const pathRegex = getPathRegex(endpoint) + const disabledPaths = endpointDisablePaths?.concat( + hostObj[endpoint][method]?.["disable_paths"] ?? [], + ) + addToBlockFields( + blockFieldsEntries, + host, + blockFieldMethod, + endpoint, + pathRegex, + disabledPaths, + ) + } + } + } + } + } + } + } + BlockFieldsService.entries = blockFieldsEntries + } catch (err) { + console.error(`Error in populating metlo config blockFields: ${err}`) + } +} diff --git a/backend/src/collector_src/log-request/index.ts b/backend/src/collector_src/log-request/index.ts index 32b98fd4..88e08f00 100644 --- a/backend/src/collector_src/log-request/index.ts +++ b/backend/src/collector_src/log-request/index.ts @@ -1,15 +1,16 @@ -import { Request, Response } from "express" +import { Response } from "express" import { LogRequestService } from "services/log-request" import { TraceParams } from "@common/types" import ApiResponseHandler from "api-response-handler" +import { MetloRequest } from "types" export const logRequestSingleHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { const traceParams: TraceParams = req.body try { - await LogRequestService.logRequest(traceParams) + await LogRequestService.logRequest(req.ctx, traceParams) await ApiResponseHandler.success(res, null) } catch (err) { await ApiResponseHandler.error(res, err) @@ -17,12 +18,12 @@ export const logRequestSingleHandler = async ( } export const logRequestBatchHandler = async ( - req: Request, + req: MetloRequest, res: Response, ): Promise => { const traceParamsBatch: TraceParams[] = req.body try { - await LogRequestService.logRequestBatch(traceParamsBatch) + await LogRequestService.logRequestBatch(req.ctx, traceParamsBatch) await ApiResponseHandler.success(res, null) } catch (err) { await ApiResponseHandler.error(res, err) diff --git a/backend/src/index.ts b/backend/src/index.ts index 88a28a7e..2a8cfaa8 100644 --- a/backend/src/index.ts +++ b/backend/src/index.ts @@ -1,7 +1,7 @@ import dotenv from "dotenv" dotenv.config() -import express, { Express, Request, Response } from "express" +import express, { Express, Response } from "express" import { TypeormStore } from "connect-typeorm" import session from "express-session" import { InstanceSettings, Session as SessionModel } from "models" @@ -22,6 +22,7 @@ import { import { getAlertsHandler, updateAlertHandler } from "api/alert" import { deleteDataFieldHandler, updateDataFieldClasses } from "api/data-field" import { getSummaryHandler } from "api/summary" +import { MetloRequest } from "types" import { AppDataSource } from "data-source" import { MulterSource } from "multer-source" import { @@ -56,11 +57,16 @@ import { putInstanceSettingsHandler, } from "api/settings" -const app: Express = express() const port = process.env.PORT || 8080 RedisClient.getInstance() +const app: Express = express() app.disable("x-powered-by") + +app.use(async (req: MetloRequest, res, next) => { + req.ctx = {} + next() +}) app.use(express.json()) app.use(express.urlencoded({ extended: true })) app.use( @@ -84,7 +90,7 @@ app.use(async (req, res, next) => { } }) -app.get("/api/v1", (req: Request, res: Response) => { +app.get("/api/v1", (req: MetloRequest, res: Response) => { res.send("OK") }) diff --git a/backend/src/jobs.ts b/backend/src/jobs.ts index 17e153fc..bcff4793 100644 --- a/backend/src/jobs.ts +++ b/backend/src/jobs.ts @@ -10,6 +10,7 @@ import { import runAllTests from "services/testing/runAllTests" import { logAggregatedStats } from "services/logging" import { DateTime } from "luxon" +import { MetloContext } from "types" const log = (logMessage: string, newLine?: boolean) => console.log( @@ -17,6 +18,7 @@ const log = (logMessage: string, newLine?: boolean) => ) const main = async () => { + const ctx: MetloContext = {} const datasource = await AppDataSource.initialize() if (!datasource.isInitialized) { console.error("Couldn't initialize datasource...") @@ -34,7 +36,7 @@ const main = async () => { schedule.scheduleJob("*/10 * * * *", () => { generateSpecSem.take(async () => { log("Generating OpenAPI Spec Files...", true) - await generateOpenApiSpec() + await generateOpenApiSpec(ctx) log("Finished generating OpenAPI Spec Files.") generateSpecSem.leave() }) @@ -43,7 +45,7 @@ const main = async () => { schedule.scheduleJob("30 * * * * ", () => { checkForUnauthenticatedSem.take(async () => { log("Checking for Unauthenticated Endpoints", true) - await checkForUnauthenticatedEndpoints() + await checkForUnauthenticatedEndpoints(ctx) log("Finished checking for Unauthenticated Endpoints") checkForUnauthenticatedSem.leave() }) @@ -53,7 +55,7 @@ const main = async () => { schedule.scheduleJob("15 * * * *", () => { unsecuredAlertsSem.take(async () => { log("Generating Alerts for Unsecured Endpoints", true) - await monitorEndpointForHSTS() + await monitorEndpointForHSTS(ctx) log("Finished generating alerts for Unsecured Endpoints.") unsecuredAlertsSem.leave() }) @@ -62,7 +64,7 @@ const main = async () => { schedule.scheduleJob("30 * * * *", () => { testsSem.take(async () => { log("Running Tests...", true) - await runAllTests() + await runAllTests(ctx) log("Finished running tests.") testsSem.leave() }) @@ -71,7 +73,7 @@ const main = async () => { schedule.scheduleJob("*/10 * * * *", () => { clearApiTracesSem.take(async () => { log("Clearing Api Trace data...", true) - await clearApiTraces() + await clearApiTraces(ctx) log("Finished clearing Api Trace data.") clearApiTracesSem.leave() }) @@ -81,7 +83,7 @@ const main = async () => { schedule.scheduleJob("0 */6 * * *", () => { logAggregateStatsSem.take(async () => { log("Logging Aggregated Stats...", true) - await logAggregatedStats() + await logAggregatedStats(ctx) log("Finished Logging Aggregated Stats.") logAggregateStatsSem.leave() }) diff --git a/backend/src/middleware/verify-api-key-middleware.ts b/backend/src/middleware/verify-api-key-middleware.ts index cd898545..886a4ffe 100644 --- a/backend/src/middleware/verify-api-key-middleware.ts +++ b/backend/src/middleware/verify-api-key-middleware.ts @@ -1,25 +1,25 @@ import ApiResponseHandler from "api-response-handler" -import { AppDataSource } from "data-source" import Error401Unauthorized from "errors/error-401-unauthorized" -import { NextFunction, Request, Response } from "express" +import { NextFunction, Response } from "express" import { ApiKey } from "models" +import { getRepoQB } from "services/database/utils" +import { MetloRequest } from "types" import { hasher } from "utils/hash" import { RedisClient } from "utils/redis" export async function verifyApiKeyMiddleware( - req: Request, + req: MetloRequest, res: Response, next: NextFunction, ) { try { let hashKey = hasher(req.headers.authorization) - const cachedHashKey = await RedisClient.getFromRedis(hashKey) + const cachedHashKey = await RedisClient.getFromRedis(req.ctx, hashKey) if (!cachedHashKey) { - await AppDataSource.getRepository(ApiKey) - .createQueryBuilder("key") + await getRepoQB(req.ctx, ApiKey, "key") .where("key.apiKeyHash = :hash", { hash: hashKey }) .getOneOrFail() - RedisClient.addToRedis(hashKey, true, 5) + RedisClient.addToRedis(req.ctx, hashKey, true, 5) } next() } catch (err) { diff --git a/backend/src/models/aggregate-trace-data-hourly.ts b/backend/src/models/aggregate-trace-data-hourly.ts index ffeba0c3..434bd279 100644 --- a/backend/src/models/aggregate-trace-data-hourly.ts +++ b/backend/src/models/aggregate-trace-data-hourly.ts @@ -1,17 +1,17 @@ import { Entity, Unique, - BaseEntity, PrimaryGeneratedColumn, Column, ManyToOne, JoinColumn, } from "typeorm" import { ApiEndpoint } from "./api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" @Entity() @Unique("unique_constraint_hourly", ["apiEndpoint", "hour"]) -export class AggregateTraceDataHourly extends BaseEntity { +export class AggregateTraceDataHourly extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/alert.ts b/backend/src/models/alert.ts index 87940e08..71a33b60 100644 --- a/backend/src/models/alert.ts +++ b/backend/src/models/alert.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Column, CreateDateColumn, Entity, @@ -10,9 +9,10 @@ import { } from "typeorm" import { AlertType, RiskScore, Status } from "@common/enums" import { ApiEndpoint } from "models/api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class Alert extends BaseEntity { +export class Alert extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/api-endpoint-test.ts b/backend/src/models/api-endpoint-test.ts index 3dfb43ab..d6404d25 100644 --- a/backend/src/models/api-endpoint-test.ts +++ b/backend/src/models/api-endpoint-test.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Column, Entity, JoinColumn, @@ -8,9 +7,10 @@ import { } from "typeorm" import { Request } from "@metlo/testing" import { ApiEndpoint } from "models/api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class ApiEndpointTest extends BaseEntity { +export class ApiEndpointTest extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/api-endpoint.ts b/backend/src/models/api-endpoint.ts index 061ee9d4..a5474fb6 100644 --- a/backend/src/models/api-endpoint.ts +++ b/backend/src/models/api-endpoint.ts @@ -16,10 +16,11 @@ import { OpenApiSpec } from "models/openapi-spec" import { RestMethod, RiskScore } from "@common/enums" import { isParameter } from "utils" import { getPathTokens } from "@common/utils" +import MetloBaseEntity from "./metlo-base-entity" @Entity() @Unique("unique_constraint_api_endpoint", ["path", "method", "host"]) -export class ApiEndpoint extends BaseEntity { +export class ApiEndpoint extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/api-trace.ts b/backend/src/models/api-trace.ts index 2e1be40b..b4895184 100644 --- a/backend/src/models/api-trace.ts +++ b/backend/src/models/api-trace.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Column, CreateDateColumn, Entity, @@ -11,9 +10,10 @@ import { import { Meta, PairObject, SessionMeta } from "@common/types" import { RestMethod } from "@common/enums" import { ApiEndpoint } from "models/api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class ApiTrace extends BaseEntity { +export class ApiTrace extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/attack.ts b/backend/src/models/attack.ts index 95873b5b..3a3ae5d7 100644 --- a/backend/src/models/attack.ts +++ b/backend/src/models/attack.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Column, CreateDateColumn, Entity, @@ -11,9 +10,10 @@ import { import { RiskScore, AttackType } from "@common/enums" import { AttackMeta } from "@common/types" import { ApiEndpoint } from "models/api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" -@Entity() -export class Attack extends BaseEntity { +@Entity("attack") +export class Attack extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/authentication-config.ts b/backend/src/models/authentication-config.ts index a1a1c6fd..76bb25ea 100644 --- a/backend/src/models/authentication-config.ts +++ b/backend/src/models/authentication-config.ts @@ -1,8 +1,9 @@ import { BaseEntity, Column, Entity, PrimaryColumn } from "typeorm" import { AuthType } from "@common/enums" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class AuthenticationConfig extends BaseEntity { +export class AuthenticationConfig extends MetloBaseEntity { @PrimaryColumn() host: string diff --git a/backend/src/models/block-fields.ts b/backend/src/models/block-fields.ts index 11b85c1e..432f1083 100644 --- a/backend/src/models/block-fields.ts +++ b/backend/src/models/block-fields.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Column, Entity, PrimaryGeneratedColumn, @@ -8,9 +7,10 @@ import { import { DisableRestMethod } from "@common/enums" import { isParameter } from "utils" import { getPathTokens } from "@common/utils" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class BlockFields extends BaseEntity { +export class BlockFields extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/connections.ts b/backend/src/models/connections.ts index c868d313..d4a5364f 100644 --- a/backend/src/models/connections.ts +++ b/backend/src/models/connections.ts @@ -1,6 +1,5 @@ import { Entity, - BaseEntity, Column, PrimaryGeneratedColumn, CreateDateColumn, @@ -16,9 +15,10 @@ import { SSH_INFO, } from "@common/types" import { encrypt, generate_iv } from "utils/encryption" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class Connections extends BaseEntity { +export class Connections extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/data-field.ts b/backend/src/models/data-field.ts index 65951ed8..0352e12b 100644 --- a/backend/src/models/data-field.ts +++ b/backend/src/models/data-field.ts @@ -1,6 +1,5 @@ import { Entity, - BaseEntity, Column, ManyToOne, PrimaryGeneratedColumn, @@ -11,6 +10,7 @@ import { } from "typeorm" import { DataClass, DataTag, DataType, DataSection } from "@common/enums" import { ApiEndpoint } from "models/api-endpoint" +import MetloBaseEntity from "./metlo-base-entity" @Entity() @Unique("unique_constraint_data_field", [ @@ -18,7 +18,7 @@ import { ApiEndpoint } from "models/api-endpoint" "dataPath", "apiEndpointUuid", ]) -export class DataField extends BaseEntity { +export class DataField extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/instance-settings.ts b/backend/src/models/instance-settings.ts index 5c603200..b7098501 100644 --- a/backend/src/models/instance-settings.ts +++ b/backend/src/models/instance-settings.ts @@ -1,7 +1,8 @@ -import { BaseEntity, Column, Entity, PrimaryGeneratedColumn } from "typeorm" +import { Column, Entity, PrimaryGeneratedColumn } from "typeorm" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class InstanceSettings extends BaseEntity { +export class InstanceSettings extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/keys.ts b/backend/src/models/keys.ts index b2aafa6c..49dbe981 100644 --- a/backend/src/models/keys.ts +++ b/backend/src/models/keys.ts @@ -1,16 +1,15 @@ import { API_KEY_TYPE } from "@common/enums" import { Entity, - BaseEntity, Column, PrimaryGeneratedColumn, CreateDateColumn, UpdateDateColumn, - Generated, } from "typeorm" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class ApiKey extends BaseEntity { +export class ApiKey extends MetloBaseEntity { @PrimaryGeneratedColumn("uuid") uuid: string diff --git a/backend/src/models/metlo-base-entity.ts b/backend/src/models/metlo-base-entity.ts new file mode 100644 index 00000000..e312d0ee --- /dev/null +++ b/backend/src/models/metlo-base-entity.ts @@ -0,0 +1,8 @@ +import { BaseEntity } from "typeorm" +import { MetloContext } from "types" + +export default class MetloBaseEntity extends BaseEntity { + static getTableName(ctx: MetloContext) { + return this.getRepository().metadata.tableName + } +} diff --git a/backend/src/models/openapi-spec.ts b/backend/src/models/openapi-spec.ts index a3287c4b..dd42499b 100644 --- a/backend/src/models/openapi-spec.ts +++ b/backend/src/models/openapi-spec.ts @@ -1,5 +1,4 @@ import { - BaseEntity, Entity, PrimaryColumn, Column, @@ -8,9 +7,10 @@ import { } from "typeorm" import { SpecExtension } from "@common/enums" import { MinimizedSpecContext } from "@common/types" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class OpenApiSpec extends BaseEntity { +export class OpenApiSpec extends MetloBaseEntity { @PrimaryColumn() name: string diff --git a/backend/src/models/sessions.ts b/backend/src/models/sessions.ts index d6164bdc..4ea1008c 100644 --- a/backend/src/models/sessions.ts +++ b/backend/src/models/sessions.ts @@ -1,8 +1,9 @@ import { ISession } from "connect-typeorm" import { Column, DeleteDateColumn, Entity, Index, PrimaryColumn } from "typeorm" +import MetloBaseEntity from "./metlo-base-entity" @Entity() -export class Session implements ISession { +export class Session extends MetloBaseEntity implements ISession { @Index() @Column("bigint") public expiredAt = Date.now() diff --git a/backend/src/scripts/generate-alerts.ts b/backend/src/scripts/generate-alerts.ts index 5a18500d..73bfe9f2 100644 --- a/backend/src/scripts/generate-alerts.ts +++ b/backend/src/scripts/generate-alerts.ts @@ -4,8 +4,10 @@ import { ALERT_TYPE_TO_RISK_SCORE } from "@common/maps" import { AppDataSource } from "data-source" import { Alert } from "models" import { DatabaseService } from "services/database" +import { MetloContext } from "types" const generateAlert = async ( + ctx: MetloContext, alertType: string, apiEndpointUuid: string, description: string, @@ -18,7 +20,7 @@ const generateAlert = async ( newAlert.description = description newAlert.riskScore = ALERT_TYPE_TO_RISK_SCORE[AlertType[alertType]] newAlert.context = JSON.parse(context || "{}") - await DatabaseService.executeTransactions([[newAlert]], [], false) + await DatabaseService.executeTransactions(ctx, [[newAlert]], [], false) } catch (err) { console.error(`Error generating new alert from script: ${err}`) } @@ -36,7 +38,7 @@ const main = async () => { const apiEndpointUuid = args["endpointUuid"] const description = args["description"] const context = args["context"] - await generateAlert(alertType, apiEndpointUuid, description, context) + await generateAlert({}, alertType, apiEndpointUuid, description, context) } main() diff --git a/backend/src/scripts/generate-attacks.ts b/backend/src/scripts/generate-attacks.ts index 0ca57c1e..fb82b676 100644 --- a/backend/src/scripts/generate-attacks.ts +++ b/backend/src/scripts/generate-attacks.ts @@ -5,6 +5,8 @@ import { ApiEndpoint, Attack } from "models" import { AttackType } from "@common/enums" import { ATTACK_TYPE_TO_RISK_SCORE } from "@common/maps" import { DateTime } from "luxon" +import { getEntityManager } from "services/database/utils" +import { MetloContext } from "types" const randomDate = (start?: boolean) => { const startTime = start @@ -16,13 +18,16 @@ const randomDate = (start?: boolean) => { return new Date(startTime + Math.random() * (endTime - startTime)) } -const generateAttacks = async (numAttacks: number) => { +const generateAttacks = async (ctx: MetloContext, numAttacks: number) => { const queryRunner = AppDataSource.createQueryRunner() await queryRunner.connect() try { - const endpoints = await queryRunner.manager.find(ApiEndpoint, { - select: { uuid: true, host: true }, - }) + const endpoints = await getEntityManager(ctx, queryRunner).find( + ApiEndpoint, + { + select: { uuid: true, host: true }, + }, + ) const attackTypes = Object.keys(AttackType) const insertAttacks: Attack[] = [] for (let i = 0; i < numAttacks; i++) { @@ -39,7 +44,7 @@ const generateAttacks = async (numAttacks: number) => { newAttack.host = endpoints[randEndpointNum].host insertAttacks.push(newAttack) } - await queryRunner.manager.insert(Attack, insertAttacks) + await getEntityManager(ctx, queryRunner).insert(Attack, insertAttacks) } catch (err) { console.error(`Encountered error while generating sample attacks: ${err}`) } finally { @@ -48,6 +53,7 @@ const generateAttacks = async (numAttacks: number) => { } const main = async () => { + const ctx: MetloContext = {} const datasource = await AppDataSource.initialize() if (!datasource.isInitialized) { console.error("Couldn't initialize datasource...") @@ -56,7 +62,7 @@ const main = async () => { console.log("AppDataSource Initialized...") const args = yargs.argv const numAttacks = args["numAttacks"] ?? 20 - await generateAttacks(numAttacks) + await generateAttacks(ctx, numAttacks) } main() diff --git a/backend/src/scripts/generate-endpoints.ts b/backend/src/scripts/generate-endpoints.ts index 8cfbce5e..eb95e63f 100644 --- a/backend/src/scripts/generate-endpoints.ts +++ b/backend/src/scripts/generate-endpoints.ts @@ -1,5 +1,6 @@ import { AppDataSource } from "data-source" import { generateEndpointsFromTraces } from "services/jobs" +import { MetloContext } from "types" const main = async () => { const datasource = await AppDataSource.initialize() @@ -9,7 +10,8 @@ const main = async () => { } console.log("AppDataSource Initialized...") console.log("Generating Endpoints and OpenAPI Spec Files...") - await generateEndpointsFromTraces() + const ctx: MetloContext = {} + await generateEndpointsFromTraces(ctx) console.log("Finished generating Endpoints and OpenAPI Spec Files.") } diff --git a/backend/src/services/alert/index.ts b/backend/src/services/alert/index.ts index daabb0d6..c77bd725 100644 --- a/backend/src/services/alert/index.ts +++ b/backend/src/services/alert/index.ts @@ -10,7 +10,6 @@ import validator from "validator" import jsonMap from "json-source-map" import yaml from "js-yaml" import SourceMap from "js-yaml-source-map" -import { AppDataSource } from "data-source" import { Alert, ApiEndpoint, ApiTrace, DataField, OpenApiSpec } from "models" import { AlertType, @@ -32,13 +31,16 @@ import Error409Conflict from "errors/error-409-conflict" import Error500InternalServer from "errors/error-500-internal-server" import { getPathTokens } from "@common/utils" import Error404NotFound from "errors/error-404-not-found" +import { createQB, getQB, getRepository } from "services/database/utils" +import { MetloContext } from "types" export class AlertService { static async updateAlert( + ctx: MetloContext, alertId: string, updateAlertParams: UpdateAlertParams, ): Promise { - const alertRepository = AppDataSource.getRepository(Alert) + const alertRepository = getRepository(ctx, Alert) const alert = await alertRepository.findOne({ where: { uuid: alertId, @@ -81,14 +83,19 @@ export class AlertService { default: throw new Error500InternalServer("Unknown update type.") } - await alertRepository.update({ uuid: alertId }, alert) + await createQB(ctx) + .update(Alert) + .set({ status: alert.status, resolutionMessage: alert.resolutionMessage }) + .where("uuid = :uuid", { uuid: alertId }) + .execute() return alert } static async getAlerts( + ctx: MetloContext, alertParams: GetAlertParams, ): Promise<[AlertResponse[], number]> { - const alertRepository = AppDataSource.getRepository(Alert) + const alertRepository = getRepository(ctx, Alert) let whereConditions: FindOptionsWhere[] | FindOptionsWhere = {} let paginationParams: FindManyOptions = {} @@ -199,30 +206,15 @@ export class AlertService { return alerts } - static async getAlert(alertId: string): Promise { - const alertRepository = AppDataSource.getRepository(Alert) - if (!validator.isUUID(alertId)) { - throw new Error404NotFound("Alert not found.") - } - return await alertRepository.findOneBy({ uuid: alertId }) - } - - static async getAlertWithConditions( - conditions: FindOptionsWhere, - ): Promise { - const alertRepository = AppDataSource.getRepository(Alert) - return await alertRepository.findOneBy(conditions) - } - static async existingUnresolvedAlert( + ctx: MetloContext, apiEndpointUuid: string, type: AlertType, description: string, queryRunner?: QueryRunner, ) { if (queryRunner) { - return await queryRunner.manager - .createQueryBuilder() + return await getQB(ctx, queryRunner) .select(["uuid"]) .from(Alert, "alert") .where(`"apiEndpointUuid" = :id`, { id: apiEndpointUuid }) @@ -231,7 +223,7 @@ export class AlertService { .andWhere("description = :description", { description }) .getRawOne() } - const alertRepository = AppDataSource.getRepository(Alert) + const alertRepository = getRepository(ctx, Alert) return await alertRepository.findOne({ select: { uuid: true, @@ -246,6 +238,7 @@ export class AlertService { } static async createAlert( + ctx: MetloContext, alertType: AlertType, apiEndpoint: ApiEndpoint, description?: string, @@ -270,6 +263,7 @@ export class AlertService { } if (noDuplicate) { const existing = await this.existingUnresolvedAlert( + ctx, apiEndpoint.uuid, alertType, alertDescription, @@ -307,6 +301,7 @@ export class AlertService { } static async createDataFieldAlerts( + ctx: MetloContext, dataFields: DataField[], apiEndpointUuid: string, apiEndpointPath: string, @@ -340,6 +335,7 @@ export class AlertService { basicAuthDescription, ) || (await this.existingUnresolvedAlert( + ctx, apiEndpointUuid, AlertType.BASIC_AUTHENTICATION_DETECTED, basicAuthDescription, @@ -418,6 +414,7 @@ export class AlertService { alert.description, ) || (await this.existingUnresolvedAlert( + ctx, apiEndpointUuid, alert.type, alert.description, @@ -446,6 +443,7 @@ export class AlertService { } static async createSpecDiffAlerts( + ctx: MetloContext, alertItems: Record, apiEndpointUuid: string, apiTrace: QueuedApiTrace, @@ -462,6 +460,7 @@ export class AlertService { let alerts: Alert[] = [] for (const key in alertItems) { const existing = await this.existingUnresolvedAlert( + ctx, apiEndpointUuid, AlertType.OPEN_API_SPEC_DIFF, key, @@ -516,8 +515,7 @@ export class AlertService { newAlert.description = key alerts.push(newAlert) } - await queryRunner.manager - .createQueryBuilder() + await getQB(ctx, queryRunner) .update(OpenApiSpec) .set({ minimizedSpecContext: openApiSpec.minimizedSpecContext }) .where("name = :name", { name: openApiSpec.name }) @@ -531,6 +529,7 @@ export class AlertService { } static async createMissingHSTSAlert( + ctx: MetloContext, alertProps: Array<[ApiEndpoint, ApiTrace, string]>, ) { try { @@ -540,6 +539,7 @@ export class AlertService { let alerts: Alert[] = [] for (const alertProp of alertProps) { const existing = await this.existingUnresolvedAlert( + ctx, alertProp[0].uuid, AlertType.UNSECURED_ENDPOINT_DETECTED, alertProp[2], diff --git a/backend/src/services/authentication-config/index.ts b/backend/src/services/authentication-config/index.ts index bd1af5d0..fb46fe63 100644 --- a/backend/src/services/authentication-config/index.ts +++ b/backend/src/services/authentication-config/index.ts @@ -1,20 +1,22 @@ import { AuthType } from "@common/enums" import { QueuedApiTrace, SessionMeta } from "@common/types" -import { AppDataSource } from "data-source" import { AuthenticationConfig } from "models" import { encryptEcb } from "utils/encryption" import { AuthenticationConfig as CachedAuthConfig } from "@common/types" import { AUTH_CONFIG_LIST_KEY } from "~/constants" import { RedisClient } from "utils/redis" +import { MetloContext } from "types" +import { getRepository } from "services/database/utils" export class AuthenticationConfigService { - static async setSessionMetadata(apiTrace: QueuedApiTrace) { + static async setSessionMetadata(ctx: MetloContext, apiTrace: QueuedApiTrace) { const redisKey = `auth_config_${apiTrace.host}` let cachedAuthConfig: CachedAuthConfig = await RedisClient.getFromRedis( + ctx, redisKey, ) if (!cachedAuthConfig) { - const authConfigRepo = AppDataSource.getRepository(AuthenticationConfig) + const authConfigRepo = getRepository(ctx, AuthenticationConfig) const authConfig = await authConfigRepo.findOneBy({ host: apiTrace.host, }) @@ -30,8 +32,8 @@ export class AuthenticationConfigService { } else { cachedAuthConfig = {} as CachedAuthConfig } - RedisClient.addToRedis(redisKey, cachedAuthConfig) - RedisClient.addValueToSet(AUTH_CONFIG_LIST_KEY, [ + RedisClient.addToRedis(ctx, redisKey, cachedAuthConfig) + RedisClient.addValueToSet(ctx, AUTH_CONFIG_LIST_KEY, [ `auth_config_${apiTrace.host}`, ]) } diff --git a/backend/src/services/connections/index.ts b/backend/src/services/connections/index.ts index e05399a5..d3920618 100644 --- a/backend/src/services/connections/index.ts +++ b/backend/src/services/connections/index.ts @@ -1,19 +1,23 @@ import { ConnectionType } from "@common/enums" import { AWS_CONNECTION, GCP_CONNECTION, SSH_INFO } from "@common/types" -import { AppDataSource } from "data-source" import Error500InternalServer from "errors/error-500-internal-server" import { Connections } from "models" +import { createQB, getRepoQB, getRepository } from "services/database/utils" +import { MetloContext } from "types" export class ConnectionsService { - static saveConnectionAws = async ({ - conn_meta, - name, - id, - }: { - conn_meta: AWS_CONNECTION & SSH_INFO - name: string - id: string - }) => { + static saveConnectionAws = async ( + ctx: MetloContext, + { + conn_meta, + name, + id, + }: { + conn_meta: AWS_CONNECTION & SSH_INFO + name: string + id: string + }, + ) => { const { access_id, secret_access_key, @@ -60,7 +64,7 @@ export class ConnectionsService { conn.uuid = id conn.name = name try { - const connectionRepository = AppDataSource.getRepository(Connections) + const connectionRepository = getRepository(ctx, Connections) await connectionRepository.save(conn) } catch (err) { console.error(`Error in saving connection: ${err}`) @@ -68,15 +72,18 @@ export class ConnectionsService { } } - static saveConnectionGcp = async ({ - conn_meta, - name, - id, - }: { - conn_meta: GCP_CONNECTION - name: string - id: string - }) => { + static saveConnectionGcp = async ( + ctx: MetloContext, + { + conn_meta, + name, + id, + }: { + conn_meta: GCP_CONNECTION + name: string + id: string + }, + ) => { const { key_file, project, @@ -131,7 +138,7 @@ export class ConnectionsService { conn.uuid = id conn.name = name try { - const connectionRepository = AppDataSource.getRepository(Connections) + const connectionRepository = getRepository(ctx, Connections) await connectionRepository.save(conn) } catch (err) { console.error(`Error in saving connection: ${err}`) @@ -139,11 +146,9 @@ export class ConnectionsService { } } - static listConnections = async () => { + static listConnections = async (ctx: MetloContext) => { try { - const connectionRepository = AppDataSource.getRepository(Connections) - let resp = await connectionRepository - .createQueryBuilder("conn") + let resp = await getRepoQB(ctx, Connections, "conn") .select([ "conn.uuid", "conn.name", @@ -162,11 +167,11 @@ export class ConnectionsService { } static getConnectionForUuid = async ( + ctx: MetloContext, uuid: string, with_metadata: boolean = false, ) => { try { - const connectionRepository = AppDataSource.getRepository(Connections) const selects = [ "conn.uuid", "conn.name", @@ -180,8 +185,7 @@ export class ConnectionsService { selects.push("conn.aws_meta") selects.push("conn.gcp_meta") } - let resp = connectionRepository - .createQueryBuilder("conn") + let resp = getRepoQB(ctx, Connections, "conn") .select(selects) .where("conn.uuid = :uuid", { uuid }) .getOne() @@ -192,15 +196,18 @@ export class ConnectionsService { } } - static updateConnectionForUuid = async ({ - name, - uuid, - }: { - name: string - uuid: string - }) => { + static updateConnectionForUuid = async ( + ctx: MetloContext, + { + name, + uuid, + }: { + name: string + uuid: string + }, + ) => { try { - let resp = AppDataSource.createQueryBuilder() + let resp = createQB(ctx) .update(Connections) .set({ name: name }) .where("uuid = :uuid", { uuid }) @@ -212,9 +219,9 @@ export class ConnectionsService { } } - static deleteConnectionForUuid = async ({ uuid }) => { + static deleteConnectionForUuid = async (ctx: MetloContext, { uuid }) => { try { - let resp = AppDataSource.createQueryBuilder() + let resp = createQB(ctx) .delete() .from(Connections) .where("uuid = :uuid", { uuid }) @@ -226,9 +233,9 @@ export class ConnectionsService { } } - static getNumConnections = async (): Promise => { + static getNumConnections = async (ctx: MetloContext): Promise => { try { - return await AppDataSource.getRepository(Connections).count() + return await getRepository(ctx, Connections).count() } catch (err) { console.error(`Error in Get Num Connections service: ${err}`) throw new Error500InternalServer(err) diff --git a/backend/src/services/data-field/index.ts b/backend/src/services/data-field/index.ts index 0d06d83b..2f928248 100644 --- a/backend/src/services/data-field/index.ts +++ b/backend/src/services/data-field/index.ts @@ -1,12 +1,13 @@ import { PairObject, QueuedApiTrace } from "@common/types" import { DataClass, DataSection, DataTag, DataType } from "@common/enums" -import { ApiEndpoint, ApiTrace, DataField } from "models" +import { ApiEndpoint, DataField } from "models" import { getDataType, getRiskScore, isParameter, parsedJson } from "utils" import { getPathTokens } from "@common/utils" import { ScannerService } from "services/scanner/scan" -import { AppDataSource } from "data-source" import Error404NotFound from "errors/error-404-not-found" import { addDataClass } from "./utils" +import { createQB, getRepository } from "services/database/utils" +import { MetloContext } from "types" export class DataFieldService { static dataFields: Record @@ -14,14 +15,21 @@ export class DataFieldService { static traceCreatedAt: Date static dataFieldsLength: number - static async deleteDataField(dataFieldId: string): Promise { - const dataFieldRepository = AppDataSource.getRepository(DataField) + static async deleteDataField( + ctx: MetloContext, + dataFieldId: string, + ): Promise { + const dataFieldRepository = getRepository(ctx, DataField) const dataField = await dataFieldRepository.findOneBy({ uuid: dataFieldId }) const fieldUuid = dataField.uuid if (!dataField) { throw new Error404NotFound("DataField for provided id not found.") } - await dataFieldRepository.remove(dataField) + await createQB(ctx) + .delete() + .from(DataField) + .where("uuid = :uuid", { uuid: fieldUuid }) + .execute() return { ...dataField, uuid: fieldUuid, @@ -29,12 +37,13 @@ export class DataFieldService { } static async updateDataClasses( + ctx: MetloContext, dataFieldId: string, dataClasses: DataClass[], dataPath: string, dataSection: DataSection, ) { - const dataFieldRepository = AppDataSource.getRepository(DataField) + const dataFieldRepository = getRepository(ctx, DataField) const dataField = await dataFieldRepository.findOneBy({ uuid: dataFieldId, dataPath, diff --git a/backend/src/services/database/index.ts b/backend/src/services/database/index.ts index 29e012d3..4be77183 100644 --- a/backend/src/services/database/index.ts +++ b/backend/src/services/database/index.ts @@ -1,71 +1,50 @@ -import { QueryFailedError } from "typeorm" -import { DatabaseError } from "pg-protocol" import { AppDataSource } from "data-source" import { DatabaseModel } from "models" import Error500InternalServer from "errors/error-500-internal-server" -import { getDataType } from "utils" -import { DataType } from "@common/enums" +import { retryTypeormTransaction } from "utils/db" +import { getEntityManager } from "./utils" +import { MetloContext } from "types" export class DatabaseService { - static isQueryFailedError = ( - err: unknown, - ): err is QueryFailedError & DatabaseError => err instanceof QueryFailedError + static validateQuery(query: string) {} - static delay = (fn: any, ms: number) => - new Promise(resolve => setTimeout(() => resolve(fn()), ms)) - - static randInt = (min: number, max: number) => - Math.floor(Math.random() * (max - min + 1) + min) + static async executeRawQuery( + rawQuery: string, + parameters?: any[], + ): Promise { + this.validateQuery(rawQuery) - static async retryTypeormTransaction(fn: any, maxAttempts: number) { - const execute = async (attempt: number) => { - try { - return await fn() - } catch (err) { - if (this.isQueryFailedError(err)) { - if (err.code === "40P01" || err.code === "55P03") { - if (attempt <= maxAttempts) { - const nextAttempt = attempt + 1 - const delayInMilliseconds = this.randInt(200, 1000) - console.error( - `Retrying after ${delayInMilliseconds} ms due to:`, - err, - ) - return this.delay(() => execute(nextAttempt), delayInMilliseconds) - } else { - throw err - } - } else { - throw err - } - } else { - throw err - } - } + const queryRunner = AppDataSource.createQueryRunner() + await queryRunner.connect() + await queryRunner.startTransaction() + let res = null + try { + res = await queryRunner.query(rawQuery, parameters ?? []) + await queryRunner.commitTransaction() + } catch (err) { + console.error(`Encountered error while executing raw sql query: ${err}`) + await queryRunner.rollbackTransaction() + throw new Error500InternalServer(err) + } finally { + await queryRunner.release() } - return execute(1) + return res } static async executeRawQueries( - rawQueries: string | string[], - parameters?: any[][], + rawQueries: string[], + parameters?: [], ): Promise { + rawQueries.forEach(e => this.validateQuery(e)) + const queryRunner = AppDataSource.createQueryRunner() await queryRunner.connect() await queryRunner.startTransaction() - const isMultiple = getDataType(rawQueries) === DataType.ARRAY - let res = null - if (isMultiple) { - res = [] - } + let res = [] try { - if (isMultiple) { - for (let i = 0; i < rawQueries?.length; i++) { - const query = rawQueries[i] - res.push(await queryRunner.query(query, parameters?.[i] ?? [])) - } - } else { - res = await queryRunner.query(rawQueries as string, parameters ?? []) + for (let i = 0; i < rawQueries?.length; i++) { + const query = rawQueries[i] + res.push(await queryRunner.query(query, parameters?.[i] ?? [])) } await queryRunner.commitTransaction() } catch (err) { @@ -79,6 +58,7 @@ export class DatabaseService { } static async executeTransactions( + ctx: MetloContext, saveItems: DatabaseModel[][], removeItems: DatabaseModel[][], retry?: boolean, @@ -86,27 +66,24 @@ export class DatabaseService { const queryRunner = AppDataSource.createQueryRunner() await queryRunner.connect() await queryRunner.startTransaction() - try { const chunkBatch = 1000 for (let i = 0; i < saveItems.length; i++) { const item = saveItems[i] - const chunkSize = - item.length > chunkBatch ? item.length / chunkBatch : chunkBatch - const fn = () => queryRunner.manager.save(item, { chunk: chunkBatch }) + const fn = () => + getEntityManager(ctx, queryRunner).save(item, { chunk: chunkBatch }) if (retry) { - await this.retryTypeormTransaction(fn, 5) + await retryTypeormTransaction(fn, 5) } else { await fn() } } for (let i = 0; i < removeItems.length; i++) { const item = removeItems[i] - const chunkSize = - item.length > chunkBatch ? item.length / chunkBatch : chunkBatch - const fn = () => queryRunner.manager.remove(item, { chunk: chunkBatch }) + const fn = () => + getEntityManager(ctx, queryRunner).remove(item, { chunk: chunkBatch }) if (retry) { - await this.retryTypeormTransaction(fn, 5) + await retryTypeormTransaction(fn, 5) } else { await fn() } diff --git a/backend/src/services/database/utils.ts b/backend/src/services/database/utils.ts new file mode 100644 index 00000000..da537f03 --- /dev/null +++ b/backend/src/services/database/utils.ts @@ -0,0 +1,160 @@ +import { + DeepPartial, + EntityManager, + FindManyOptions, + FindOneOptions, + FindOptionsWhere, + InsertResult, + QueryRunner, + RemoveOptions, + Repository, + SaveOptions, +} from "typeorm" +import { ObjectLiteral } from "typeorm/common/ObjectLiteral" +import { EntityTarget } from "typeorm/common/EntityTarget" +import { AppDataSource } from "data-source" +import { MetloContext } from "types" +import { QueryDeepPartialEntity } from "typeorm/query-builder/QueryPartialEntity" + +export const createQB = (ctx: MetloContext) => { + let qb = AppDataSource.createQueryBuilder() + qb.where = qb.andWhere + return qb +} + +export const getQB = (ctx: MetloContext, queryRunner: QueryRunner) => { + let qb = queryRunner.manager.createQueryBuilder() + qb.where = qb.andWhere + return qb +} + +export function getRepoQB( + ctx: MetloContext, + target: EntityTarget, + alias?: string, +) { + let qb = AppDataSource.getRepository(target).createQueryBuilder(alias) + qb.where = qb.andWhere + return qb +} + +export class WrappedRepository { + ctx: MetloContext + repository: Repository + + constructor(ctx: MetloContext, repository: Repository) { + this.repository = repository + this.ctx = ctx + } + + count(options?: FindManyOptions): Promise { + return this.repository.count(options) + } + + countBy( + where: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.repository.countBy(where) + } + + find(options?: FindManyOptions): Promise { + return this.repository.find(options) + } + + findBy( + where: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.repository.findBy(where) + } + + findOne(options: FindOneOptions): Promise { + return this.repository.findOne(options) + } + + findOneBy( + where: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.repository.findOneBy(where) + } + + findAndCount(options?: FindManyOptions): Promise<[Entity[], number]> { + return this.repository.findAndCount(options) + } + + save>( + entity: T, + options?: SaveOptions, + ) { + return this.repository.manager.save( + this.repository.metadata.target, + entity, + options, + ) + } +} + +export function getRepository( + ctx: MetloContext, + target: EntityTarget, +) { + const repo = AppDataSource.getRepository(target) + return new WrappedRepository(ctx, repo) +} + +export class WrappedEntityManager { + ctx: MetloContext + manager: EntityManager + + constructor(ctx: MetloContext, manager: EntityManager) { + this.manager = manager + this.ctx = ctx + } + + find( + entityClass: EntityTarget, + options?: FindManyOptions, + ): Promise { + return this.manager.find(entityClass, options) + } + + findOne( + entityClass: EntityTarget, + options: FindOneOptions, + ): Promise { + return this.manager.findOne(entityClass, options) + } + + findOneBy( + entityClass: EntityTarget, + where: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.manager.findOneBy(entityClass, where) + } + + save( + targetOrEntity: Entity, + maybeEntityOrOptions?: SaveOptions, + ): Promise + save>( + targetOrEntity: EntityTarget, + maybeEntityOrOptions: T, + maybeOptions?: SaveOptions, + ) { + return this.manager.save(targetOrEntity, maybeEntityOrOptions, maybeOptions) + } + + remove(entity: Entity, options?: RemoveOptions): Promise { + return this.manager.remove(entity, options) + } + + insert( + target: EntityTarget, + entity: QueryDeepPartialEntity | QueryDeepPartialEntity[], + ): Promise { + return this.manager.insert(target, entity) + } +} + +export function getEntityManager(ctx: MetloContext, queryRunner: QueryRunner) { + return new WrappedEntityManager(ctx, queryRunner.manager) +} diff --git a/backend/src/services/get-endpoints/index.ts b/backend/src/services/get-endpoints/index.ts index 383fe706..87207b64 100644 --- a/backend/src/services/get-endpoints/index.ts +++ b/backend/src/services/get-endpoints/index.ts @@ -19,8 +19,16 @@ import { Test } from "@metlo/testing" import Error404NotFound from "errors/error-404-not-found" import { getRiskScore } from "utils" import { getEndpointsCountQuery, getEndpointsQuery } from "./queries" +import { + createQB, + getEntityManager, + getQB, + getRepoQB, + getRepository, +} from "services/database/utils" +import { MetloContext } from "types" -const GET_DATA_FIELDS_QUERY = ` +const getDataFieldsQuery = (ctx: MetloContext) => ` SELECT uuid, "dataClasses"::text[], @@ -33,8 +41,7 @@ SELECT "updatedAt", "dataPath", "apiEndpointUuid" -FROM - data_field +FROM ${DataField.getTableName(ctx)} data_field WHERE "apiEndpointUuid" = $1 ORDER BY @@ -44,10 +51,11 @@ ORDER BY export class GetEndpointsService { static async updateIsAuthenticated( + ctx: MetloContext, apiEndpointUuid: string, authenticated: boolean, ): Promise { - await AppDataSource.createQueryBuilder() + await createQB(ctx) .update(ApiEndpoint) .set({ isAuthenticatedUserSet: authenticated }) .where("uuid = :id", { id: apiEndpointUuid }) @@ -55,9 +63,10 @@ export class GetEndpointsService { } static async updateEndpointRiskScore( + ctx: MetloContext, apiEndpointUuid: string, ): Promise { - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) + const apiEndpointRepository = getRepository(ctx, ApiEndpoint) const apiEndpoint = await apiEndpointRepository.findOne({ where: { uuid: apiEndpointUuid, @@ -67,14 +76,16 @@ export class GetEndpointsService { }, }) apiEndpoint.riskScore = getRiskScore(apiEndpoint.dataFields) - await apiEndpointRepository.update( - { uuid: apiEndpointUuid }, - { riskScore: apiEndpoint.riskScore }, - ) + await getRepoQB(ctx, ApiEndpoint) + .andWhere("uuid = :uuid", { uuid: apiEndpointUuid }) + .update() + .set({ riskScore: apiEndpoint.riskScore }) + .execute() return apiEndpoint } static async getEndpoints( + ctx: MetloContext, getEndpointParams: GetEndpointParams, ): Promise<[ApiEndpointResponse[], number]> { const queryRunner = AppDataSource.createQueryRunner() @@ -120,11 +131,11 @@ export class GetEndpointsService { const offsetFilter = `OFFSET ${getEndpointParams?.offset ?? 10}` const endpointResults = await queryRunner.query( - getEndpointsQuery(whereFilterString, limitFilter, offsetFilter), + getEndpointsQuery(ctx, whereFilterString, limitFilter, offsetFilter), parameters, ) const countResults = await queryRunner.query( - getEndpointsCountQuery(whereFilterString), + getEndpointsCountQuery(ctx, whereFilterString), parameters, ) @@ -138,42 +149,43 @@ export class GetEndpointsService { } static async getEndpoint( + ctx: MetloContext, endpointId: string, ): Promise { const queryRunner = AppDataSource.createQueryRunner() try { await queryRunner.connect() - const endpoint = await queryRunner.manager - .createQueryBuilder() + const endpoint = await getQB(ctx, queryRunner) .from(ApiEndpoint, "endpoint") .where("uuid = :id", { id: endpointId }) .getRawOne() if (!endpoint) { throw new Error404NotFound("Endpoint does not exist.") } - const alerts = await queryRunner.manager - .createQueryBuilder() + const alerts = await getQB(ctx, queryRunner) .select(["uuid", "status"]) .from(Alert, "alert") .where(`"apiEndpointUuid" = :id`, { id: endpointId }) .getRawMany() const dataFields: DataField[] = await queryRunner.query( - GET_DATA_FIELDS_QUERY, + getDataFieldsQuery(ctx), [endpointId], ) - const openapiSpec = await queryRunner.manager - .createQueryBuilder() + const openapiSpec = await getQB(ctx, queryRunner) .from(OpenApiSpec, "spec") .where("name = :name", { name: endpoint.openapiSpecName }) .getRawOne() - const traces = await queryRunner.manager.find(ApiTrace, { + const traces = await getEntityManager(ctx, queryRunner).find(ApiTrace, { where: { apiEndpointUuid: endpoint.uuid }, order: { createdAt: "DESC" }, take: 100, }) - const tests = await queryRunner.manager.find(ApiEndpointTest, { - where: { apiEndpoint: { uuid: endpointId } }, - }) + const tests = await getEntityManager(ctx, queryRunner).find( + ApiEndpointTest, + { + where: { apiEndpoint: { uuid: endpointId } }, + }, + ) return { ...endpoint, alerts, @@ -190,11 +202,13 @@ export class GetEndpointsService { } } - static async getHosts(): Promise { + static async getHosts(ctx: MetloContext): Promise { try { - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) - const hosts: { [host: string]: string }[] = await apiEndpointRepository - .createQueryBuilder("apiEndpoint") + const hosts: { [host: string]: string }[] = await getRepoQB( + ctx, + ApiEndpoint, + "apiEndpoint", + ) .select(["host"]) .distinct(true) .getRawMany() @@ -205,13 +219,12 @@ export class GetEndpointsService { } } - static async getUsage(endpointId: string): Promise { + static async getUsage( + ctx: MetloContext, + endpointId: string, + ): Promise { try { - const aggregateTraceDataRepo = AppDataSource.getRepository( - AggregateTraceDataHourly, - ) - const usage = await aggregateTraceDataRepo - .createQueryBuilder("trace") + const usage = await getRepoQB(ctx, AggregateTraceDataHourly, "trace") .select([`DATE_TRUNC('day', hour) AS date`, `SUM("numCalls") AS count`]) .where(`"apiEndpointUuid" = :id`, { id: endpointId }) .groupBy(`DATE_TRUNC('day', hour)`) diff --git a/backend/src/services/get-endpoints/queries.ts b/backend/src/services/get-endpoints/queries.ts index b7a0d695..8be6b01f 100644 --- a/backend/src/services/get-endpoints/queries.ts +++ b/backend/src/services/get-endpoints/queries.ts @@ -1,4 +1,8 @@ +import { ApiEndpoint, DataField } from "models" +import { MetloContext } from "types" + export const getEndpointsQuery = ( + ctx: MetloContext, whereFilter: string, limitFilter: string, offsetFilter: string, @@ -7,12 +11,12 @@ export const getEndpointsQuery = ( endpoint.*, data_field."dataClasses" FROM - "api_endpoint" endpoint + ${ApiEndpoint.getTableName(ctx)} endpoint LEFT JOIN LATERAL ( SELECT array_agg(DISTINCT "classes")::varchar[] as "dataClasses" FROM - data_field, + ${DataField.getTableName(ctx)} data_field, unnest(data_field."dataClasses") as "classes" WHERE data_field."apiEndpointUuid" = endpoint.uuid @@ -26,16 +30,19 @@ export const getEndpointsQuery = ( ${offsetFilter} ` -export const getEndpointsCountQuery = (whereFilter: string) => ` +export const getEndpointsCountQuery = ( + ctx: MetloContext, + whereFilter: string, +) => ` SELECT COUNT(endpoint.uuid) as count FROM - "api_endpoint" endpoint + ${ApiEndpoint.getTableName(ctx)} endpoint LEFT JOIN LATERAL ( SELECT array_agg(DISTINCT "classes") as "dataClasses" FROM - data_field, + ${DataField.getTableName(ctx)} data_field, unnest(data_field."dataClasses") as "classes" WHERE data_field."apiEndpointUuid" = endpoint.uuid diff --git a/backend/src/services/jobs/check-unauthenticated-endpoints.ts b/backend/src/services/jobs/check-unauthenticated-endpoints.ts index 09d7d006..7d4919ac 100644 --- a/backend/src/services/jobs/check-unauthenticated-endpoints.ts +++ b/backend/src/services/jobs/check-unauthenticated-endpoints.ts @@ -6,8 +6,10 @@ import { getUnauthenticatedEndpointsSensitiveData, } from "./queries" import { AlertService } from "services/alert" +import { getQB } from "services/database/utils" +import { MetloContext } from "types" -const checkForUnauthenticatedEndpoints = async (): Promise => { +const checkForUnauthenticatedEndpoints = async (ctx: MetloContext): Promise => { const queryRunner = AppDataSource.createQueryRunner() try { await queryRunner.connect() @@ -24,12 +26,7 @@ const checkForUnauthenticatedEndpoints = async (): Promise => { const alerts = await AlertService.createUnauthEndpointSenDataAlerts( endpointsToAlert, ) - await queryRunner.manager - .createQueryBuilder() - .insert() - .into(Alert) - .values(alerts) - .execute() + await getQB(ctx, queryRunner).insert().into(Alert).values(alerts).execute() } catch (err) { console.error( `Encountered error when checking for unauthenticated endpoints: ${err}`, diff --git a/backend/src/services/jobs/clear-api-traces.ts b/backend/src/services/jobs/clear-api-traces.ts index e3ba0a12..b821d12e 100644 --- a/backend/src/services/jobs/clear-api-traces.ts +++ b/backend/src/services/jobs/clear-api-traces.ts @@ -2,16 +2,17 @@ import { DateTime } from "luxon" import { ApiTrace } from "models" import { AppDataSource } from "data-source" import { aggregateTracesDataHourlyQuery } from "./queries" +import { MetloContext } from "types" +import { getQB } from "services/database/utils" -const clearApiTraces = async (): Promise => { +const clearApiTraces = async (ctx: MetloContext): Promise => { const queryRunner = AppDataSource.createQueryRunner() await queryRunner.connect() try { const now = DateTime.now() const oneHourAgo = now.minus({ hours: 1 }).toJSDate() - const maxTimeRes = await queryRunner.manager - .createQueryBuilder() + const maxTimeRes = await getQB(ctx, queryRunner) .select([`MAX("createdAt") as "maxTime"`]) .from(ApiTrace, "traces") .where('"apiEndpointUuid" IS NOT NULL') @@ -22,8 +23,7 @@ const clearApiTraces = async (): Promise => { if (maxTime) { await queryRunner.startTransaction() await queryRunner.query(aggregateTracesDataHourlyQuery, [maxTime]) - await queryRunner.manager - .createQueryBuilder() + await getQB(ctx, queryRunner) .delete() .from(ApiTrace) .where('"apiEndpointUuid" IS NOT NULL') diff --git a/backend/src/services/jobs/generate-endpoints-traces.ts b/backend/src/services/jobs/generate-endpoints-traces.ts index 1baaf986..29911551 100644 --- a/backend/src/services/jobs/generate-endpoints-traces.ts +++ b/backend/src/services/jobs/generate-endpoints-traces.ts @@ -3,13 +3,17 @@ import { v4 as uuidv4 } from "uuid" import { ApiTrace, ApiEndpoint, Alert } from "models" import { AppDataSource } from "data-source" import { AlertType } from "@common/enums" -import { DatabaseService } from "services/database" import { AlertService } from "services/alert" import { getPathTokens } from "@common/utils" import { skipAutoGeneratedMatch, isSuspectedParamater } from "utils" import { GenerateEndpoint } from "./types" +import { retryTypeormTransaction } from "utils/db" +import { MetloContext } from "types" +import { getEntityManager, getQB } from "services/database/utils" -const generateEndpointsFromTraces = async (): Promise => { +const generateEndpointsFromTraces = async ( + ctx: MetloContext, +): Promise => { const queryRunner = AppDataSource.createQueryRunner() await queryRunner.connect() try { @@ -31,12 +35,15 @@ const generateEndpointsFromTraces = async (): Promise => { }, take: 1000, } - let traces = await queryRunner.manager.find(ApiTrace, tracesFindOptions) + let traces = await getEntityManager(ctx, queryRunner).find( + ApiTrace, + tracesFindOptions, + ) while (traces && traces?.length > 0) { const regexToTracesMap: Record = {} for (let i = 0; i < traces.length; i++) { const trace = traces[i] - const apiEndpoint = await queryRunner.manager.findOne(ApiEndpoint, { + const apiEndpoint = await getEntityManager(ctx, queryRunner).findOne(ApiEndpoint, { where: { pathRegex: Raw(alias => `:path ~ ${alias}`, { path: trace.path, @@ -53,20 +60,18 @@ const generateEndpointsFromTraces = async (): Promise => { apiEndpoint.updateDates(trace.createdAt) await queryRunner.startTransaction() - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .update(ApiTrace) .set({ apiEndpointUuid: apiEndpoint.uuid }) .where("uuid = :id", { id: trace.uuid }) .execute(), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .update(ApiEndpoint) .set({ firstDetected: apiEndpoint.firstDetected, @@ -146,35 +151,33 @@ const generateEndpointsFromTraces = async (): Promise => { traceIds.push(trace.uuid) } const alert = await AlertService.createAlert( + ctx, AlertType.NEW_ENDPOINT, apiEndpoint, ) await queryRunner.startTransaction() - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .insert() .into(ApiEndpoint) .values(apiEndpoint) .execute(), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .insert() .into(Alert) .values(alert) .execute(), 5, ) - await DatabaseService.retryTypeormTransaction( + await retryTypeormTransaction( () => - queryRunner.manager - .createQueryBuilder() + getQB(ctx, queryRunner) .update(ApiTrace) .set({ apiEndpointUuid: apiEndpoint.uuid }) .where("uuid IN(:...ids)", { ids: traceIds }) @@ -183,7 +186,7 @@ const generateEndpointsFromTraces = async (): Promise => { ) await queryRunner.commitTransaction() } - traces = await queryRunner.manager.find(ApiTrace, tracesFindOptions) + traces = await getEntityManager(ctx, queryRunner).find(ApiTrace, tracesFindOptions) } console.log("Finished Generating Endpoints.") } catch (err) { diff --git a/backend/src/services/jobs/generate-openapi-spec.ts b/backend/src/services/jobs/generate-openapi-spec.ts index 50807621..1d400094 100644 --- a/backend/src/services/jobs/generate-openapi-spec.ts +++ b/backend/src/services/jobs/generate-openapi-spec.ts @@ -1,18 +1,18 @@ import { IsNull } from "typeorm" import { SpecExtension } from "@common/enums" import { ApiEndpoint, OpenApiSpec, ApiTrace } from "models" -import { AppDataSource } from "data-source" import { DatabaseService } from "services/database" import { getPathTokens } from "@common/utils" import { isParameter, parsedJsonNonNull } from "utils" import { BodySchema, BodyContent, Responses } from "./types" import { parseSchema, parseContent } from "./utils" +import { getRepoQB, getRepository } from "services/database/utils" +import { MetloContext } from "types" -const generateOpenApiSpec = async (): Promise => { +const generateOpenApiSpec = async (ctx: MetloContext): Promise => { try { - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) - const apiTraceRepository = AppDataSource.getRepository(ApiTrace) + const apiEndpointRepository = getRepository(ctx, ApiEndpoint) + const openApiSpecRepository = getRepository(ctx, OpenApiSpec) const nonSpecEndpoints = await apiEndpointRepository.findBy({ openapiSpecName: IsNull(), }) @@ -62,9 +62,10 @@ const generateOpenApiSpec = async (): Promise => { const paths = openApiSpec["paths"] const path = endpoint.path const method = endpoint.method.toLowerCase() - let tracesQb = apiTraceRepository - .createQueryBuilder() - .where('"apiEndpointUuid" = :id', { id: endpoint.uuid }) + let tracesQb = getRepoQB(ctx, ApiTrace).where( + '"apiEndpointUuid" = :id', + { id: endpoint.uuid }, + ) if (spec.specUpdatedAt) { tracesQb = tracesQb .andWhere('"createdAt" > :updated', { @@ -216,7 +217,12 @@ const generateOpenApiSpec = async (): Promise => { spec.updatedAt = currTime spec.specUpdatedAt = currTime spec.extension = SpecExtension.JSON - await DatabaseService.executeTransactions([[spec], endpoints], [], true) + await DatabaseService.executeTransactions( + ctx, + [[spec], endpoints], + [], + true, + ) } } catch (err) { console.error(`Encountered error while generating OpenAPI specs: ${err}`) diff --git a/backend/src/services/jobs/monitor-endpoint-hsts.ts b/backend/src/services/jobs/monitor-endpoint-hsts.ts index 167a45b1..e6929217 100644 --- a/backend/src/services/jobs/monitor-endpoint-hsts.ts +++ b/backend/src/services/jobs/monitor-endpoint-hsts.ts @@ -1,19 +1,17 @@ import axios from "axios" import { ApiEndpoint, ApiTrace, Alert } from "models" -import { AppDataSource } from "data-source" import { AlertService } from "services/alert" +import { MetloContext } from "types" +import { getRepoQB, getRepository } from "services/database/utils" -const monitorEndpointForHSTS = async (): Promise => { +const monitorEndpointForHSTS = async (ctx: MetloContext): Promise => { try { - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) - const apiTraceRepository = AppDataSource.getRepository(ApiTrace) - const alertsRepository = AppDataSource.getRepository(Alert) + const apiTraceRepository = getRepository(ctx, ApiTrace) + const alertsRepository = getRepository(ctx, Alert) const alertableData: Array<[ApiEndpoint, ApiTrace, string]> = [] - for (const endpoint of await apiEndpointRepository - .createQueryBuilder() - .getMany()) { + for (const endpoint of await getRepoQB(ctx, ApiEndpoint).getMany()) { const latest_trace_for_endpoint = await apiTraceRepository.findOne({ where: { apiEndpointUuid: endpoint.uuid }, order: { createdAt: "DESC" }, @@ -51,7 +49,7 @@ const monitorEndpointForHSTS = async (): Promise => { } } } - let alerts = await AlertService.createMissingHSTSAlert(alertableData) + let alerts = await AlertService.createMissingHSTSAlert(ctx, alertableData) await alertsRepository.save(alerts) } catch (err) { console.error( diff --git a/backend/src/services/log-request/index.ts b/backend/src/services/log-request/index.ts index 43f3915f..3b7b76af 100644 --- a/backend/src/services/log-request/index.ts +++ b/backend/src/services/log-request/index.ts @@ -4,12 +4,16 @@ import { BlockFieldsService } from "services/block-fields" import { AuthenticationConfigService } from "services/authentication-config" import { RedisClient } from "utils/redis" import { TRACES_QUEUE } from "~/constants" +import { MetloContext } from "types" export class LogRequestService { - static async logRequest(traceParams: TraceParams): Promise { + static async logRequest( + ctx: MetloContext, + traceParams: TraceParams, + ): Promise { try { /** Log Request in ApiTrace table **/ - const queueLength = await RedisClient.getListLength(TRACES_QUEUE) + const queueLength = await RedisClient.getListLength(ctx, TRACES_QUEUE) if (queueLength > 1000) { return } @@ -39,9 +43,10 @@ export class LogRequestService { } await BlockFieldsService.redactBlockedFields(apiTraceObj) - await AuthenticationConfigService.setSessionMetadata(apiTraceObj) + await AuthenticationConfigService.setSessionMetadata(ctx, apiTraceObj) RedisClient.pushValueToRedisList( + ctx, TRACES_QUEUE, [JSON.stringify(apiTraceObj)], true, @@ -52,9 +57,12 @@ export class LogRequestService { } } - static async logRequestBatch(traceParamsBatch: TraceParams[]): Promise { + static async logRequestBatch( + ctx: MetloContext, + traceParamsBatch: TraceParams[], + ): Promise { for (let i = 0; i < traceParamsBatch.length; i++) { - await this.logRequest(traceParamsBatch[i]) + await this.logRequest(ctx, traceParamsBatch[i]) } } } diff --git a/backend/src/services/logging/index.ts b/backend/src/services/logging/index.ts index 25c8e29f..3608b67b 100644 --- a/backend/src/services/logging/index.ts +++ b/backend/src/services/logging/index.ts @@ -1,17 +1,18 @@ import axios from "axios" -import { AppDataSource } from "data-source" import { InstanceSettings } from "models" +import { getRepository } from "services/database/utils" import { getCounts } from "services/summary/usageStats" +import { MetloContext } from "types" -export const logAggregatedStats = async () => { - const settingRepository = AppDataSource.getRepository(InstanceSettings) +export const logAggregatedStats = async (ctx: MetloContext) => { + const settingRepository = getRepository(ctx, InstanceSettings) const settingsLs = await settingRepository.find() if (settingsLs.length == 0) { console.log("No instance settings found...") return } const settings = settingsLs[0] - const counts = await getCounts() + const counts = await getCounts(ctx) await axios({ url: "https://logger.metlo.com/log", method: "POST", diff --git a/backend/src/services/spec/index.ts b/backend/src/services/spec/index.ts index 89f492f6..220a48f5 100644 --- a/backend/src/services/spec/index.ts +++ b/backend/src/services/spec/index.ts @@ -4,7 +4,6 @@ import SwaggerParser from "@apidevtools/swagger-parser" import Converter from "swagger2openapi" import yaml from "js-yaml" import YAML from "yaml" -import OpenAPIRequestValidator from "@leoscope/openapi-request-validator" import OpenAPIResponseValidator, { OpenAPIResponseValidatorValidationError, } from "@leoscope/openapi-response-validator" @@ -22,22 +21,16 @@ import { OpenApiSpec as OpenApiSpecResponse, QueuedApiTrace, } from "@common/types" -import { getPathTokens } from "@common/utils" import { AppDataSource } from "data-source" -import { getPathRegex, isParameter, parsedJsonNonNull } from "utils" +import { getPathRegex, parsedJsonNonNull } from "utils" import Error409Conflict from "errors/error-409-conflict" import Error422UnprocessableEntity from "errors/error-422-unprocessable-entity" import { - generateAlertMessageFromReqErrors, generateAlertMessageFromRespErrors, getOpenAPISpecVersion, - getSpecRequestParameters, getSpecResponses, - parsePathParameter, - SpecValue, AjvError, validateSpecSchema, - getSpecRequestBody, getHostsV3, getServersV3, } from "./utils" @@ -53,6 +46,8 @@ import { updateOldEndpointUuids, getAllOldEndpoints, } from "./queries" +import { MetloContext } from "types" +import { getEntityManager, getQB, getRepository } from "services/database/utils" interface EndpointsMap { endpoint: ApiEndpoint @@ -60,16 +55,20 @@ interface EndpointsMap { } export class SpecService { - static async getSpec(specName: string): Promise { - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + static async getSpec( + ctx: MetloContext, + specName: string, + ): Promise { + const openApiSpecRepository = getRepository(ctx, OpenApiSpec) const spec = await openApiSpecRepository.findOneBy({ name: specName }) return spec } static async getSpecs( + ctx: MetloContext, listAutogenerated: boolean = true, ): Promise { - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + const openApiSpecRepository = getRepository(ctx, OpenApiSpec) const specList = await openApiSpecRepository.find({ where: { isAutoGenerated: listAutogenerated }, order: { updatedAt: "DESC" }, @@ -78,6 +77,7 @@ export class SpecService { } static async updateSpec( + ctx: MetloContext, specObject: JSONValue, fileName: string, extension: SpecExtension, @@ -100,8 +100,8 @@ export class SpecService { await queryRunner.connect() await queryRunner.startTransaction() try { - await this.deleteSpec(fileName) - await this.uploadNewSpec(specObject, fileName, extension, specString) + await this.deleteSpec(ctx, fileName) + await this.uploadNewSpec(ctx, specObject, fileName, extension, specString) await queryRunner.commitTransaction() } catch (err) { console.error(`Error updating spec file: ${err}`) @@ -113,6 +113,7 @@ export class SpecService { } static async deleteSpec( + ctx: MetloContext, fileName: string, existingQueryRunner?: QueryRunner, ): Promise { @@ -125,9 +126,12 @@ export class SpecService { await queryRunner.startTransaction() } try { - const openApiSpec = await queryRunner.manager.findOneBy(OpenApiSpec, { - name: fileName, - }) + const openApiSpec = await getEntityManager(ctx, queryRunner).findOneBy( + OpenApiSpec, + { + name: fileName, + }, + ) if (!openApiSpec) { throw new Error404NotFound( "No spec file with the provided name exists.", @@ -137,14 +141,12 @@ export class SpecService { throw new Error409Conflict("Can't delete auto generated spec.") } await queryRunner.query(deleteOpenAPISpecDiffAlerts, [fileName]) - await queryRunner.manager - .createQueryBuilder() + await getQB(ctx, queryRunner) .update(ApiEndpoint) .set({ openapiSpecName: null }) .where('"openapiSpecName" = :name', { name: fileName }) .execute() - await queryRunner.manager - .createQueryBuilder() + await getQB(ctx, queryRunner) .delete() .from(OpenApiSpec) .where("name = :name", { name: fileName }) @@ -166,6 +168,7 @@ export class SpecService { } static async uploadNewSpec( + ctx: MetloContext, specObject: JSONValue, fileName: string, extension: SpecExtension, @@ -205,8 +208,8 @@ export class SpecService { const paths: JSONValue = specObject["paths"] - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + const apiEndpointRepository = getRepository(ctx, ApiEndpoint) + const openApiSpecRepository = getRepository(ctx, OpenApiSpec) let existingSpec = await openApiSpecRepository.findOneBy({ name: fileName, @@ -334,9 +337,9 @@ export class SpecService { } try { - await queryRunner.manager.save(existingSpec) + await getEntityManager(ctx, queryRunner).save(existingSpec) for (const item of Object.values(endpointsMap)) { - await queryRunner.manager.save(item.endpoint) + await getEntityManager(ctx, queryRunner).save(item.endpoint) const similarEndpointUuids = [] for (const e of Object.values(item.similarEndpoints)) { similarEndpointUuids.push(e.uuid) @@ -353,23 +356,20 @@ export class SpecService { } if (similarEndpointUuids.length > 0) { - await queryRunner.manager.save(item.endpoint) - const updateTracesQb = queryRunner.manager - .createQueryBuilder() + await getEntityManager(ctx, queryRunner).save(item.endpoint) + const updateTracesQb = getQB(ctx, queryRunner) .update(ApiTrace) .set({ apiEndpointUuid: item.endpoint.uuid }) .where(`"apiEndpointUuid" IN(:...ids)`, { ids: similarEndpointUuids, }) - const deleteDataFieldsQb = queryRunner.manager - .createQueryBuilder() + const deleteDataFieldsQb = getQB(ctx, queryRunner) .delete() .from(DataField) .where(`"apiEndpointUuid" IN(:...ids)`, { ids: similarEndpointUuids, }) - const updateAlertsQb = queryRunner.manager - .createQueryBuilder() + const updateAlertsQb = getQB(ctx, queryRunner) .update(Alert) .set({ apiEndpointUuid: item.endpoint.uuid }) .where(`"apiEndpointUuid" IN(:...ids)`, { @@ -378,8 +378,7 @@ export class SpecService { .andWhere(`type NOT IN(:...types)`, { types: [AlertType.NEW_ENDPOINT, AlertType.OPEN_API_SPEC_DIFF], }) - const deleteAlertsQb = queryRunner.manager - .createQueryBuilder() + const deleteAlertsQb = getQB(ctx, queryRunner) .delete() .from(Alert) .where(`"apiEndpointUuid" IN(:...ids)`, { @@ -388,8 +387,7 @@ export class SpecService { .andWhere(`type IN(:...types)`, { types: [AlertType.NEW_ENDPOINT, AlertType.OPEN_API_SPEC_DIFF], }) - const deleteAggregateHourlyQb = queryRunner.manager - .createQueryBuilder() + const deleteAggregateHourlyQb = getQB(ctx, queryRunner) .delete() .from(AggregateTraceDataHourly) .where(`"apiEndpointUuid" IN(:...ids)`, { @@ -423,8 +421,7 @@ export class SpecService { similarEndpointUuids, ]) await deleteAggregateHourlyQb.execute() - await queryRunner.manager - .createQueryBuilder() + await getQB(ctx, queryRunner) .delete() .from(ApiEndpoint) .where(`"uuid" IN(:...ids)`, { ids: similarEndpointUuids }) @@ -448,12 +445,13 @@ export class SpecService { } static async findOpenApiSpecDiff( + ctx: MetloContext, trace: QueuedApiTrace, endpoint: ApiEndpoint, queryRunner: QueryRunner, ): Promise { try { - const openApiSpecRepository = AppDataSource.getRepository(OpenApiSpec) + const openApiSpecRepository = getRepository(ctx, OpenApiSpec) const openApiSpec = await openApiSpecRepository.findOneBy({ name: endpoint.openapiSpecName, }) @@ -494,6 +492,7 @@ export class SpecService { const errorItems = { ...respErrorItems } return await AlertService.createSpecDiffAlerts( + ctx, errorItems, endpoint.uuid, trace, diff --git a/backend/src/services/summary/alerts.ts b/backend/src/services/summary/alerts.ts index c0943ade..ca0aed4a 100644 --- a/backend/src/services/summary/alerts.ts +++ b/backend/src/services/summary/alerts.ts @@ -1,31 +1,33 @@ import { Status, AlertType } from "@common/enums" -import { AppDataSource } from "data-source" import { Alert } from "models" -import cache from "memory-cache" import { DatabaseService } from "services/database" +import { getRepository } from "services/database/utils" +import { MetloContext } from "types" +import { RedisClient } from "utils/redis" -export const getAlertTypeAgg = async () => { +export const getAlertTypeAgg = async (ctx: MetloContext) => { const alertTypeAggRes: { type: AlertType; count: number }[] = - await DatabaseService.executeRawQueries(` + await DatabaseService.executeRawQuery(` SELECT type, CAST(COUNT(*) AS INTEGER) as count - FROM alert WHERE status = 'Open' + FROM ${Alert.getTableName(ctx)} WHERE status = 'Open' GROUP BY 1 `) return Object.fromEntries(alertTypeAggRes.map(e => [e.type, e.count])) } -export const getAlertTypeAggCached = async () => { - const cacheRes: Record | null = cache.get("alertTypeAgg") +export const getAlertTypeAggCached = async (ctx: MetloContext) => { + const cacheRes: Record | null = + await RedisClient.getFromRedis(ctx, "alertTypeAgg") if (cacheRes) { return cacheRes } - const realRes = await getAlertTypeAgg() - cache.put("alertTypeAgg", realRes, 5000) + const realRes = await getAlertTypeAgg(ctx) + await RedisClient.addToRedis(ctx, "alertTypeAgg", realRes, 5) return realRes } -export const getTopAlerts = async () => { - const alertRepository = AppDataSource.getRepository(Alert) +export const getTopAlerts = async (ctx: MetloContext) => { + const alertRepository = getRepository(ctx, Alert) return await alertRepository.find({ where: { status: Status.OPEN, @@ -41,12 +43,15 @@ export const getTopAlerts = async () => { }) } -export const getTopAlertsCached = async () => { - const cacheRes: Alert[] | null = cache.get("topAlertsCached") +export const getTopAlertsCached = async (ctx: MetloContext) => { + const cacheRes: Alert[] | null = await RedisClient.getFromRedis( + ctx, + "topAlertsCached", + ) if (cacheRes) { return cacheRes } - const realRes = await getTopAlerts() - cache.put("topAlertsCached", realRes, 5000) + const realRes = await getTopAlerts(ctx) + await RedisClient.addToRedis(ctx, "topAlertsCached", realRes, 5) return realRes } diff --git a/backend/src/services/summary/endpoints.ts b/backend/src/services/summary/endpoints.ts index 3f10ff5b..fb187533 100644 --- a/backend/src/services/summary/endpoints.ts +++ b/backend/src/services/summary/endpoints.ts @@ -1,28 +1,29 @@ import groupBy from "lodash/groupBy" import { In } from "typeorm" -import { AppDataSource } from "data-source" import { ApiEndpoint, ApiTrace } from "models" import { EndpointAndUsage } from "@common/types" -import cache from "memory-cache" import { DatabaseService } from "services/database" +import { MetloContext } from "types" +import { RedisClient } from "utils/redis" +import { getRepository } from "services/database/utils" -export const getTopEndpoints = async () => { - const apiTraceRepository = AppDataSource.getRepository(ApiTrace) - const apiEndpointRepository = AppDataSource.getRepository(ApiEndpoint) +export const getTopEndpoints = async (ctx: MetloContext) => { + const apiTraceRepository = getRepository(ctx, ApiTrace) + const apiEndpointRepository = getRepository(ctx, ApiEndpoint) const endpointStats: { endpoint: string last1MinCnt: number last5MinCnt: number last30MinCnt: number - }[] = await DatabaseService.executeRawQueries(` + }[] = await DatabaseService.executeRawQuery(` SELECT traces."apiEndpointUuid" as endpoint, CAST(COUNT(*) AS INTEGER) as "last30MinCnt", CAST(SUM(CASE WHEN traces."createdAt" > (NOW() - INTERVAL '5 minutes') THEN 1 ELSE 0 END) AS INTEGER) as "last5MinCnt", CAST(SUM(CASE WHEN traces."createdAt" > (NOW() - INTERVAL '1 minutes') THEN 1 ELSE 0 END) AS INTEGER) as "last1MinCnt" FROM - api_trace traces + ${ApiTrace.getTableName(ctx)} traces WHERE traces."apiEndpointUuid" IS NOT NULL AND traces."createdAt" > (NOW() - INTERVAL '30 minutes') @@ -68,12 +69,15 @@ export const getTopEndpoints = async () => { ) } -export const getTopEndpointsCached = async () => { - const cacheRes: EndpointAndUsage[] | null = cache.get("endpointUsageStats") +export const getTopEndpointsCached = async (ctx: MetloContext) => { + const cacheRes: EndpointAndUsage[] | null = await RedisClient.getFromRedis( + ctx, + "endpointUsageStats", + ) if (cacheRes) { return cacheRes } - const realRes = await getTopEndpoints() - cache.put("endpointUsageStats", realRes, 15000) + const realRes = await getTopEndpoints(ctx) + await RedisClient.addToRedis(ctx, "endpointUsageStats", realRes, 15) return realRes } diff --git a/backend/src/services/summary/index.ts b/backend/src/services/summary/index.ts index 5e5a3924..c3637b6c 100644 --- a/backend/src/services/summary/index.ts +++ b/backend/src/services/summary/index.ts @@ -1,27 +1,28 @@ import { Summary as SummaryResponse } from "@common/types" import { ConnectionsService } from "services/connections" +import { MetloContext } from "types" import { getAlertTypeAggCached, getTopAlertsCached } from "./alerts" import { getTopEndpointsCached } from "./endpoints" import { getPIIDataTypeCountCached } from "./piiData" import { getCountsCached, getUsageStatsCached } from "./usageStats" -export class SummaryService { - static async getSummaryData(): Promise { - const topEndpoints = await getTopEndpointsCached() - const alertTypeCount = await getAlertTypeAggCached() - const topAlerts = await getTopAlertsCached() - const piiDataTypeCount = await getPIIDataTypeCountCached() - const usageStats = await getUsageStatsCached() - const counts = await getCountsCached() - const numConnections = await ConnectionsService.getNumConnections() - return { - piiDataTypeCount: piiDataTypeCount as any, - alertTypeCount: alertTypeCount as any, - topAlerts, - topEndpoints, - usageStats, - numConnections, - ...counts, - } +export const getSummaryData = async ( + ctx: MetloContext, +): Promise => { + const topEndpoints = await getTopEndpointsCached(ctx) + const alertTypeCount = await getAlertTypeAggCached(ctx) + const topAlerts = await getTopAlertsCached(ctx) + const piiDataTypeCount = await getPIIDataTypeCountCached(ctx) + const usageStats = await getUsageStatsCached(ctx) + const counts = await getCountsCached(ctx) + const numConnections = await ConnectionsService.getNumConnections(ctx) + return { + piiDataTypeCount: piiDataTypeCount as any, + alertTypeCount: alertTypeCount as any, + topAlerts, + topEndpoints, + usageStats, + numConnections, + ...counts, } } diff --git a/backend/src/services/summary/piiData.ts b/backend/src/services/summary/piiData.ts index 6530f916..a4f76650 100644 --- a/backend/src/services/summary/piiData.ts +++ b/backend/src/services/summary/piiData.ts @@ -5,34 +5,38 @@ import { SensitiveDataSummary, } from "@common/types" import { DATA_CLASS_TO_RISK_SCORE } from "@common/maps" -import { AppDataSource } from "data-source" -import cache from "memory-cache" import { DatabaseService } from "services/database" +import { ApiEndpoint, DataField } from "models" +import { MetloContext } from "types" +import { RedisClient } from "utils/redis" -export const getPIIDataTypeCount = async () => { +export const getPIIDataTypeCount = async (ctx: MetloContext) => { const piiDataTypeCountRes: { type: DataClass; cnt: number }[] = - await DatabaseService.executeRawQueries(` + await DatabaseService.executeRawQuery(` SELECT data_class as type, CAST(COUNT(*) AS INTEGER) as cnt - FROM (SELECT UNNEST("dataClasses") as data_class FROM data_field) tbl + FROM (SELECT UNNEST("dataClasses") as data_class FROM ${DataField.getTableName( + ctx, + )}) tbl GROUP BY 1 `) return Object.fromEntries(piiDataTypeCountRes.map(e => [e.type, e.cnt])) } -export const getPIIDataTypeCountCached = async () => { +export const getPIIDataTypeCountCached = async (ctx: MetloContext) => { const cacheRes: Record | null = - cache.get("PIIDataTypeCount") + await RedisClient.getFromRedis(ctx, "PIIDataTypeCount") if (cacheRes) { return cacheRes } - const realRes = await getPIIDataTypeCount() - cache.put("PIIDataTypeCount", realRes, 5000) + const realRes = await getPIIDataTypeCount(ctx) + await RedisClient.addToRedis(ctx, "PIIDataTypeCount", realRes, 5) return realRes } -export const getPIIDataTypeAgg = async (params: GetSensitiveDataAggParams) => { - const queryRunner = AppDataSource.createQueryRunner() - +export const getPIIDataTypeAgg = async ( + ctx: MetloContext, + params: GetSensitiveDataAggParams, +) => { let queryParams = [] let dataFieldFilters: string[] = [] let riskFilter = "" @@ -66,8 +70,10 @@ export const getPIIDataTypeAgg = async (params: GetSensitiveDataAggParams) => { const filtered_data_fields = ` SELECT data_field.*, api_endpoint.host as host - FROM data_field - JOIN api_endpoint ON data_field."apiEndpointUuid" = api_endpoint.uuid + FROM ${DataField.getTableName(ctx)} data_field + JOIN ${ApiEndpoint.getTableName( + ctx, + )} api_endpoint ON data_field."apiEndpointUuid" = api_endpoint.uuid ${dataFieldFilter} ` const unnest_fields = ` @@ -114,17 +120,15 @@ export const getPIIDataTypeAgg = async (params: GetSensitiveDataAggParams) => { FROM all_filtered_fields ` - const piiDataTypeRes: PIIDataClassAggItem[] = await queryRunner.query( + const piiDataTypeRes: PIIDataClassAggItem[] = await DatabaseService.executeRawQuery( piiQuery, queryParams, ) - const endpointRes: { count: number }[] = await queryRunner.query( + const endpointRes: { count: number }[] = await DatabaseService.executeRawQuery( endpointQuery, queryParams, ) - await queryRunner.release() - return { piiDataTypeCount: Object.fromEntries( piiDataTypeRes.map(e => [e.dataClass, e.count]), diff --git a/backend/src/services/summary/usageStats.ts b/backend/src/services/summary/usageStats.ts index e70522bf..25e8ebb9 100644 --- a/backend/src/services/summary/usageStats.ts +++ b/backend/src/services/summary/usageStats.ts @@ -1,13 +1,21 @@ import { UsageStats } from "@common/types" -import cache from "memory-cache" +import { + AggregateTraceDataHourly, + Alert, + ApiEndpoint, + ApiTrace, + DataField, +} from "models" import { DatabaseService } from "services/database" +import { MetloContext } from "types" +import { RedisClient } from "utils/redis" -export const getUsageStats = async () => { +export const getUsageStats = async (ctx: MetloContext) => { const statsQuery = ` SELECT DATE_TRUNC('day', traces.hour) as day, SUM(traces."numCalls") as cnt - FROM aggregate_trace_data_hourly traces + FROM ${AggregateTraceDataHourly.getTableName(ctx)} traces WHERE traces.hour > (NOW() - INTERVAL '15 days') GROUP BY 1 ORDER BY 1 @@ -16,7 +24,7 @@ export const getUsageStats = async () => { SELECT CAST(SUM(CASE WHEN traces."createdAt" > (NOW() - INTERVAL '1 minutes') THEN 1 ELSE 0 END) AS INTEGER) as "last1MinCnt", CAST(COUNT(*) AS INTEGER) as "last60MinCnt" - FROM api_trace traces + FROM ${ApiTrace.getTableName(ctx)} traces WHERE traces."createdAt" > (NOW() - INTERVAL '60 minutes') ` const queryResponses = await DatabaseService.executeRawQueries([ @@ -38,13 +46,16 @@ export const getUsageStats = async () => { } as UsageStats } -export const getUsageStatsCached = async () => { - const cacheRes: UsageStats | null = cache.get("usageStats") +export const getUsageStatsCached = async (ctx: MetloContext) => { + const cacheRes: UsageStats | null = await RedisClient.getFromRedis( + ctx, + "usageStats", + ) if (cacheRes) { return cacheRes } - const realRes = await getUsageStats() - cache.put("usageStats", realRes, 60000) + const realRes = await getUsageStats(ctx) + await RedisClient.addToRedis(ctx, "usageStats", realRes, 60) return realRes } @@ -56,22 +67,22 @@ interface CountsResponse { highRiskAlerts: number } -export const getCounts = async () => { +export const getCounts = async (ctx: MetloContext) => { const newAlertQuery = ` SELECT CAST(COUNT(*) AS INTEGER) as count, CAST(SUM(CASE WHEN "riskScore" = 'high' THEN 1 ELSE 0 END) AS INTEGER) as high_risk_count - FROM alert WHERE status = 'Open' + FROM ${Alert.getTableName(ctx)} alert WHERE status = 'Open' ` const endpointsTrackedQuery = ` SELECT CAST(COUNT(*) AS INTEGER) as endpoint_count, CAST(COUNT(DISTINCT(host)) AS INTEGER) as host_count - FROM api_endpoint + FROM ${ApiEndpoint.getTableName(ctx)} api_endpoint ` const piiDataFieldsQuery = ` SELECT CAST(COUNT(*) AS INTEGER) as count - FROM data_field WHERE "dataTag" = 'PII' + FROM ${DataField.getTableName(ctx)} data_field WHERE "dataTag" = 'PII' ` const [newAlertQueryRes, endpointsTrackedQueryRes, piiDataFieldsQueryRes] = await DatabaseService.executeRawQueries([ @@ -93,12 +104,15 @@ export const getCounts = async () => { } } -export const getCountsCached = async () => { - const cacheRes: CountsResponse | null = cache.get("usageCounts") +export const getCountsCached = async (ctx: MetloContext) => { + const cacheRes: CountsResponse | null = await RedisClient.getFromRedis( + ctx, + "usageCounts", + ) if (cacheRes) { return cacheRes } - const realRes = await getCounts() - cache.put("usageCounts", realRes, 5000) + const realRes = await getCounts(ctx) + await RedisClient.addToRedis(ctx, "usageCounts", realRes, 60) return realRes } diff --git a/backend/src/services/summary/vulnerabilities.ts b/backend/src/services/summary/vulnerabilities.ts index 623a5e40..600240cc 100644 --- a/backend/src/services/summary/vulnerabilities.ts +++ b/backend/src/services/summary/vulnerabilities.ts @@ -5,13 +5,14 @@ import { VulnerabilitySummary, } from "@common/types" import { ALERT_TYPE_TO_RISK_SCORE } from "@common/maps" -import { AppDataSource } from "data-source" +import { Alert } from "models" +import { MetloContext } from "types" +import { DatabaseService } from "services/database" export const getVulnerabilityAgg = async ( + ctx: MetloContext, params: GetVulnerabilityAggParams, ) => { - const queryRunner = AppDataSource.createQueryRunner() - let queryParams = [] let alertFilters: string[] = [] @@ -44,7 +45,7 @@ export const getVulnerabilityAgg = async ( alert.*, ${riskCase} as risk, api_endpoint.host as host - FROM alert + FROM ${Alert.getTableName(ctx)} alert JOIN api_endpoint ON alert."apiEndpointUuid" = api_endpoint.uuid ${alertFilter} ` @@ -67,16 +68,10 @@ export const getVulnerabilityAgg = async ( FROM filtered_alerts ` - const vulnerabilityItemRes: VulnerabilityAggItem[] = await queryRunner.query( - vulnerabilityQuery, - queryParams, - ) - const endpointRes: { count: number }[] = await queryRunner.query( - endpointQuery, - queryParams, - ) - - await queryRunner.release() + const vulnerabilityItemRes: VulnerabilityAggItem[] = + await DatabaseService.executeRawQuery(vulnerabilityQuery, queryParams) + const endpointRes: { count: number }[] = + await DatabaseService.executeRawQuery(endpointQuery, queryParams) return { vulnerabilityTypeCount: Object.fromEntries( diff --git a/backend/src/services/testing/runAllTests.ts b/backend/src/services/testing/runAllTests.ts index 868eb7dc..091ec6fd 100644 --- a/backend/src/services/testing/runAllTests.ts +++ b/backend/src/services/testing/runAllTests.ts @@ -1,9 +1,10 @@ -import { AppDataSource } from "data-source" import { ApiEndpointTest } from "models" import { runTest } from "@metlo/testing" +import { getRepository } from "services/database/utils" +import { MetloContext } from "types" -export const runAllTests = async (): Promise => { - const testRepository = AppDataSource.getRepository(ApiEndpointTest) +export const runAllTests = async (ctx: MetloContext): Promise => { + const testRepository = getRepository(ctx, ApiEndpointTest) const allTests = await testRepository.find({ relations: { apiEndpoint: true, diff --git a/backend/src/suricata_setup/gcp-services/gcp_setup.ts b/backend/src/suricata_setup/gcp-services/gcp_setup.ts index aa455de9..8670510d 100644 --- a/backend/src/suricata_setup/gcp-services/gcp_setup.ts +++ b/backend/src/suricata_setup/gcp-services/gcp_setup.ts @@ -16,9 +16,10 @@ import { format, } from "suricata_setup/ssh-services/ssh-setup" import path from "path" -import { AppDataSource } from "data-source" import { ApiKey } from "models" import { createApiKey } from "api/keys/service" +import { getRepository } from "services/database/utils" +import { MetloContext } from "types" const promiseExec = promisify(exec) @@ -833,18 +834,21 @@ export async function test_ssh({ } } -export async function push_files({ - key_file, - source_private_ip, - project, - id, - instance_url, - ...rest -}: RESPONSE["data"]): Promise { +export async function push_files( + ctx: MetloContext, + { + key_file, + source_private_ip, + project, + id, + instance_url, + ...rest + }: RESPONSE["data"], +): Promise { const instance_name = instance_url.split("/").at(-1) let [key, raw] = createApiKey(`Metlo-collector-${id}`) key.for = API_KEY_TYPE.GCP - let api_key = await AppDataSource.getRepository(ApiKey).save(key) + let api_key = await getRepository(ctx, ApiKey).save(key) try { let filepath_ingestor_out = path.normalize( diff --git a/backend/src/suricata_setup/index.ts b/backend/src/suricata_setup/index.ts index 866831f5..3507be0d 100644 --- a/backend/src/suricata_setup/index.ts +++ b/backend/src/suricata_setup/index.ts @@ -14,7 +14,6 @@ import { import { delete_aws_data } from "./aws-services/delete" import { test_ssh, push_files, execute_commands } from "./ssh-services" import { v4 as uuidv4 } from "uuid" -import { addToRedis, addToRedisFromPromise } from "./utils" import { ConnectionsService } from "services/connections" import { get_destination_subnet, @@ -32,6 +31,8 @@ import { execute_commands as gcp_execute_commands, } from "./gcp-services/gcp_setup" import { delete_gcp_data } from "./gcp-services/delete" +import { RedisClient } from "utils/redis" +import { MetloContext } from "types" function dummy_response(uuid, step, data, type: ConnectionType) { if (type == ConnectionType.AWS) { @@ -62,6 +63,7 @@ function dummy_response(uuid, step, data, type: ConnectionType) { } export async function setup( + ctx: MetloContext, step: number = 0, type: ConnectionType, metadata_for_step: STEP_RESPONSE["data"], @@ -93,8 +95,9 @@ export async function setup( case 10: uuid = uuidv4() resp = dummy_response(uuid, 10, metadata_for_step, ConnectionType.AWS) - await addToRedis(uuid, resp) - addToRedisFromPromise( + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, uuid, test_ssh({ ...metadata, @@ -105,10 +108,11 @@ export async function setup( case 11: uuid = uuidv4() resp = dummy_response(uuid, 11, metadata_for_step, ConnectionType.AWS) - await addToRedis(uuid, resp) - addToRedisFromPromise( + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, uuid, - push_files({ + push_files(ctx, { ...metadata, step: 11, }), @@ -117,15 +121,16 @@ export async function setup( case 12: uuid = uuidv4() resp = dummy_response(uuid, 12, metadata_for_step, ConnectionType.AWS) - await addToRedis(uuid, resp) - addToRedisFromPromise( + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, uuid, execute_commands({ ...metadata, step: 12, } as any).then(resp => { if (resp.status === "COMPLETE") { - ConnectionsService.saveConnectionAws({ + ConnectionsService.saveConnectionAws(ctx, { id: resp.data.id, name: resp.data.name, conn_meta: { ...resp.data } as Required, @@ -150,8 +155,12 @@ export async function setup( case GCP_STEPS.CREATE_DESTINATION_SUBNET: uuid = uuidv4() resp = await dummy_response(uuid, 3, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) - addToRedisFromPromise(uuid, get_destination_subnet(metadata)) + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, + uuid, + get_destination_subnet(metadata), + ) return resp case GCP_STEPS.CREATE_FIREWALL: return await create_firewall_rule(metadata) @@ -160,35 +169,47 @@ export async function setup( case GCP_STEPS.CREATE_MIG: uuid = uuidv4() resp = dummy_response(uuid, 6, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) - addToRedisFromPromise(uuid, create_mig(metadata)) + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise(ctx, uuid, create_mig(metadata)) return resp case GCP_STEPS.CREATE_HEALTH_CHECK: uuid = uuidv4() resp = dummy_response(uuid, 8, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) - addToRedisFromPromise(uuid, create_health_check(metadata)) + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, + uuid, + create_health_check(metadata), + ) return resp case GCP_STEPS.CREATE_BACKEND_SERVICE: uuid = uuidv4() resp = dummy_response(uuid, 9, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) + await RedisClient.addToRedis(ctx, uuid, resp) - addToRedisFromPromise(uuid, create_backend_service(metadata)) + RedisClient.addToRedisFromPromise( + ctx, + uuid, + create_backend_service(metadata), + ) return resp case GCP_STEPS.CREATE_ILB: uuid = uuidv4() resp = dummy_response(uuid, 10, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) + await RedisClient.addToRedis(ctx, uuid, resp) - addToRedisFromPromise(uuid, create_load_balancer(metadata)) + RedisClient.addToRedisFromPromise( + ctx, + uuid, + create_load_balancer(metadata), + ) return resp case GCP_STEPS.START_PACKET_MIRRORING: uuid = uuidv4() resp = dummy_response(uuid, 11, metadata, ConnectionType.GCP) - await addToRedis(uuid, resp) + await RedisClient.addToRedis(ctx, uuid, resp) - addToRedisFromPromise(uuid, packet_mirroring(metadata)) + RedisClient.addToRedisFromPromise(ctx, uuid, packet_mirroring(metadata)) return resp case GCP_STEPS.TEST_SSH: uuid = uuidv4() @@ -198,8 +219,8 @@ export async function setup( metadata, ConnectionType.GCP, ) - await addToRedis(uuid, resp) - addToRedisFromPromise(uuid, gcp_test_ssh(metadata)) + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise(ctx, uuid, gcp_test_ssh(metadata)) return resp case GCP_STEPS.PUSH_FILES: uuid = uuidv4() @@ -209,8 +230,12 @@ export async function setup( metadata_for_step, ConnectionType.GCP, ) - await addToRedis(uuid, resp) - addToRedisFromPromise(uuid, gcp_push_files(metadata)) + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, + uuid, + gcp_push_files(ctx, metadata), + ) return resp case GCP_STEPS.EXEC_COMMAND: uuid = uuidv4() @@ -220,12 +245,13 @@ export async function setup( metadata_for_step, ConnectionType.GCP, ) - await addToRedis(uuid, resp) - addToRedisFromPromise( + await RedisClient.addToRedis(ctx, uuid, resp) + RedisClient.addToRedisFromPromise( + ctx, uuid, gcp_execute_commands(metadata).then(resp => { if (resp.status === "COMPLETE") { - ConnectionsService.saveConnectionGcp({ + ConnectionsService.saveConnectionGcp(ctx, { id: resp.data.id, name: resp.data.name, conn_meta: { diff --git a/backend/src/suricata_setup/ssh-services/index.ts b/backend/src/suricata_setup/ssh-services/index.ts index a77f47b3..c6e3896f 100644 --- a/backend/src/suricata_setup/ssh-services/index.ts +++ b/backend/src/suricata_setup/ssh-services/index.ts @@ -2,8 +2,9 @@ import { API_KEY_TYPE, ConnectionType } from "@common/enums" import { STEP_RESPONSE } from "@common/types" import { createApiKey } from "api/keys/service" import { randomUUID } from "crypto" -import { AppDataSource } from "data-source" import { ApiKey } from "models" +import { getRepository } from "services/database/utils" +import { MetloContext } from "types" import { SSH_CONN, put_data_file, format, remove_file } from "./ssh-setup" type RESPONSE = STEP_RESPONSE @@ -57,19 +58,22 @@ export async function test_ssh({ } } -export async function push_files({ - keypair, - remote_machine_url, - source_private_ip, - username, - step, - id, - ...rest -}: RESPONSE["data"] & { step: number }): Promise { +export async function push_files( + ctx: MetloContext, + { + keypair, + remote_machine_url, + source_private_ip, + username, + step, + id, + ...rest + }: RESPONSE["data"] & { step: number }, +): Promise { let conn = new SSH_CONN(keypair, remote_machine_url, username) let [key, raw] = createApiKey(`Metlo-collector-${id}`) key.for = API_KEY_TYPE.AWS - let api_key = await AppDataSource.getRepository(ApiKey).save(key) + let api_key = await getRepository(ctx, ApiKey).save(key) try { let filepath_ingestor = `${__dirname}/../generics/scripts/metlo-ingestor-${randomUUID()}.service` let filepath_rules = `${__dirname}/../generics/scripts/local-${randomUUID()}.rules` diff --git a/backend/src/suricata_setup/utils/index.ts b/backend/src/suricata_setup/utils/index.ts deleted file mode 100644 index 72e1368a..00000000 --- a/backend/src/suricata_setup/utils/index.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { RedisClient } from "utils/redis" - -async function addToRedis(key, data: Object) { - const redisClient = RedisClient.getInstance() - await redisClient.set(key, JSON.stringify(data)) -} -function addToRedisFromPromise(key, data: Promise) { - const redisClient = RedisClient.getInstance() - data - .then(resp => redisClient.set(key, JSON.stringify(resp))) - .catch(err => redisClient.set(key, JSON.stringify(err))) -} - -async function getFromRedis(key) { - const redisClient = RedisClient.getInstance() - return JSON.parse(await redisClient.get(key)) -} - -async function deleteKeyFromRedis(key) { - const redisClient = RedisClient.getInstance() - return await redisClient.del([key]) -} - -export { addToRedis, addToRedisFromPromise, getFromRedis, deleteKeyFromRedis } diff --git a/backend/src/types.ts b/backend/src/types.ts new file mode 100644 index 00000000..5c792c8d --- /dev/null +++ b/backend/src/types.ts @@ -0,0 +1,7 @@ +import { Request } from "express" + +export interface MetloContext {} + +export interface MetloRequest extends Request { + ctx?: MetloContext +} diff --git a/backend/src/utils/db.ts b/backend/src/utils/db.ts new file mode 100644 index 00000000..c93224b7 --- /dev/null +++ b/backend/src/utils/db.ts @@ -0,0 +1,41 @@ +import { QueryFailedError } from "typeorm" +import { DatabaseError } from "pg-protocol" + +const delay = (fn: any, ms: number) => + new Promise(resolve => setTimeout(() => resolve(fn()), ms)) + +const randInt = (min: number, max: number) => + Math.floor(Math.random() * (max - min + 1) + min) + +export const isQueryFailedError = ( + err: unknown, +): err is QueryFailedError & DatabaseError => err instanceof QueryFailedError + +export const retryTypeormTransaction = async (fn: any, maxAttempts: number) => { + const execute = async (attempt: number) => { + try { + return await fn() + } catch (err) { + if (isQueryFailedError(err)) { + if (err.code === "40P01" || err.code === "55P03") { + if (attempt <= maxAttempts) { + const nextAttempt = attempt + 1 + const delayInMilliseconds = randInt(200, 1000) + console.error( + `Retrying after ${delayInMilliseconds} ms due to:`, + err, + ) + return delay(() => execute(nextAttempt), delayInMilliseconds) + } else { + throw err + } + } else { + throw err + } + } else { + throw err + } + } + } + return execute(1) +} diff --git a/backend/src/utils/redis.ts b/backend/src/utils/redis.ts index 2fa0e428..655d0d15 100644 --- a/backend/src/utils/redis.ts +++ b/backend/src/utils/redis.ts @@ -1,4 +1,5 @@ import IORedis from "ioredis" +import { MetloContext } from "types" export class RedisClient { private static instance: RedisClient @@ -16,16 +17,21 @@ export class RedisClient { return RedisClient.client } - public static addToRedis(key: string, data: Object, expireIn?: number) { + public static async addToRedis( + ctx: MetloContext, + key: string, + data: Object, + expireIn?: number, + ) { try { - this.getInstance().set(key, JSON.stringify(data)) + await this.getInstance().set(key, JSON.stringify(data)) if (expireIn) { this.getInstance().expire(key, expireIn, "NX") } } catch {} } - public static async getFromRedis(key: string) { + public static async getFromRedis(ctx: MetloContext, key: string) { try { return JSON.parse(await this.getInstance().get(key)) } catch { @@ -33,7 +39,7 @@ export class RedisClient { } } - public static async deleteFromRedis(keys: string[]) { + public static async deleteFromRedis(ctx: MetloContext, keys: string[]) { try { return await this.getInstance().del(keys) } catch (err) { @@ -41,23 +47,33 @@ export class RedisClient { } } - public static pushValueToRedisList( + public static async deleteKeyFromRedis(ctx: MetloContext, key: string) { + const redisClient = RedisClient.getInstance() + return await redisClient.del([key]) + } + + public static async pushValueToRedisList( + ctx: MetloContext, key: string, data: (string | number | Buffer)[], right?: boolean, ) { try { if (right) { - this.getInstance().rpush(key, ...data) + await this.getInstance().rpush(key, ...data) } else { - this.getInstance().lpush(key, ...data) + await this.getInstance().lpush(key, ...data) } } catch (err) { console.error(`Error pushing value to redis list: ${err}`) } } - public static async popValueFromRedisList(key: string, right?: boolean) { + public static async popValueFromRedisList( + ctx: MetloContext, + key: string, + right?: boolean, + ) { try { if (right) { return await this.getInstance().rpop(key) @@ -69,15 +85,19 @@ export class RedisClient { } } - public static addValueToSet(key: string, data: string[]) { + public static async addValueToSet( + ctx: MetloContext, + key: string, + data: string[], + ) { try { - this.getInstance().sadd(key, ...data) + await this.getInstance().sadd(key, ...data) } catch (err) { console.error(`Error adding value to redis set: ${err}`) } } - public static async getValuesFromSet(key: string) { + public static async getValuesFromSet(ctx: MetloContext, key: string) { try { return await this.getInstance().smembers(key) } catch (err) { @@ -86,6 +106,7 @@ export class RedisClient { } public static async getListValueFromRedis( + ctx: MetloContext, key: string, start: number, end: number, @@ -98,11 +119,22 @@ export class RedisClient { } } - public static async getListLength(key: string) { + public static async getListLength(ctx: MetloContext, key: string) { try { return await this.getInstance().llen(key) } catch (err) { return 0 } } + + public static addToRedisFromPromise( + ctx: MetloContext, + key: string, + data: Promise, + ) { + const redisClient = RedisClient.getInstance() + data + .then(resp => redisClient.set(key, JSON.stringify(resp))) + .catch(err => redisClient.set(key, JSON.stringify(err))) + } } diff --git a/backend/yarn.lock b/backend/yarn.lock index f13fabf1..c0c6dac1 100644 --- a/backend/yarn.lock +++ b/backend/yarn.lock @@ -1083,11 +1083,6 @@ resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9" integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA== -"@types/memory-cache@^0.2.2": - version "0.2.2" - resolved "https://registry.yarnpkg.com/@types/memory-cache/-/memory-cache-0.2.2.tgz#f8fb6d8aa0eb006ed44fc659bf8bfdc1a5cc57fa" - integrity sha512-xNnm6EkmYYhTnLiOHC2bdKgcYY5qjjrq5vl9KXD2nh0em0koZoFS500EL4Q4V/eW+A3P7NC7P7GIYzNOSQp7jQ== - "@types/mime@*": version "3.0.1" resolved "https://registry.yarnpkg.com/@types/mime/-/mime-3.0.1.tgz#5f8f2bca0a5863cb69bc0b0acd88c96cb1d4ae10" @@ -1100,14 +1095,6 @@ dependencies: "@types/express" "*" -"@types/newman@^5.3.0": - version "5.3.0" - resolved "https://registry.yarnpkg.com/@types/newman/-/newman-5.3.0.tgz#3ca70f06adc7d653c603c83dc1c6433dbd2f0593" - integrity sha512-3w2C8Rqo38BXJFWg3zDqTy7qRPBxDSZzEW54UqHghkLqQKAFeC/VwbSuevK+iOdpNwKn3fal8SowpFjdWtVDow== - dependencies: - "@types/postman-collection" "*" - "@types/tough-cookie" "*" - "@types/node-schedule@^2.1.0": version "2.1.0" resolved "https://registry.yarnpkg.com/@types/node-schedule/-/node-schedule-2.1.0.tgz#60375640c0509bab963573def9d1f417f438c290" @@ -1125,13 +1112,6 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-18.7.9.tgz#180bfc495c91dc62573967edf047e15dbdce1491" integrity sha512-0N5Y1XAdcl865nDdjbO0m3T6FdmQ4ijE89/urOHLREyTXbpMWbSafx9y7XIsgWGtwUP2iYTinLyyW3FatAxBLQ== -"@types/postman-collection@*": - version "3.5.7" - resolved "https://registry.yarnpkg.com/@types/postman-collection/-/postman-collection-3.5.7.tgz#c62fed598928cb0c45f3287782bb4d5c4127d1b3" - integrity sha512-wqZ/MlGEYP+RoiofnAnKDJAHxRzmMK97CeFLoHPNoHdHX0uyBLCDl+uZV9x4xuPVRjkeM4xcarIaUaUk47c7SQ== - dependencies: - "@types/node" "*" - "@types/qs@*": version "6.9.7" resolved "https://registry.yarnpkg.com/@types/qs/-/qs-6.9.7.tgz#63bb7d067db107cc1e457c303bc25d511febf6cb" @@ -1172,11 +1152,6 @@ resolved "https://registry.yarnpkg.com/@types/stack-utils/-/stack-utils-2.0.1.tgz#20f18294f797f2209b5f65c8e3b5c8e8261d127c" integrity sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw== -"@types/tough-cookie@*": - version "4.0.2" - resolved "https://registry.yarnpkg.com/@types/tough-cookie/-/tough-cookie-4.0.2.tgz#6286b4c7228d58ab7866d19716f3696e03a09397" - integrity sha512-Q5vtl1W5ue16D+nIaW8JWebSSraJVlK+EthKn7e7UcD4KWsaSJ8BqGPXNaPghgtcn/fhvrN17Tv8ksUsQpiplw== - "@types/uuid@^8.3.4": version "8.3.4" resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-8.3.4.tgz#bd86a43617df0594787d38b735f55c805becf1bc" @@ -2914,11 +2889,6 @@ media-typer@0.3.0: resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748" integrity sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ== -memory-cache@^0.2.0: - version "0.2.0" - resolved "https://registry.yarnpkg.com/memory-cache/-/memory-cache-0.2.0.tgz#7890b01d52c00c8ebc9d533e1f8eb17e3034871a" - integrity sha512-OcjA+jzjOYzKmKS6IQVALHLVz+rNTMPoJvCztFaZxwG14wtAW7VRZjwTQu06vKCYOxh4jVnik7ya0SXTB0W+xA== - merge-descriptors@1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/merge-descriptors/-/merge-descriptors-1.0.1.tgz#b00aaa556dd8b44568150ec9d1b953f3f90cbb61"