Skip to content

Commit

Permalink
update benchmark utils
Browse files Browse the repository at this point in the history
  • Loading branch information
RLKRo committed Oct 11, 2024
1 parent d3af3b2 commit de739f2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 36 deletions.
36 changes: 21 additions & 15 deletions chatsky/utils/db_benchmark/basic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pympler import asizeof

from chatsky.core import Message, Context, AbsoluteNodeLabel
from chatsky.context_storages import MemoryContextStorage
from chatsky.utils.db_benchmark.benchmark import BenchmarkConfig


Expand Down Expand Up @@ -59,7 +60,8 @@ def get_message(message_dimensions: Tuple[int, ...]):
return Message(misc=get_dict(message_dimensions))


def get_context(
async def get_context(
db,
dialog_len: int,
message_dimensions: Tuple[int, ...],
misc_dimensions: Tuple[int, ...],
Expand All @@ -73,12 +75,16 @@ def get_context(
:param misc_dimensions:
A parameter used to generate misc field. See :py:func:`~.get_dict`.
"""
return Context(
labels={i: (f"flow_{i}", f"node_{i}") for i in range(dialog_len)},
requests={i: get_message(message_dimensions) for i in range(dialog_len)},
responses={i: get_message(message_dimensions) for i in range(dialog_len)},
misc=get_dict(misc_dimensions),
)
ctx = await Context.connected(db, start_label=("flow", "node"))
ctx.current_turn_id = -1
for i in range(dialog_len):
ctx.current_turn_id += 1
ctx.labels[ctx.current_turn_id] = AbsoluteNodeLabel(flow_name=f"flow_{i}", node_name=f"node_{i}")
ctx.requests[ctx.current_turn_id] = get_message(message_dimensions)
ctx.responses[ctx.current_turn_id] = get_message(message_dimensions)
await ctx.misc.update(get_dict(misc_dimensions))

return ctx


class BasicBenchmarkConfig(BenchmarkConfig, frozen=True):
Expand Down Expand Up @@ -121,15 +127,15 @@ class BasicBenchmarkConfig(BenchmarkConfig, frozen=True):
See :py:func:`~.get_dict`.
"""

def get_context(self) -> Context:
async def get_context(self, db) -> Context:
"""
Return context with `from_dialog_len`, `message_dimensions`, `misc_dimensions`.
Wraps :py:func:`~.get_context`.
"""
return get_context(self.from_dialog_len, self.message_dimensions, self.misc_dimensions)
return await get_context(db, self.from_dialog_len, self.message_dimensions, self.misc_dimensions)

def info(self):
async def info(self):
"""
Return fields of this instance and sizes of objects defined by this config.
Expand All @@ -147,17 +153,17 @@ def info(self):
return {
"params": self.model_dump(),
"sizes": {
"starting_context_size": naturalsize(asizeof.asizeof(self.get_context()), gnu=True),
"starting_context_size": naturalsize(asizeof.asizeof(await self.get_context(MemoryContextStorage())), gnu=True),
"final_context_size": naturalsize(
asizeof.asizeof(get_context(self.to_dialog_len, self.message_dimensions, self.misc_dimensions)),
asizeof.asizeof(await get_context(MemoryContextStorage(), self.to_dialog_len, self.message_dimensions, self.misc_dimensions)),
gnu=True,
),
"misc_size": naturalsize(asizeof.asizeof(get_dict(self.misc_dimensions)), gnu=True),
"message_size": naturalsize(asizeof.asizeof(get_message(self.message_dimensions)), gnu=True),
},
}

def context_updater(self, context: Context) -> Optional[Context]:
async def context_updater(self, context: Context) -> Optional[Context]:
"""
Update context to have `step_dialog_len` more labels, requests and responses,
unless such dialog len would be equal to `to_dialog_len` or exceed than it,
Expand All @@ -166,10 +172,10 @@ def context_updater(self, context: Context) -> Optional[Context]:
start_len = len(context.labels)
if start_len + self.step_dialog_len < self.to_dialog_len:
for i in range(start_len, start_len + self.step_dialog_len):
context.current_turn_id = context.current_turn_id + 1
context.current_turn_id += 1
context.labels[context.current_turn_id] = AbsoluteNodeLabel(flow_name="flow_{i}", node_name="node_{i}")
context.requests[context.current_turn_id] = get_message(self.message_dimensions)
context.responses[context.current_turn_id] =get_message(self.message_dimensions)
context.responses[context.current_turn_id] = get_message(self.message_dimensions)
return context
else:
return None
Expand Down
41 changes: 20 additions & 21 deletions chatsky/utils/db_benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from uuid import uuid4
from pathlib import Path
from time import perf_counter
from typing import Tuple, List, Dict, Union, Optional, Callable, Any
from typing import Tuple, List, Dict, Union, Optional, Callable, Any, Awaitable
import json
import importlib
from statistics import mean
import abc
from traceback import extract_tb, StackSummary
import asyncio

from pydantic import BaseModel, Field
from tqdm.auto import tqdm
Expand All @@ -36,11 +37,11 @@
from chatsky.core import Context


def time_context_read_write(
async def time_context_read_write(
context_storage: DBContextStorage,
context_factory: Callable[[], Context],
context_factory: Callable[[DBContextStorage], Awaitable[Context]],
context_num: int,
context_updater: Optional[Callable[[Context], Optional[Context]]] = None,
context_updater: Optional[Callable[[Context], Awaitable[Optional[Context]]]] = None,
) -> Tuple[List[float], List[Dict[int, float]], List[Dict[int, float]]]:
"""
Benchmark `context_storage` by writing and reading `context`\\s generated by `context_factory`
Expand Down Expand Up @@ -81,48 +82,46 @@ def time_context_read_write(
dialog_len of the context returned by `context_factory`.
So if `context_updater` is None, all dictionaries will be empty.
"""
context_storage.clear()
await context_storage.clear_all()

write_times: List[float] = []
read_times: List[Dict[int, float]] = []
update_times: List[Dict[int, float]] = []

for _ in tqdm(range(context_num), desc=f"Benchmarking context storage:{context_storage.full_path}", leave=False):
context = context_factory()

ctx_id = uuid4()
context = await context_factory(context_storage)

# write operation benchmark
write_start = perf_counter()
context_storage[ctx_id] = context
await context.store()
write_times.append(perf_counter() - write_start)

read_times.append({})
update_times.append({})

# read operation benchmark
read_start = perf_counter()
_ = context_storage[ctx_id]
_ = await Context.connected(context_storage, start_label=("flow", "node"), id=context.id)
read_time = perf_counter() - read_start
read_times[-1][len(context.labels)] = read_time

if context_updater is not None:
updated_context = context_updater(context)
updated_context = await context_updater(context)

while updated_context is not None:
update_start = perf_counter()
context_storage[ctx_id] = updated_context
await updated_context.store()
update_time = perf_counter() - update_start
update_times[-1][len(updated_context.labels)] = update_time

read_start = perf_counter()
_ = context_storage[ctx_id]
_ = await Context.connected(context_storage, start_label=("flow", "node"), id=updated_context.id)
read_time = perf_counter() - read_start
read_times[-1][len(updated_context.labels)] = read_time

updated_context = context_updater(updated_context)
updated_context = await context_updater(updated_context)

context_storage.clear()
await context_storage.clear_all()
return write_times, read_times, update_times


Expand Down Expand Up @@ -167,7 +166,7 @@ class BenchmarkConfig(BaseModel, abc.ABC, frozen=True):
"""

@abc.abstractmethod
def get_context(self) -> Context:
async def get_context(self, db: DBContextStorage) -> Context:
"""
Return context to benchmark read and write operations with.
Expand All @@ -176,14 +175,14 @@ def get_context(self) -> Context:
...

@abc.abstractmethod
def info(self) -> Dict[str, Any]:
async def info(self) -> Dict[str, Any]:
"""
Return a dictionary with information about this configuration.
"""
...

@abc.abstractmethod
def context_updater(self, context: Context) -> Optional[Context]:
async def context_updater(self, context: Context) -> Optional[Context]:
"""
Update context with new dialog turns or return `None` to stop updates.
Expand Down Expand Up @@ -287,12 +286,12 @@ def get_complex_stats(results):

def _run(self):
try:
write_times, read_times, update_times = time_context_read_write(
write_times, read_times, update_times = asyncio.run(time_context_read_write(
self.db_factory.db(),
self.benchmark_config.get_context,
self.benchmark_config.context_num,
self.benchmark_config.context_updater,
)
))
return {
"success": True,
"result": {
Expand Down Expand Up @@ -369,7 +368,7 @@ def save_results_to_file(
result["benchmarks"].append(
{
**case.model_dump(exclude={"benchmark_config"}),
"benchmark_config": case.benchmark_config.info(),
"benchmark_config": asyncio.run(case.benchmark_config.info()),
**case.run(),
}
)
Expand Down

0 comments on commit de739f2

Please sign in to comment.