From de739f260b9d572894c0c67424bb2f8db9c341fe Mon Sep 17 00:00:00 2001 From: Roman Zlobin Date: Fri, 11 Oct 2024 16:32:38 +0300 Subject: [PATCH] update benchmark utils --- chatsky/utils/db_benchmark/basic_config.py | 36 +++++++++++-------- chatsky/utils/db_benchmark/benchmark.py | 41 +++++++++++----------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/chatsky/utils/db_benchmark/basic_config.py b/chatsky/utils/db_benchmark/basic_config.py index 2b329895d..7b328ea39 100644 --- a/chatsky/utils/db_benchmark/basic_config.py +++ b/chatsky/utils/db_benchmark/basic_config.py @@ -16,6 +16,7 @@ from pympler import asizeof from chatsky.core import Message, Context, AbsoluteNodeLabel +from chatsky.context_storages import MemoryContextStorage from chatsky.utils.db_benchmark.benchmark import BenchmarkConfig @@ -59,7 +60,8 @@ def get_message(message_dimensions: Tuple[int, ...]): return Message(misc=get_dict(message_dimensions)) -def get_context( +async def get_context( + db, dialog_len: int, message_dimensions: Tuple[int, ...], misc_dimensions: Tuple[int, ...], @@ -73,12 +75,16 @@ def get_context( :param misc_dimensions: A parameter used to generate misc field. See :py:func:`~.get_dict`. """ - return Context( - labels={i: (f"flow_{i}", f"node_{i}") for i in range(dialog_len)}, - requests={i: get_message(message_dimensions) for i in range(dialog_len)}, - responses={i: get_message(message_dimensions) for i in range(dialog_len)}, - misc=get_dict(misc_dimensions), - ) + ctx = await Context.connected(db, start_label=("flow", "node")) + ctx.current_turn_id = -1 + for i in range(dialog_len): + ctx.current_turn_id += 1 + ctx.labels[ctx.current_turn_id] = AbsoluteNodeLabel(flow_name=f"flow_{i}", node_name=f"node_{i}") + ctx.requests[ctx.current_turn_id] = get_message(message_dimensions) + ctx.responses[ctx.current_turn_id] = get_message(message_dimensions) + await ctx.misc.update(get_dict(misc_dimensions)) + + return ctx class BasicBenchmarkConfig(BenchmarkConfig, frozen=True): @@ -121,15 +127,15 @@ class BasicBenchmarkConfig(BenchmarkConfig, frozen=True): See :py:func:`~.get_dict`. """ - def get_context(self) -> Context: + async def get_context(self, db) -> Context: """ Return context with `from_dialog_len`, `message_dimensions`, `misc_dimensions`. Wraps :py:func:`~.get_context`. """ - return get_context(self.from_dialog_len, self.message_dimensions, self.misc_dimensions) + return await get_context(db, self.from_dialog_len, self.message_dimensions, self.misc_dimensions) - def info(self): + async def info(self): """ Return fields of this instance and sizes of objects defined by this config. @@ -147,9 +153,9 @@ def info(self): return { "params": self.model_dump(), "sizes": { - "starting_context_size": naturalsize(asizeof.asizeof(self.get_context()), gnu=True), + "starting_context_size": naturalsize(asizeof.asizeof(await self.get_context(MemoryContextStorage())), gnu=True), "final_context_size": naturalsize( - asizeof.asizeof(get_context(self.to_dialog_len, self.message_dimensions, self.misc_dimensions)), + asizeof.asizeof(await get_context(MemoryContextStorage(), self.to_dialog_len, self.message_dimensions, self.misc_dimensions)), gnu=True, ), "misc_size": naturalsize(asizeof.asizeof(get_dict(self.misc_dimensions)), gnu=True), @@ -157,7 +163,7 @@ def info(self): }, } - def context_updater(self, context: Context) -> Optional[Context]: + async def context_updater(self, context: Context) -> Optional[Context]: """ Update context to have `step_dialog_len` more labels, requests and responses, unless such dialog len would be equal to `to_dialog_len` or exceed than it, @@ -166,10 +172,10 @@ def context_updater(self, context: Context) -> Optional[Context]: start_len = len(context.labels) if start_len + self.step_dialog_len < self.to_dialog_len: for i in range(start_len, start_len + self.step_dialog_len): - context.current_turn_id = context.current_turn_id + 1 + context.current_turn_id += 1 context.labels[context.current_turn_id] = AbsoluteNodeLabel(flow_name="flow_{i}", node_name="node_{i}") context.requests[context.current_turn_id] = get_message(self.message_dimensions) - context.responses[context.current_turn_id] =get_message(self.message_dimensions) + context.responses[context.current_turn_id] = get_message(self.message_dimensions) return context else: return None diff --git a/chatsky/utils/db_benchmark/benchmark.py b/chatsky/utils/db_benchmark/benchmark.py index fee678e66..b830a62c3 100644 --- a/chatsky/utils/db_benchmark/benchmark.py +++ b/chatsky/utils/db_benchmark/benchmark.py @@ -22,12 +22,13 @@ from uuid import uuid4 from pathlib import Path from time import perf_counter -from typing import Tuple, List, Dict, Union, Optional, Callable, Any +from typing import Tuple, List, Dict, Union, Optional, Callable, Any, Awaitable import json import importlib from statistics import mean import abc from traceback import extract_tb, StackSummary +import asyncio from pydantic import BaseModel, Field from tqdm.auto import tqdm @@ -36,11 +37,11 @@ from chatsky.core import Context -def time_context_read_write( +async def time_context_read_write( context_storage: DBContextStorage, - context_factory: Callable[[], Context], + context_factory: Callable[[DBContextStorage], Awaitable[Context]], context_num: int, - context_updater: Optional[Callable[[Context], Optional[Context]]] = None, + context_updater: Optional[Callable[[Context], Awaitable[Optional[Context]]]] = None, ) -> Tuple[List[float], List[Dict[int, float]], List[Dict[int, float]]]: """ Benchmark `context_storage` by writing and reading `context`\\s generated by `context_factory` @@ -81,20 +82,18 @@ def time_context_read_write( dialog_len of the context returned by `context_factory`. So if `context_updater` is None, all dictionaries will be empty. """ - context_storage.clear() + await context_storage.clear_all() write_times: List[float] = [] read_times: List[Dict[int, float]] = [] update_times: List[Dict[int, float]] = [] for _ in tqdm(range(context_num), desc=f"Benchmarking context storage:{context_storage.full_path}", leave=False): - context = context_factory() - - ctx_id = uuid4() + context = await context_factory(context_storage) # write operation benchmark write_start = perf_counter() - context_storage[ctx_id] = context + await context.store() write_times.append(perf_counter() - write_start) read_times.append({}) @@ -102,27 +101,27 @@ def time_context_read_write( # read operation benchmark read_start = perf_counter() - _ = context_storage[ctx_id] + _ = await Context.connected(context_storage, start_label=("flow", "node"), id=context.id) read_time = perf_counter() - read_start read_times[-1][len(context.labels)] = read_time if context_updater is not None: - updated_context = context_updater(context) + updated_context = await context_updater(context) while updated_context is not None: update_start = perf_counter() - context_storage[ctx_id] = updated_context + await updated_context.store() update_time = perf_counter() - update_start update_times[-1][len(updated_context.labels)] = update_time read_start = perf_counter() - _ = context_storage[ctx_id] + _ = await Context.connected(context_storage, start_label=("flow", "node"), id=updated_context.id) read_time = perf_counter() - read_start read_times[-1][len(updated_context.labels)] = read_time - updated_context = context_updater(updated_context) + updated_context = await context_updater(updated_context) - context_storage.clear() + await context_storage.clear_all() return write_times, read_times, update_times @@ -167,7 +166,7 @@ class BenchmarkConfig(BaseModel, abc.ABC, frozen=True): """ @abc.abstractmethod - def get_context(self) -> Context: + async def get_context(self, db: DBContextStorage) -> Context: """ Return context to benchmark read and write operations with. @@ -176,14 +175,14 @@ def get_context(self) -> Context: ... @abc.abstractmethod - def info(self) -> Dict[str, Any]: + async def info(self) -> Dict[str, Any]: """ Return a dictionary with information about this configuration. """ ... @abc.abstractmethod - def context_updater(self, context: Context) -> Optional[Context]: + async def context_updater(self, context: Context) -> Optional[Context]: """ Update context with new dialog turns or return `None` to stop updates. @@ -287,12 +286,12 @@ def get_complex_stats(results): def _run(self): try: - write_times, read_times, update_times = time_context_read_write( + write_times, read_times, update_times = asyncio.run(time_context_read_write( self.db_factory.db(), self.benchmark_config.get_context, self.benchmark_config.context_num, self.benchmark_config.context_updater, - ) + )) return { "success": True, "result": { @@ -369,7 +368,7 @@ def save_results_to_file( result["benchmarks"].append( { **case.model_dump(exclude={"benchmark_config"}), - "benchmark_config": case.benchmark_config.info(), + "benchmark_config": asyncio.run(case.benchmark_config.info()), **case.run(), } )