Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Speed up _determine_reference_key() by 30% in libs/langchain/langchain/smith/evaluation/runner_utils.py #27

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from langchain_core.tracers.langchain import LangChainTracer
from langsmith.client import Client
from langsmith.env import get_git_info, get_langchain_env_var_metadata
from langsmith.env import get_git_info
from langsmith.evaluation import EvaluationResult, RunEvaluator
from langsmith.run_helpers import as_runnable, is_traceable_function
from langsmith.schemas import Dataset, DataType, Example, TracerSession
Expand Down Expand Up @@ -499,17 +499,17 @@ def _determine_reference_key(
example_outputs: Optional[List[str]],
) -> Optional[str]:
if config.reference_key:
reference_key = config.reference_key
if example_outputs and reference_key not in example_outputs:
if example_outputs and config.reference_key not in example_outputs:
raise ValueError(
f"Reference key {reference_key} not in Dataset"
f"Reference key {config.reference_key} not in Dataset"
f" example outputs: {example_outputs}"
)
elif example_outputs and len(example_outputs) == 1:
reference_key = list(example_outputs)[0]
else:
reference_key = None
return reference_key
return config.reference_key

if example_outputs and len(example_outputs) == 1:
return example_outputs[0]

return None


def _construct_run_evaluator(
Expand Down Expand Up @@ -1227,8 +1227,6 @@ async def arun_on_dataset(
input_mapper = kwargs.pop("input_mapper", None)
if input_mapper:
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id")

if kwargs:
warn_deprecated(
Expand Down Expand Up @@ -1283,8 +1281,6 @@ def run_on_dataset(
input_mapper = kwargs.pop("input_mapper", None)
if input_mapper:
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id")

if kwargs:
warn_deprecated(
Expand Down