Skip to content

Commit

Permalink
ydb finished
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Oct 4, 2024
1 parent e1cb50d commit 782bf66
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 365 deletions.
62 changes: 31 additions & 31 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,19 @@ def __init__(
self._mongo = AsyncIOMotorClient(self.full_path, uuidRepresentation="standard")
db = self._mongo.get_default_database()

self._main_table = db[f"{collection_prefix}_{self._main_table_name}"]
self._turns_table = db[f"{collection_prefix}_{self._turns_table_name}"]
self._misc_table = db[f"{collection_prefix}_{self._misc_table_name}"]
self.main_table = db[f"{collection_prefix}_{self._main_table_name}"]
self.turns_table = db[f"{collection_prefix}_{self._turns_table_name}"]
self.misc_table = db[f"{collection_prefix}_{self._misc_table_name}"]

asyncio.run(
asyncio.gather(
self._main_table.create_index(
self.main_table.create_index(
self._id_column_name, background=True, unique=True
),
self._turns_table.create_index(
self.turns_table.create_index(
[self._id_column_name, self._key_column_name], background=True, unique=True
),
self._misc_table.create_index(
self.misc_table.create_index(
[self._id_column_name, self._key_column_name], background=True, unique=True
)
)
Expand All @@ -82,25 +82,25 @@ def __init__(
# TODO: this method (and similar) repeat often. Optimize?
def _get_config_for_field(self, field_name: str) -> Tuple[Collection, str, FieldConfig]:
if field_name == self.labels_config.name:
return self._turns_table, field_name, self.labels_config
return self.turns_table, field_name, self.labels_config
elif field_name == self.requests_config.name:
return self._turns_table, field_name, self.requests_config
return self.turns_table, field_name, self.requests_config
elif field_name == self.responses_config.name:
return self._turns_table, field_name, self.responses_config
return self.turns_table, field_name, self.responses_config
elif field_name == self.misc_config.name:
return self._misc_table, self._value_column_name, self.misc_config
return self.misc_table, self._value_column_name, self.misc_config
else:
raise ValueError(f"Unknown field name: {field_name}!")

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes]]:
result = await self._main_table.find_one(
result = await self.main_table.find_one(
{self._id_column_name: ctx_id},
[self._current_turn_id_column_name, self._created_at_column_name, self._updated_at_column_name, self._framework_data_column_name]
)
return (result[self._current_turn_id_column_name], result[self._created_at_column_name], result[self._updated_at_column_name], result[self._framework_data_column_name]) if result is not None else None

async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at: int, fw_data: bytes) -> None:
await self._main_table.update_one(
await self.main_table.update_one(
{self._id_column_name: ctx_id},
{
"$set": {
Expand All @@ -116,62 +116,62 @@ async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at:

async def delete_context(self, ctx_id: str) -> None:
await asyncio.gather(
self._main_table.delete_one({self._id_column_name: ctx_id}),
self._turns_table.delete_one({self._id_column_name: ctx_id}),
self._misc_table.delete_one({self._id_column_name: ctx_id})
self.main_table.delete_one({self._id_column_name: ctx_id}),
self.turns_table.delete_one({self._id_column_name: ctx_id}),
self.misc_table.delete_one({self._id_column_name: ctx_id})
)

async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[Hashable, bytes]]:
field_table, field_name, field_config = self._get_config_for_field(field_name)
field_table, key_name, field_config = self._get_config_for_field(field_name)
sort, limit, key = None, 0, dict()
if field_table == self._turns_table:
if field_table == self.turns_table:
sort = [(self._key_column_name, -1)]
if isinstance(field_config.subscript, int):
limit = field_config.subscript
elif isinstance(field_config.subscript, Set):
key = {self._key_column_name: {"$in": list(field_config.subscript)}}
result = await field_table.find(
{self._id_column_name: ctx_id, field_name: {"$exists": True, "$ne": None}, **key},
[self._key_column_name, field_name],
{self._id_column_name: ctx_id, key_name: {"$exists": True, "$ne": None}, **key},
[self._key_column_name, key_name],
sort=sort
).limit(limit).to_list(None)
return [(item[self._key_column_name], item[field_name]) for item in result]
return [(item[self._key_column_name], item[key_name]) for item in result]

async def load_field_keys(self, ctx_id: str, field_name: str) -> List[Hashable]:
field_table, field_name, _ = self._get_config_for_field(field_name)
field_table, key_name, _ = self._get_config_for_field(field_name)
result = await field_table.aggregate(
[
{"$match": {self._id_column_name: ctx_id, field_name: {"$ne": None}}},
{"$match": {self._id_column_name: ctx_id, key_name: {"$ne": None}}},
{"$group": {"_id": None, self._UNIQUE_KEYS: {"$addToSet": f"${self._key_column_name}"}}},
]
).to_list(None)
return result[0][self._UNIQUE_KEYS] if len(result) == 1 else list()

async def load_field_items(self, ctx_id: str, field_name: str, keys: Set[Hashable]) -> List[bytes]:
field_table, field_name, _ = self._get_config_for_field(field_name)
field_table, key_name, _ = self._get_config_for_field(field_name)
result = await field_table.find(
{self._id_column_name: ctx_id, self._key_column_name: {"$in": list(keys)}, field_name: {"$exists": True, "$ne": None}},
[self._key_column_name, field_name]
{self._id_column_name: ctx_id, self._key_column_name: {"$in": list(keys)}, key_name: {"$exists": True, "$ne": None}},
[self._key_column_name, key_name]
).to_list(None)
return [(item[self._key_column_name], item[field_name]) for item in result]
return [(item[self._key_column_name], item[key_name]) for item in result]

async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[Hashable, bytes]]) -> None:
field_table, field_name, _ = self._get_config_for_field(field_name)
field_table, key_name, _ = self._get_config_for_field(field_name)
if len(items) == 0:
return
await field_table.bulk_write(
[
UpdateOne(
{self._id_column_name: ctx_id, self._key_column_name: k},
{"$set": {field_name: v}},
{"$set": {key_name: v}},
upsert=True,
) for k, v in items
]
)

async def clear_all(self) -> None:
await asyncio.gather(
self._main_table.delete_many({}),
self._turns_table.delete_many({}),
self._misc_table.delete_many({})
self.main_table.delete_many({}),
self.turns_table.delete_many({}),
self.misc_table.delete_many({})
)
44 changes: 22 additions & 22 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
raise ImportError("`redis` package is missing.\n" + install_suggestion)
if not bool(key_prefix):
raise ValueError("`key_prefix` parameter shouldn't be empty")
self._redis = Redis.from_url(self.full_path)
self.database = Redis.from_url(self.full_path)

self._prefix = key_prefix
self._main_key = f"{key_prefix}:{self._main_table_name}"
Expand Down Expand Up @@ -97,63 +97,63 @@ def _get_config_for_field(self, field_name: str, ctx_id: str) -> Tuple[str, Call
raise ValueError(f"Unknown field name: {field_name}!")

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes]]:
if await self._redis.exists(f"{self._main_key}:{ctx_id}"):
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
cti, ca, ua, fd = await gather(
self._redis.hget(f"{self._main_key}:{ctx_id}", self._current_turn_id_column_name),
self._redis.hget(f"{self._main_key}:{ctx_id}", self._created_at_column_name),
self._redis.hget(f"{self._main_key}:{ctx_id}", self._updated_at_column_name),
self._redis.hget(f"{self._main_key}:{ctx_id}", self._framework_data_column_name)
self.database.hget(f"{self._main_key}:{ctx_id}", self._current_turn_id_column_name),
self.database.hget(f"{self._main_key}:{ctx_id}", self._created_at_column_name),
self.database.hget(f"{self._main_key}:{ctx_id}", self._updated_at_column_name),
self.database.hget(f"{self._main_key}:{ctx_id}", self._framework_data_column_name)
)
return (int(cti), int(ca), int(ua), fd)
else:
return None

async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at: int, fw_data: bytes) -> None:
await gather(
self._redis.hset(f"{self._main_key}:{ctx_id}", self._current_turn_id_column_name, str(turn_id)),
self._redis.hset(f"{self._main_key}:{ctx_id}", self._created_at_column_name, str(crt_at)),
self._redis.hset(f"{self._main_key}:{ctx_id}", self._updated_at_column_name, str(upd_at)),
self._redis.hset(f"{self._main_key}:{ctx_id}", self._framework_data_column_name, fw_data)
self.database.hset(f"{self._main_key}:{ctx_id}", self._current_turn_id_column_name, str(turn_id)),
self.database.hset(f"{self._main_key}:{ctx_id}", self._created_at_column_name, str(crt_at)),
self.database.hset(f"{self._main_key}:{ctx_id}", self._updated_at_column_name, str(upd_at)),
self.database.hset(f"{self._main_key}:{ctx_id}", self._framework_data_column_name, fw_data)
)

async def delete_context(self, ctx_id: str) -> None:
keys = await self._redis.keys(f"{self._prefix}:*:{ctx_id}*")
keys = await self.database.keys(f"{self._prefix}:*:{ctx_id}*")
if len(keys) > 0:
await self._redis.delete(*keys)
await self.database.delete(*keys)

async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[Hashable, bytes]]:
field_key, field_converter, field_config = self._get_config_for_field(field_name, ctx_id)
keys = await self._redis.hkeys(field_key)
keys = await self.database.hkeys(field_key)
if field_key.startswith(self._turns_key):
keys = sorted(keys, key=lambda k: int(k), reverse=True)
if isinstance(field_config.subscript, int):
keys = keys[:field_config.subscript]
elif isinstance(field_config.subscript, Set):
keys = [k for k in keys if k in self._keys_to_bytes(field_config.subscript)]
values = await gather(*[self._redis.hget(field_key, k) for k in keys])
values = await gather(*[self.database.hget(field_key, k) for k in keys])
return [(k, v) for k, v in zip(field_converter(keys), values)]

async def load_field_keys(self, ctx_id: str, field_name: str) -> List[Hashable]:
field_key, field_converter, _ = self._get_config_for_field(field_name, ctx_id)
return field_converter(await self._redis.hkeys(field_key))
return field_converter(await self.database.hkeys(field_key))

async def load_field_items(self, ctx_id: str, field_name: str, keys: List[Hashable]) -> List[Tuple[Hashable, bytes]]:
field_key, field_converter, _ = self._get_config_for_field(field_name, ctx_id)
load = [k for k in await self._redis.hkeys(field_key) if k in self._keys_to_bytes(keys)]
values = await gather(*[self._redis.hget(field_key, k) for k in load])
load = [k for k in await self.database.hkeys(field_key) if k in self._keys_to_bytes(keys)]
values = await gather(*[self.database.hget(field_key, k) for k in load])
return [(k, v) for k, v in zip(field_converter(load), values)]

async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[Hashable, bytes]]) -> None:
field_key, _, _ = self._get_config_for_field(field_name, ctx_id)
await gather(*[self._redis.hset(field_key, str(k), v) for k, v in items])
await gather(*[self.database.hset(field_key, str(k), v) for k, v in items])

async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[Hashable]) -> None:
field_key, _, _ = self._get_config_for_field(field_name, ctx_id)
match = [k for k in await self._redis.hkeys(field_key) if k in self._keys_to_bytes(keys)]
match = [k for k in await self.database.hkeys(field_key) if k in self._keys_to_bytes(keys)]
if len(match) > 0:
await self._redis.hdel(field_key, *match)
await self.database.hdel(field_key, *match)

async def clear_all(self) -> None:
keys = await self._redis.keys(f"{self._prefix}:*")
keys = await self.database.keys(f"{self._prefix}:*")
if len(keys) > 0:
await self._redis.delete(*keys)
await self.database.delete(*keys)
Loading

0 comments on commit 782bf66

Please sign in to comment.