-
Notifications
You must be signed in to change notification settings - Fork 0
/
vendor_embed.py
178 lines (147 loc) · 5.44 KB
/
vendor_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import json
import logging
import os
import uuid
from enum import StrEnum
import boto3
from pydantic import BaseModel
from pydantic import ValidationError
from pymongo import MongoClient
# Configure the root logger for AWS Lambda
logger = logging.getLogger()
for handler in logger.handlers:
logger.removeHandler(handler)
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
def get_mongo_client():
mongodb_uri = os.getenv("MONGODB_URI")
logger.info("Creating MongoDB client.")
return MongoClient(mongodb_uri, uuidRepresentation="standard")
def get_knowledge_source_id_from_app(app_id: str) -> str:
logger.info(f"Fetching source ID for app_id: {app_id}")
# Connect to MongoDB and fetch the document
try:
client = get_mongo_client()
db = client["chat"]
apps = db.knowledge_apps
# Convert app_id to a proper UUID format
try:
app_uuid = uuid.UUID(app_id)
except ValueError as e:
logger.error(f"Invalid app_id format: {app_id}. Error: {e}")
raise
# Get the app
app = apps.find_one({"_id": app_uuid})
if app and app.get("knowledge_sources"):
logger.info(f"App found for app_id: {app_id}")
return str(app["knowledge_sources"][0].oid)
elif not app:
logger.error(f"No app found for app_id: {app_id}")
raise ValueError(f"No app found for app_id: {app_id}")
else:
logger.error(f"No knowledge sources found for app_id: {app_id}")
raise ValueError(f"No knowledge sources found for app_id: {app_id}")
finally:
if client:
client.close()
class DependsOnType(StrEnum):
N_TO_N = "N_TO_N"
SEQUENTIAL = "SEQUENTIAL"
class DependsOn(BaseModel):
jobId: str
type: DependsOnType
def model_dump(self, **kwargs) -> dict:
return {"jobId": self.jobId, "type": self.type.value}
class EnvVar(BaseModel):
name: str
value: str
class ContainerOverrides(BaseModel):
command: list[str]
environment: list[EnvVar]
def remove_special_characters(input_str: str) -> str:
return "".join(e for e in input_str if e.isalnum())
class BatchManager:
@staticmethod
def submit_embed_job(
job_name: str,
job_queue: str,
job_definition: str,
container_overrides: ContainerOverrides,
depends_on: list[DependsOn] = [],
) -> str:
client = boto3.client("batch")
try:
response = client.submit_job(
jobName=job_name,
jobQueue=job_queue,
jobDefinition=job_definition,
containerOverrides={
"command": container_overrides.command,
"environment": [
{"name": env.name, "value": env.value}
for env in container_overrides.environment
],
},
dependsOn=[{"jobId": dep.jobId, "type": dep.type.value} for dep in depends_on],
)
logger.info(f"Embed job '{job_name}' submitted successfully. ID: {response['jobId']}")
return response["jobId"]
except Exception as e:
logger.error(
f"Failed to submit embed job '{job_name}'. Error: {str(e)}",
exc_info=True,
)
raise e
class EmbedJobEvent(BaseModel):
name: str
id: str
def handle_vendor_embed_request(event, context):
try:
logger.debug("Received event body: %s", event["body"])
event = json.loads(event["body"])
validated_event = EmbedJobEvent(**event)
logger.info(
f"Event validated successfully for embed job with name: '{validated_event.name}' and id: '{validated_event.id}'"
)
except json.JSONDecodeError as json_err:
logger.error("JSON parsing error: %s", str(json_err), exc_info=True)
return {
"statusCode": 400,
"body": json.dumps(f"JSON parsing error: {str(json_err)}"),
}
except ValidationError as validation_err:
logger.error("Validation error: %s", str(validation_err), exc_info=True)
return {
"statusCode": 400,
"body": json.dumps(f"Validation error: {str(validation_err)}"),
}
try:
prepare_embed_job(validated_event.name, validated_event.id)
return {
"statusCode": 200,
"body": json.dumps("Embed job submitted successfully"),
}
except Exception as e:
logger.error("Error submitting embed job: %s", str(e), exc_info=True)
return {
"statusCode": 500,
"body": json.dumps(f"Error submitting embed job: {str(e)}"),
}
def prepare_embed_job(name: str, app_id: str) -> str:
logger.debug(f"Preparing embed job for '{name}'")
source_id = get_knowledge_source_id_from_app(app_id)
logger.info(f"Source ID for '{name}': {source_id}")
command = ["emb", "update-website", source_id, "--namespace", app_id]
job_id = BatchManager.submit_embed_job(
job_name=remove_special_characters(name),
job_queue="website-processing-queue",
job_definition="embedding-job",
container_overrides=ContainerOverrides(
command=command,
environment=[],
),
depends_on=[],
)
logger.info(f"Embed job for '{name}' prepared successfully. Job ID: {job_id}")
return job_id