Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Add annotator test runner + LlamaGuard2, Llama 3 70b annotator test #451

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
968e281
Init first working test
tsunamit Jun 10, 2024
20a5e2d
Enable system chat role. Add mistral 8x22b and llama 3 70b to togethe…
tsunamit Jun 14, 2024
11d9854
undo debug code
tsunamit Jun 14, 2024
1bba2b3
Black formatting
tsunamit Jun 14, 2024
7977661
Update typing issues
tsunamit Jun 14, 2024
3c7a349
Merge branch 'main' of https://github.com/mlcommons/modelgauge into r…
tsunamit Jun 18, 2024
d5e8793
Add initial safety model test files
tsunamit Jun 18, 2024
24b9ed2
typing issues
tsunamit Jun 18, 2024
2b062d9
fix glob search bug
tsunamit Jun 19, 2024
6d7ac4d
Add readme
tsunamit Jun 24, 2024
53fbebe
Add comments and todos
tsunamit Jun 24, 2024
d594d72
use inject secret
tsunamit Jun 24, 2024
9c06361
clarify lg1 responses
tsunamit Jun 24, 2024
ddd61b3
create chat test
tsunamit Jun 25, 2024
01f48a2
Create constants file
tsunamit Jun 25, 2024
c7d963b
Refactor utils and classes
tsunamit Jun 25, 2024
11f3889
Add second test
tsunamit Jun 26, 2024
841ec79
Split into 2 separate tests
tsunamit Jun 26, 2024
f2c6e51
update readme
tsunamit Jun 26, 2024
507b8c7
Use runner instead of test for llama guard 2
tsunamit Jun 27, 2024
03b3b22
remove sut, add runner
tsunamit Jun 27, 2024
fd16bc9
remove testing code
tsunamit Jun 27, 2024
cf2ac8e
remove old block
tsunamit Jun 27, 2024
3c85d05
remove unrelated steps in readme
tsunamit Jun 27, 2024
9bffeb2
Move some constants around
tsunamit Jun 27, 2024
9715eba
update ref
tsunamit Jun 27, 2024
d95738c
Add chat test
tsunamit Jul 1, 2024
c1f030c
Remove comments
tsunamit Jul 1, 2024
8f50685
Consolidate into single test
tsunamit Jul 2, 2024
8c399de
add uids for tests and mypy
tsunamit Jul 2, 2024
eb21393
move method to utils
tsunamit Jul 2, 2024
fa6d39e
uid update
tsunamit Jul 2, 2024
68e4453
uid updates again
tsunamit Jul 2, 2024
a23abf0
add main
tsunamit Jul 3, 2024
ec294b0
add uid to references
tsunamit Jul 3, 2024
1768f51
remove unneeded param
tsunamit Jul 3, 2024
de9e0bf
Return custom test record, remove unneeded deps
tsunamit Jul 3, 2024
25da305
Create base annotator test for annotator runner (#466)
bkorycki Jul 12, 2024
d69ef25
Add support for textprompt, fix type bug when calling measure_quality
tsunamit Jul 15, 2024
ecad887
Add text prompt support to llama 3 annotator
tsunamit Jul 15, 2024
cdbbe3d
Add readme instructions to run alpha version
tsunamit Jul 15, 2024
99f2d0b
Merge branch 'main' of https://github.com/mlcommons/modelgauge into r…
tsunamit Jul 15, 2024
4070eb6
reset records
tsunamit Jul 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions modelgauge/annotator_test_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import random

from modelgauge.safety_eval_context import SafetyEvalContext
from modelgauge.annotation import Annotation
from modelgauge.annotator import CompletionAnnotator
from modelgauge.base_test import PromptResponseTest, TestResult
from modelgauge.caching import Cache, NoCache, SqlDictCache
from modelgauge.dependency_helper import FromSourceDependencyHelper
from modelgauge.prompt import TextPrompt
from modelgauge.records import TestItemRecord, TestRecord
from modelgauge.single_turn_prompt_response import (
MeasuredTestItem,
PromptInteractionAnnotations,
SUTCompletionAnnotations,
SUTResponseAnnotations,
TestItem,
TestItemAnnotations,
)
from modelgauge.sut import SUTCompletion
from modelgauge.sut_capabilities_verification import assert_sut_capabilities
from modelgauge.sut_decorator import assert_is_sut
from modelgauge.test_decorator import assert_is_test
from tqdm import tqdm
from typing import List, Optional


def run_annotator_test(
test: PromptResponseTest,
data_dir: str,
max_test_items: Optional[int] = None,
use_caching: bool = True,
disable_progress_bar: bool = False,
) -> TestRecord:
"""Demonstration for how to run a single Test on a single SUT, all calls serial."""

assert_is_test(test)

# Ensure we can record what these objects are
test_initialization = test.initialization_record
test_data_path = os.path.join(data_dir, "tests", test.__class__.__name__)

annotators = []
for key, annotator in test.get_annotators().items():
annotator_cache: Cache
if use_caching:
annotator_cache = SqlDictCache(
os.path.join(test_data_path, "annotators"), key
)
else:
annotator_cache = NoCache()
assert isinstance(
annotator, CompletionAnnotator
), "Only know how to do CompletionAnnotator."
annotators.append(AnnotatorData(key, annotator, annotator_cache))

# This runner just records versions, it doesn't specify a required version.
dependency_helper = FromSourceDependencyHelper(
os.path.join(test_data_path, "dependency_data"),
test.get_dependencies(),
required_versions={},
)

# Still need to create test items with prompt, responses. Just don't need sut for this
test_items = test.make_test_items(dependency_helper)
if max_test_items is not None:
assert max_test_items > 0, f"Cannot run a test using {max_test_items}."
if max_test_items < len(test_items):
rng = random.Random()
rng.seed(0)
rng.shuffle(test_items)
test_items = test_items[:max_test_items]
test_item_records = []
measured_test_items = []
desc = f"Processing TestItems for test={test.uid}"
for test_item in tqdm(test_items, desc=desc, disable=disable_progress_bar):
test_item_record = _process_test_item(test_item, test, annotators)
test_item_records.append(test_item_record)
measured_test_items.append(
MeasuredTestItem(
test_item=test_item_record.test_item,
measurements=test_item_record.measurements,
)
)
test_result = TestResult.from_instance(
test.aggregate_measurements(measured_test_items)
)
return TestRecord(
test_uid=test.uid,
test_initialization=test_initialization,
dependency_versions=dependency_helper.versions_used(),
sut_uid="",
sut_initialization=None,
test_item_records=test_item_records,
result=test_result,
)


class AnnotatorData:
"""Container to hold data about an annotator."""

def __init__(self, key: str, annotator: CompletionAnnotator, cache: Cache):
self.key = key
self.annotator = annotator
self.cache = cache


def _process_test_item(
item: TestItem,
test: PromptResponseTest,
annotators: List[AnnotatorData],
) -> TestItemRecord:
interactions: List[PromptInteractionAnnotations] = []
for prompt in item.prompts:
if isinstance(prompt.context, SafetyEvalContext):
response = prompt.context.response
completion = SUTCompletion(text=response)
else:
raise ValueError("Prompt object does not have SafetyEvalContext object.")

annotated_completions: List[SUTCompletionAnnotations] = []

annotations = {}
for annotator_data in annotators:
annotator = annotator_data.annotator
try:
annotator_request = annotator.translate_request(prompt, completion)
with annotator_data.cache as cache:
annotator_response = cache.get_or_call(
annotator_request, annotator.annotate
)
# TODO some check needs to happen here to ensure annotator
# response is of a certain type. Ideally it should be
# compile time driven, not runtime
annotation = annotator.translate_response(
annotator_request, annotator_response
)
except Exception as e:
raise Exception(
f"Exception while handling annotation for {annotator_data.key} on {prompt}"
) from e

annotations[annotator_data.key] = Annotation.from_instance(annotation)
annotated_completions.append(
SUTCompletionAnnotations(completion=completion, annotations=annotations)
)

interactions.append(
PromptInteractionAnnotations(
prompt=prompt,
response=SUTResponseAnnotations(completions=annotated_completions),
)
)
annotated = TestItemAnnotations(
test_item=item,
interactions=interactions,
)
measurements = test.measure_quality(annotated)

return TestItemRecord(
test_item=annotated.test_item,
interactions=annotated.interactions,
measurements=measurements,
)
4 changes: 2 additions & 2 deletions modelgauge/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TestItem,
)
from pydantic import AwareDatetime, BaseModel, Field
from typing import Dict, List, Mapping
from typing import Dict, List, Mapping, Optional


class TestItemRecord(BaseModel):
Expand All @@ -29,7 +29,7 @@ class TestRecord(BaseModel):
test_initialization: InitializationRecord
dependency_versions: Mapping[str, str]
sut_uid: str
sut_initialization: InitializationRecord
sut_initialization: Optional[InitializationRecord]
tsunamit marked this conversation as resolved.
Show resolved Hide resolved
# TODO We should either reintroduce "Turns" here, or expect
# there to b different schemas for different TestImplementationClasses.
test_item_records: List[TestItemRecord]
Expand Down
9 changes: 9 additions & 0 deletions plugins/safety_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Safety Models Plugin
**This plugin is in alpha mode. Expect there to be issues. Please reach out to the engineering team to resolve**

## Notes
- This test is currently only compatible with the 1320 MLC human annotated dataset, which is under restricted access (contact engineering team for access)
- This test requires using the annotator specific test runner, which is not configurable except by code.

## Known issues
- running pytests using zsh (instead of bash) as your terminal has issues collecting tests due to the * wildcard search issue. TLDR: use bash instead to run pytests
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
from typing import List
from modelgauge.safety_model_constants import DEFAULT_MAX_CHAT_TOKENS
from modelgauge.safety_model_prompts import (
LABELER_SYS_PROMPT,
LABELER_USER_PROMPT,
MLC_TAXONOMY_0_5,
)
from modelgauge.suts.together_client import (
TogetherApiKey,
TogetherChatRequest,
TogetherChatResponse,
TogetherChatSUT,
)
from modelgauge.annotator import CompletionAnnotator
from modelgauge.prompt import ChatMessage, ChatPrompt, ChatRole, SUTOptions
from modelgauge.single_turn_prompt_response import PromptWithContext
from modelgauge.safety_model_response import SafetyModelResponse
from modelgauge.sut import SUTCompletion

ANNOTATOR_UID = "llama_3_70b_chat_annotator"
LLAMA_3_70B_MODEL_PATH = "meta-llama/Llama-3-70b-chat-hf"


class Llama3_70bChatAnnotator(CompletionAnnotator[SafetyModelResponse]):
"""Annotator simply uses a prompt with Llama 3 70b to do hazard and safety classification"""

def __init__(
self,
together_api_key: TogetherApiKey,
):
self.model_name = None
self.model = TogetherChatSUT(
ANNOTATOR_UID, LLAMA_3_70B_MODEL_PATH, together_api_key
)

def translate_request(self, prompt: PromptWithContext, completion: SUTCompletion):
"""Convert the prompt+completion into the native representation for this annotator."""

system_prompt = ChatMessage(
text=LABELER_SYS_PROMPT.format(taxonomy=MLC_TAXONOMY_0_5),
role=ChatRole.system,
)

# We have a list of interaction messages, simplest option is to turn into 'role': 'message' block of text
if isinstance(prompt.prompt, ChatPrompt):
chat_interactions_to_annotate = self._messages_to_str(
prompt.prompt.messages
)
else:
raise ValueError(
f"Can only handle ChatPrompt type. Type given: {type(prompt.prompt)}"
)

request = self.model.translate_chat_prompt(
prompt=ChatPrompt(
messages=[
system_prompt,
ChatMessage(
text=chat_interactions_to_annotate,
role=ChatRole.user,
),
],
options=SUTOptions(max_tokens=DEFAULT_MAX_CHAT_TOKENS),
)
)

return request

def annotate(self, annotation_request):
"""Perform annotation and return the raw response from the annotator."""
result = self.model.evaluate(annotation_request)
return result

def translate_response(
self, request: TogetherChatRequest, response: TogetherChatResponse
) -> SafetyModelResponse:
"""Convert the raw response into the form read by Tests."""

response_text = response.choices[0].message.content

# Parse response
try:
response_json = json.loads(response_text)
return SafetyModelResponse(**response_json)
except Exception as e:
print(f"Error parsing response JSON: {e}")
return SafetyModelResponse(
is_safe=False,
safety_categories=[],
is_valid=False,
)

def _messages_to_str(self, messages: List[ChatMessage]) -> str:
def _role_to_str(role: ChatRole) -> str:
if role == ChatRole.user:
return "User"
elif role == ChatRole.sut:
return "Assistant"
else:
raise ValueError(f"Unknown or invalid role: {role}")

return "\n".join([f"{_role_to_str(m.role)}: {m.text}" for m in messages])
12 changes: 12 additions & 0 deletions plugins/safety_models/modelgauge/safety_eval_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import BaseModel
from typing import Dict, Optional


class SafetyEvalContext(BaseModel):
uid: str
prompt: str
response: str
hazard_category: str
annotations: Dict[str, str]
lg_1_is_safe: int
lg_1_violation_categories: Optional[str] = None
5 changes: 5 additions & 0 deletions plugins/safety_models/modelgauge/safety_model_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
HUMAN_EVAL_DEPENDENCY_KEY = "human_eval_dataset"

DEFAULT_1320_DATASET_PATH = "./1320mlc.csv"

DEFAULT_MAX_CHAT_TOKENS = 4000
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pydantic import BaseModel


class SafetyModelMeasurement(BaseModel):
valid_test: bool
human_majority_is_safe: float
sut_is_safe: float
Loading
Loading