Skip to content

Commit

Permalink
feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596
Browse files Browse the repository at this point in the history
* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596

* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596

* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596

* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596

* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596

* feat: 优化向量化任务数据库查询效率 TencentBlueKing#2596
  • Loading branch information
cnlkl authored Sep 25, 2024
1 parent a834a96 commit b79129f
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ package com.tencent.bkrepo.job.batch.task.cache.preload

import com.tencent.bkrepo.auth.constant.PIPELINE
import com.tencent.bkrepo.common.artifact.event.base.EventType
import com.tencent.bkrepo.common.mongo.constant.ID
import com.tencent.bkrepo.common.mongo.constant.MIN_OBJECT_ID
import com.tencent.bkrepo.common.mongo.dao.util.sharding.MonthRangeShardingUtils
import com.tencent.bkrepo.common.operate.service.model.TOperateLog
import com.tencent.bkrepo.job.batch.base.DefaultContextJob
Expand All @@ -41,19 +43,22 @@ import com.tencent.bkrepo.job.batch.task.cache.preload.ai.milvus.MilvusClient
import com.tencent.bkrepo.job.batch.task.cache.preload.ai.milvus.MilvusVectorStore
import com.tencent.bkrepo.job.batch.task.cache.preload.ai.milvus.MilvusVectorStoreProperties
import com.tencent.bkrepo.job.config.properties.ArtifactAccessLogEmbeddingJobProperties
import org.bson.types.ObjectId
import org.slf4j.LoggerFactory
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.data.domain.Sort
import org.springframework.data.mongodb.core.MongoTemplate
import org.springframework.data.mongodb.core.find
import org.springframework.data.mongodb.core.findOne
import org.springframework.data.mongodb.core.query.Criteria
import org.springframework.data.mongodb.core.query.Query
import org.springframework.data.mongodb.core.query.isEqualTo
import org.springframework.data.mongodb.core.query.gte
import org.springframework.stereotype.Component
import java.time.Duration
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.ZoneId
import kotlin.math.abs
import kotlin.system.measureTimeMillis

@Component
Expand Down Expand Up @@ -120,12 +125,16 @@ class ArtifactAccessLogEmbeddingJob(
after: LocalDateTime? = null,
before: LocalDateTime? = null
) {
properties.projects.forEach { projectId ->
processDataBatch(projectId, minusMonth, after, before) { paths ->
val documents = paths.map { Document(content = it, metadata = emptyMap()) }
val elapsed = measureTimeMillis { insert(documents) }
logger.info("[$projectId] insert ${documents.size} data into [${collectionName()}] in $elapsed ms")
findAndHandle(minusMonth, after, before) { projectId, paths ->
val documents = paths.map {
val metadata = mapOf(
METADATA_KEY_DOWNLOAD_TIMESTAMP to it.value.downloadTimestamp.joinToString(","),
METADATA_KEY_ACCESS_COUNT to it.value.count.toString()
)
Document(content = it.key, metadata = metadata)
}
val elapsed = measureTimeMillis { insert(documents) }
logger.info("[$projectId] insert ${documents.size} data into [${collectionName()}] in $elapsed ms")
}
}

Expand All @@ -141,62 +150,114 @@ class ArtifactAccessLogEmbeddingJob(
return MilvusVectorStore(config, milvusClient, embeddingModel)
}

/**
* 获取有访问记录的路径
*/
private fun processDataBatch(
projectId: String,

private fun findAndHandle(
minusMonth: Long,
after: LocalDateTime?,
before: LocalDateTime?,
handler: (paths: Set<String>) -> Unit,
after: LocalDateTime? = null,
before: LocalDateTime? = null,
handler: (String, Map<String, AccessLog>) -> Unit
) {
val collectionName = collectionName(minusMonth)
// buffer存储的内容结构为(projectId, (path, accessLog))
val projectBuffer = HashMap<String, MutableMap<String, AccessLog>>()
iterateCollection(collectionName, findFirstObjectId(collectionName, after)) { operateLog ->
val createDate = operateLog.createdDate
val outOfDateRange =
after != null && createDate.isBefore(after) || before != null && createDate.isAfter(before)
val acceptableType = operateLog.type == EventType.NODE_DOWNLOADED.name
val acceptableProject = operateLog.projectId in properties.projects

if (!outOfDateRange && acceptableProject && acceptableType) {
val shouldFlush = projectBuffer.addToBuffer(operateLog)
if (shouldFlush) {
handler(operateLog.projectId, projectBuffer[operateLog.projectId]!!)
projectBuffer.remove(operateLog.projectId)
}
}
}
projectBuffer.forEach { (projectId, paths) -> handler(projectId, paths) }
}

private fun HashMap<String, MutableMap<String, AccessLog>>.addToBuffer(operateLog: OperateLog): Boolean {
with(operateLog) {
val projectRepoFullPath = if (repoName == PIPELINE) {
// 流水线仓库路径/p-xxx/b-xxx/xxx中的构建id不参与相似度计算
val secondSlashIndex = resourceKey.indexOf("/", 1)
val pipelinePath = resourceKey.substring(0, secondSlashIndex)
val artifactPath = resourceKey.substring(resourceKey.indexOf("/", secondSlashIndex + 1))
"/$projectId/$repoName$pipelinePath$artifactPath"
} else {
"/$projectId/$repoName$resourceKey"
}
val buffer = getOrPut(projectId) { HashMap() }
val accessLog = buffer.getOrPut(projectRepoFullPath) {
AccessLog(
projectId = projectId,
repoName = repoName,
fullPath = resourceKey,
projectRepoFullPath = projectRepoFullPath
)
}
accessLog.count += 1
// 只添加间隔超过10分钟的下载时间戳,可能会导致时间戳数量小于count
val createdTimestamp = createdDate.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli()
val lastTimestamp = accessLog.downloadTimestamp.lastOrNull() ?: 0L
if (abs(createdTimestamp - lastTimestamp) > 600_000) {
accessLog.downloadTimestamp.add(createdTimestamp)
}
return buffer.size >= properties.batchToInsert ||
accessLog.downloadTimestamp.size >= properties.batchToInsert
}
}

private fun iterateCollection(collectionName: String, startId: ObjectId, handler: (OperateLog) -> Unit) {
if (!mongoTemplate.collectionExists(collectionName)) {
logger.warn("mongo collection[$collectionName] not exists")
return
}
val pageSize = properties.batchSize
var offset = 0L
var querySize: Int
val criteria = buildCriteria(projectId, after, before)
val count = mongoTemplate.count(Query(), collectionName)
var progress = 0
var records: List<OperateLog>
var lastId = startId
do {
val query = Query(criteria)
.limit(pageSize)
.skip(offset)
.with(Sort.by(TOperateLog::projectId.name).ascending())
query.fields().include(
TOperateLog::repoName.name,
TOperateLog::resourceKey.name,
TOperateLog::createdDate.name
)

val query = buildQuery(lastId)
val start = System.currentTimeMillis()
val data = mongoTemplate.find<Map<String, Any?>>(query, collectionName)
logger.info("find [$projectId] access log from db elapsed[${System.currentTimeMillis() - start}]ms")
if (data.isEmpty()) {
break
}
// 记录制品访问时间
val accessPaths = data.mapTo(HashSet(pageSize)) {
val repoName = it[TOperateLog::repoName.name] as String
val fullPath = it[TOperateLog::resourceKey.name] as String
val projectRepoFullPath = if (repoName == PIPELINE) {
// 流水线仓库路径/p-xxx/b-xxx/xxx中的构建id不参与相似度计算
val secondSlashIndex = fullPath.indexOf("/", 1)
val pipelinePath = fullPath.substring(0, secondSlashIndex)
val artifactPath = fullPath.substring(fullPath.indexOf("/", secondSlashIndex + 1))
pipelinePath + artifactPath
} else {
fullPath
}
"/$projectId/$repoName$projectRepoFullPath"
records = mongoTemplate.find(query, OperateLog::class.java, collectionName)

progress += records.size
if (progress % 1000000 == 0) {
val end = System.currentTimeMillis()
logger.info("find access log from db elapsed[${end - start}]ms, $progress/$count")
}

handler(accessPaths)
querySize = data.size
offset += data.size
} while (querySize == pageSize && shouldRun())
records.forEach { handler(it) }
lastId = records.lastOrNull()?.id ?: break
} while (records.size == query.limit && shouldRun())
}

private fun buildQuery(lastId: ObjectId): Query {
val query = Query(Criteria.where(ID).gt(lastId)).limit(properties.batchSize).with(Sort.by(ID).ascending())
query.fields().include(
ID,
TOperateLog::projectId.name,
TOperateLog::repoName.name,
TOperateLog::type.name,
TOperateLog::resourceKey.name,
TOperateLog::createdDate.name
)
return query
}

private fun findFirstObjectId(collectionName: String, after: LocalDateTime?): ObjectId {
if (after == null) {
return ObjectId(MIN_OBJECT_ID)
}
// 找到after之前1小时的记录作为起始遍历点,1小时之前没有访问记录表示访问量较小可以直接从最小ID开始遍历
val startDateTime = after.minusHours(1L)
val query = Query(TOperateLog::createdDate.gte(startDateTime).lt(after))
query.fields().include(ID)
val id = mongoTemplate.findOne<Map<String, Any?>>(query, collectionName)?.let { it[ID] as ObjectId }
return id ?: ObjectId(MIN_OBJECT_ID)
}

private fun collectionName(minusMonth: Long): String {
Expand All @@ -205,20 +266,27 @@ class ArtifactAccessLogEmbeddingJob(
return "artifact_oplog_$seq"
}

private fun buildCriteria(projectId: String, after: LocalDateTime?, before: LocalDateTime?): Criteria {
val criteria = Criteria
.where(TOperateLog::projectId.name).isEqualTo(projectId)
.and(TOperateLog::type.name).isEqualTo(EventType.NODE_DOWNLOADED.name)
if (after != null && before != null) {
criteria.and(TOperateLog::createdDate.name).gte(after).lt(before)
} else {
after?.let { criteria.and(TOperateLog::createdDate.name).gte(it) }
before?.let { criteria.and(TOperateLog::createdDate.name).lt(it) }
}
return criteria
}
private data class OperateLog(
var id: ObjectId,
val projectId: String,
val repoName: String,
val resourceKey: String,
val createdDate: LocalDateTime,
val type: String,
)

private data class AccessLog(
val projectId: String,
val repoName: String,
val fullPath: String,
val projectRepoFullPath: String,
var count: Long = 0,
val downloadTimestamp: LinkedHashSet<Long> = LinkedHashSet(),
)

companion object {
private val logger = LoggerFactory.getLogger(ArtifactAccessLogEmbeddingJob::class.java)
private const val METADATA_KEY_DOWNLOAD_TIMESTAMP = "download_timestamp"
private const val METADATA_KEY_ACCESS_COUNT = "access_count"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ class ArtifactAccessLogEmbeddingJobProperties(
/**
* 需要将访问记录保存到向量数据库的项目
*/
var projects: Set<String> = emptySet()
var projects: Set<String> = emptySet(),
/**
* 批量向量化并写入向量数据库的数量
*/
var batchToInsert: Int = 500,
) : MongodbJobProperties(enabled)

0 comments on commit b79129f

Please sign in to comment.