Skip to content

Commit

Permalink
Only include necessary information in prompt for GenAI metrics (mlflo…
Browse files Browse the repository at this point in the history
…w#10698)

Signed-off-by: Roshni Malani <[email protected]>
Co-authored-by: Roshni Malani <[email protected]>
  • Loading branch information
2 people authored and B-Step62 committed Jan 9, 2024
1 parent 8bc8297 commit b23f8e1
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 79 deletions.
41 changes: 25 additions & 16 deletions mlflow/metrics/genai/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union

from mlflow.metrics.genai.prompt_template import PromptTemplate
from mlflow.utils.annotations import experimental


Expand Down Expand Up @@ -58,10 +59,10 @@ class EvaluationExample:
"its purpose, and its developer. It could be more concise for a 5-score."
"""

input: str
output: str
score: float
justification: str
input: Optional[str] = None
grading_context: Optional[Union[Dict[str, str], str]] = None

def _format_grading_context(self):
Expand All @@ -73,21 +74,29 @@ def _format_grading_context(self):
return self.grading_context

def __str__(self) -> str:
grading_context = (
""
if self.grading_context is None
else "Additional information used by the model:\n" f"{self._format_grading_context()}"
)

return f"""
return PromptTemplate(
[
"""
Example Input:
{self.input}
{input}
""",
"""
Example Output:
{self.output}
{output}
""",
"""
Additional information used by the model:
{grading_context}
Example score: {self.score}
Example justification: {self.justification}
"""
""",
"""
Example score: {score}
Example justification: {justification}
""",
]
).format(
input=self.input,
output=self.output,
grading_context=self._format_grading_context(),
score=self.score,
justification=self.justification,
)
14 changes: 13 additions & 1 deletion mlflow/metrics/genai/genai_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def make_genai_metric(
version: Optional[str] = _get_latest_metric_version(),
model: Optional[str] = _get_default_model(),
grading_context_columns: Optional[Union[str, List[str]]] = [], # noqa: B006
include_input: bool = True,
parameters: Optional[Dict[str, Any]] = None,
aggregations: Optional[List[str]] = ["mean", "variance", "p90"], # noqa: B006
greater_is_better: bool = True,
Expand All @@ -112,6 +113,7 @@ def make_genai_metric(
``grading_context_columns`` are used by the LLM as a judge as additional information to
compute the metric. The columns are extracted from the input dataset or output predictions
based on ``col_mapping`` in the ``evaluator_config`` passed to :py:func:`mlflow.evaluate()`.
:param include_input: (Optional) Whether to include the input when computing the metric.
:param parameters: (Optional) Parameters for the LLM used to compute the metric. By default, we
set the temperature to 0.0, max_tokens to 200, and top_p to 1.0. We recommend
setting the temperature to 0.0 for the LLM used as a judge to ensure consistent results.
Expand Down Expand Up @@ -205,6 +207,14 @@ def process_example(example):
f" Required grading context columns: {grading_context_columns}\n"
)

if not include_input:
return EvaluationExample(
output=example.output,
score=example.score,
justification=example.justification,
grading_context=example.grading_context,
)

return example

if examples is not None:
Expand Down Expand Up @@ -280,7 +290,9 @@ def eval_fn(
)
grading_payloads.append(
evaluation_context["eval_prompt"].format(
input=input, output=output, grading_context_columns=arg_string
input=(input if include_input else None),
output=output,
grading_context_columns=arg_string,
)
)

Expand Down
2 changes: 2 additions & 0 deletions mlflow/metrics/genai/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def answer_similarity(
name="answer_similarity",
definition=answer_similarity_class_module.definition,
grading_prompt=answer_similarity_class_module.grading_prompt,
include_input=False,
examples=examples,
version=metric_version,
model=model,
Expand Down Expand Up @@ -190,6 +191,7 @@ def faithfulness(
name="faithfulness",
definition=faithfulness_class_module.definition,
grading_prompt=faithfulness_class_module.grading_prompt,
include_input=False,
examples=examples,
version=metric_version,
model=model,
Expand Down
65 changes: 30 additions & 35 deletions mlflow/metrics/genai/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import string
from typing import Any, List, Optional

from mlflow.exceptions import MlflowException
from typing import Any, List, Union


class PromptTemplate:
"""A prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
A prompt template consists of an array of strings that will be concatenated together. It accepts
a set of parameters from the user that can be used to generate a prompt for a language model.
The template can be formatted using f-strings.
Expand All @@ -19,7 +17,7 @@ class PromptTemplate:
from mlflow.metrics.genai.prompt_template import PromptTemplate
# Instantiation using initializer
prompt = PromptTemplate(template_str="Say {foo} {baz}", variables=["foo", "baz"])
prompt = PromptTemplate(template_str="Say {foo} {baz}")
# Instantiation using partial_fill
prompt = PromptTemplate(template_str="Say {foo} {baz}").partial_fill(foo="bar")
Expand All @@ -28,37 +26,34 @@ class PromptTemplate:
prompt.format(baz="qux")
"""

def __init__(self, template_str: str, variables: Optional[List[str]] = None):
self.template_str = template_str
extracted_variables = [
fname for _, fname, _, _ in string.Formatter().parse(template_str) if fname
]
if variables:
if not all(item in variables for item in extracted_variables):
raise MlflowException(
f"The provided variables {variables} are not a subset of "
f"the extracted variables {extracted_variables} from the template string"
)
self.variables = variables
else:
# Automatically parse variables from template string
self.variables = extracted_variables
def __init__(self, template_str: Union[str, List[str]]):
self.template_strs = [template_str] if isinstance(template_str, str) else template_str

def format(self, **kwargs: Any) -> str:
# Only keep the kwargs that are in the variables
kwargs = {k: v for k, v in kwargs.items() if k in self.variables}
safe_kwargs = {k: v for k, v in kwargs.items() if v is not None}
formatted_strs = []
for template_str in self.template_strs:
extracted_variables = [
fname for _, fname, _, _ in string.Formatter().parse(template_str) if fname
]
if all(item in safe_kwargs.keys() for item in extracted_variables):
formatted_strs.append(template_str.format(**safe_kwargs))

# Format the prompt with the provided values
return self.template_str.format(**kwargs)
return "".join(formatted_strs)

def partial_fill(self, **kwargs: Any) -> "PromptTemplate":
# Create a safe dictionary that returns the key if it doesn't exist in the dictionary
safe_dict = {k: kwargs.get(k, "{" + k + "}") for k in self.variables}

# Fill in the provided values, and return a new PromptTemplate
new_template_str = self.template_str.format_map(safe_dict)
unfilled_variables = [var for var in self.variables if var not in kwargs.keys()]
return PromptTemplate(template_str=new_template_str, variables=unfilled_variables)

def __str__(self):
return self.template_str
safe_kwargs = {k: v for k, v in kwargs.items() if v is not None}
new_template_strs = []
for template_str in self.template_strs:
extracted_variables = [
fname for _, fname, _, _ in string.Formatter().parse(template_str) if fname
]
safe_available_kwargs = {
k: safe_kwargs.get(k, "{" + k + "}") for k in extracted_variables
}
new_template_strs.append(template_str.format_map(safe_available_kwargs))

return PromptTemplate(template_str=new_template_strs)

def __str__(self) -> str:
return "".join(self.template_strs)
13 changes: 9 additions & 4 deletions mlflow/metrics/genai/prompts/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
"max_tokens": 200,
"top_p": 1.0,
}

grading_system_prompt_template = PromptTemplate(
"""
[
"""
Task:
You must return the following fields in your response in two lines, one below the other:
score: Your numerical score for the model's {name} based on the rubric
Expand All @@ -28,10 +30,12 @@
You must use the grading rubric to determine your score. You must also justify your score.
Examples could be included below for reference. Make sure to use them as references and to
understand them before completing the task.
understand them before completing the task.""",
"""
Input:
{input}
{input}""",
"""
Output:
{output}
Expand All @@ -51,7 +55,8 @@
justification: Your reasoning about the model's {name} score
Do not add additional new lines. Do not add any other fields.
"""
""",
]
)


Expand Down
9 changes: 3 additions & 6 deletions tests/metrics/genai/test_genai_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ def test_extract_score_and_justification():
assert justification6 == "This is a justification"


def test_correctness_metric():
correctness_metric = answer_similarity(
def test_similarity_metric():
similarity_metric = answer_similarity(
model="gateway:/gpt-3.5-turbo", metric_version="v1", examples=[mlflow_example]
)

Expand All @@ -637,7 +637,7 @@ def test_correctness_metric():
"score_model_on_payload",
return_value=properly_formatted_openai_response1,
) as mock_predict_function:
metric_value = correctness_metric.eval_fn(
metric_value = similarity_metric.eval_fn(
pd.Series([mlflow_prediction]), {}, pd.Series([input]), pd.Series([mlflow_ground_truth])
)

Expand All @@ -658,14 +658,12 @@ def test_correctness_metric():
"grading rubric to determine your score. You must also justify your score."
"\n\nExamples could be included below for reference. Make sure to use them as "
"references and to\nunderstand them before completing the task.\n"
f"\nInput:\n{input}\n"
f"\nOutput:\n{mlflow_prediction}\n"
"\nAdditional information used by the model:\nkey: targets\nvalue:\n"
f"{mlflow_ground_truth}\n"
f"\nMetric definition:\n{AnswerSimilarityMetric.definition}\n"
f"\nGrading rubric:\n{AnswerSimilarityMetric.grading_prompt}\n"
"\nExamples:\n"
f"\nExample Input:\n{mlflow_example.input}\n"
f"\nExample Output:\n{mlflow_example.output}\n"
"\nAdditional information used by the model:\nkey: targets\nvalue:\n"
f"{mlflow_ground_truth}\n"
Expand Down Expand Up @@ -735,7 +733,6 @@ def test_faithfulness_metric():
"grading rubric to determine your score. You must also justify your score."
"\n\nExamples could be included below for reference. Make sure to use them as "
"references and to\nunderstand them before completing the task.\n"
f"\nInput:\n{input}\n"
f"\nOutput:\n{mlflow_prediction}\n"
"\nAdditional information used by the model:\nkey: context\nvalue:\n"
f"{mlflow_ground_truth}\n"
Expand Down
55 changes: 38 additions & 17 deletions tests/metrics/genai/test_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,57 @@
import re

import pytest

from mlflow.exceptions import MlflowException
from mlflow.metrics.genai.prompt_template import PromptTemplate


def test_prompt_template_formatting():
def test_prompt_template_flat_str_no_variables():
prompt = PromptTemplate(template_str="Say {foo}")
assert prompt.format(foo="bar") == "Say bar"

prompt = PromptTemplate(template_str="Say {foo} {baz}")
with pytest.raises(KeyError, match="baz"):
prompt.format(foo="bar")
assert prompt.format(foo="bar") == ""

prompt = PromptTemplate(template_str="Say foo")
assert prompt.format() == "Say foo"

with pytest.raises(
MlflowException,
match=re.escape(
"The provided variables ['foo'] are not a subset of "
"the extracted variables ['foo', 'baz'] from the template string"
),
):
prompt = PromptTemplate(template_str="Say {foo} {baz}", variables=["foo"])

prompt = PromptTemplate(template_str="Say {foo} {foo}")
assert prompt.format(foo="bar") == "Say bar bar"

prompt = PromptTemplate(template_str="Say {foo}")
assert prompt.format(foo=None) == ""


def test_prompt_template_arr_str_no_variables():
prompt = PromptTemplate(template_str=["Say {foo}"])
assert prompt.format(foo="bar") == "Say bar"

prompt = PromptTemplate(template_str=["Say {foo} {baz}"])
assert prompt.format(foo="bar") == ""

prompt = PromptTemplate(template_str=["Say {foo}", " {baz}"])
assert prompt.format(foo="bar") == "Say bar"

prompt = PromptTemplate(template_str=["Say {foo}", " {baz}"])
assert prompt.format(baz="qux") == " qux"

prompt = PromptTemplate(template_str=["Say foo", ", and say bar"])
assert prompt.format() == "Say foo, and say bar"

prompt = PromptTemplate(template_str=["Say {foo} {foo}", ", and {foo}", ", and {foo} again"])
assert prompt.format(foo="bar") == "Say bar bar, and bar, and bar again"

prompt = PromptTemplate(template_str=["Say {foo}", " {baz}"])
assert prompt.format(foo="bar", baz="qux") == "Say bar qux"

prompt = PromptTemplate(template_str=["Say {foo}", " {baz}"])
assert prompt.format(foo="bar", baz=None) == "Say bar"


def test_prompt_template_partial_formatting():
prompt = PromptTemplate(template_str="Say {foo} {baz}")
partial_prompt = prompt.partial_fill(foo="bar")
assert partial_prompt.format(baz="qux") == "Say bar qux"
assert partial_prompt.format(baz=None) == ""

prompt = PromptTemplate(template_str=["Say {foo}", " {baz}"])
partial_prompt = prompt.partial_fill(foo="bar")
assert partial_prompt.format(baz="qux") == "Say bar qux"
assert partial_prompt.format(baz=None) == "Say bar"
assert partial_prompt.format() == "Say bar"

0 comments on commit b23f8e1

Please sign in to comment.