diff --git a/db/migrations/20240417164251_alter_user_add_publish_rate.js b/db/migrations/20240417164251_alter_user_add_publish_rate.js new file mode 100644 index 000000000..39d884b1e --- /dev/null +++ b/db/migrations/20240417164251_alter_user_add_publish_rate.js @@ -0,0 +1,13 @@ +const table = 'user' + +exports.up = async (knex) => { + await knex.schema.table(table, (t) => { + t.jsonb('publish_rate') + }) +} + +exports.down = async (knex) => { + await knex.schema.table(table, (t) => { + t.dropColumn('publish_rate') + }) +} diff --git a/src/common/enums/user.ts b/src/common/enums/user.ts index b7d1a33ca..04e7e6ca9 100644 --- a/src/common/enums/user.ts +++ b/src/common/enums/user.ts @@ -1,3 +1,5 @@ +import { isProd } from 'common/environment' + export const USER_STATE = { frozen: 'frozen', active: 'active', @@ -16,3 +18,6 @@ export const AUTHOR_TYPE = { default: 'default', trendy: 'trendy', } as const + +export const PUBLISH_ARTICLE_RATE_LIMIT = isProd ? 1 : 1000 +export const PUBLISH_ARTICLE_RATE_PERIOD = 720 // for 12 minutes; diff --git a/src/common/utils/index.ts b/src/common/utils/index.ts index ab340ad63..f8f560790 100644 --- a/src/common/utils/index.ts +++ b/src/common/utils/index.ts @@ -25,6 +25,7 @@ export * from './genDisplayName' export * from './counter' export * from './verify' export * from './nanoid' +export * from './rateLimit' /** * Make a valid user name based on a given email address. It removes all special characters including _. diff --git a/src/common/utils/rateLimit.ts b/src/common/utils/rateLimit.ts new file mode 100644 index 000000000..f79239a12 --- /dev/null +++ b/src/common/utils/rateLimit.ts @@ -0,0 +1,59 @@ +export const checkOperationLimit = async ({ + user, + operation, + limit, + period, + redis, +}: { + user: string + operation: string + limit: number + period: number + redis: Redis +}) => { + const cacheKey = genCacheKey({ + id: user, + field: operation, + prefix: CACHE_PREFIX.OPERATION_LOG, + }) + + const operationLog = await redis.lrange(cacheKey, 0, -1) + + // timestamp in seconds + const current = Math.floor(Date.now() / 1000) + + // no record + if (!operationLog) { + // create + redis.lpush(cacheKey, current).then(() => { + redis.expire(cacheKey, period) + }) + + // pass + return true + } + + // count times within period + const cutoff = current - period + let times = 0 + for (const timestamp of operationLog) { + if (parseInt(timestamp, 10) >= cutoff) { + times += 1 + } else { + break + } + } + + // over limit + if (times >= limit) { + return false + } + + // add, trim, update expiration + redis.lpush(cacheKey, current) + redis.ltrim(cacheKey, 0, times) + redis.expire(cacheKey, period) + + // pass + return true +} diff --git a/src/mutations/article/publishArticle.ts b/src/mutations/article/publishArticle.ts index f695c9782..d56e75a09 100644 --- a/src/mutations/article/publishArticle.ts +++ b/src/mutations/article/publishArticle.ts @@ -1,6 +1,12 @@ import type { GQLMutationResolvers } from 'definitions' +import type { Redis } from 'ioredis' -import { PUBLISH_STATE, USER_STATE } from 'common/enums' +import { + PUBLISH_ARTICLE_RATE_LIMIT, + PUBLISH_ARTICLE_RATE_PERIOD, + PUBLISH_STATE, + USER_STATE, +} from 'common/enums' import { DraftNotFoundError, ForbiddenByStateError, @@ -8,6 +14,7 @@ import { UserInputError, } from 'common/errors' import { fromGlobalId } from 'common/utils' +import { checkOperationLimit } from 'types/directives' const resolver: GQLMutationResolvers['publishArticle'] = async ( _, @@ -18,6 +25,7 @@ const resolver: GQLMutationResolvers['publishArticle'] = async ( draftService, atomService, queues: { publicationQueue }, + redis, }, } ) => { @@ -50,6 +58,20 @@ const resolver: GQLMutationResolvers['publishArticle'] = async ( throw new UserInputError('content is required') } + const pass = await checkOperationLimit({ + user: viewer.id || viewer.ip, + operation: 'publishArticle', + limit: viewer?.publishRate?.limit ?? PUBLISH_ARTICLE_RATE_LIMIT, + period: viewer?.publishRate?.period ?? PUBLISH_ARTICLE_RATE_PERIOD, + redis, // : connections.redis, + }) + + if (!pass) { + throw new ActionLimitExceededError( + `rate exceeded for operation ${fieldName}` + ) + } + if ( draft.publishState === PUBLISH_STATE.pending || (draft.archived && isPublished) diff --git a/src/types/directives/rateLimit.ts b/src/types/directives/rateLimit.ts index 4042a28f8..fdcdd76dc 100644 --- a/src/types/directives/rateLimit.ts +++ b/src/types/directives/rateLimit.ts @@ -5,68 +5,9 @@ import { defaultFieldResolver, GraphQLSchema } from 'graphql' import { CACHE_PREFIX } from 'common/enums' import { ActionLimitExceededError } from 'common/errors' +import { checkOperationLimit } from 'common/utils' import { genCacheKey } from 'connectors' -const checkOperationLimit = async ({ - user, - operation, - limit, - period, - redis, -}: { - user: string - operation: string - limit: number - period: number - redis: Redis -}) => { - const cacheKey = genCacheKey({ - id: user, - field: operation, - prefix: CACHE_PREFIX.OPERATION_LOG, - }) - - const operationLog = await redis.lrange(cacheKey, 0, -1) - - // timestamp in seconds - const current = Math.floor(Date.now() / 1000) - - // no record - if (!operationLog) { - // create - redis.lpush(cacheKey, current).then(() => { - redis.expire(cacheKey, period) - }) - - // pass - return true - } - - // count times within period - const cutoff = current - period - let times = 0 - for (const timestamp of operationLog) { - if (parseInt(timestamp, 10) >= cutoff) { - times += 1 - } else { - break - } - } - - // over limit - if (times >= limit) { - return false - } - - // add, trim, update expiration - redis.lpush(cacheKey, current) - redis.ltrim(cacheKey, 0, times) - redis.expire(cacheKey, period) - - // pass - return true -} - export const rateLimitDirective = (directiveName = 'rateLimit') => ({ typeDef: `"Rate limit within a given period of time, in seconds" directive @${directiveName}(period: Int!, limit: Int!) on FIELD_DEFINITION`,