From 493c0b83e131676e36cd932f6960d78d0987b689 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Tue, 5 Mar 2024 12:13:27 -0800 Subject: [PATCH] Add MMLU prompt variants (#484) --- CHANGELOG.md | 2 +- olmo/eval/downstream.py | 161 ++++++++++++++++++++++++++++------------ pyproject.toml | 1 + 3 files changed, 114 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3724bbffe..36fce791e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states. -- Added MMLU downstream evaluation tasks. +- Added MMLU downstream evaluation tasks, with prompt variations. - Added support for PyTorch v2.2. ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index 815a24956..fc823cb3e 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -1,4 +1,5 @@ import abc +import logging import re from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union @@ -10,6 +11,8 @@ from ..tokenizer import Tokenizer +log = logging.getLogger(__name__) + class ICLMetric(Metric): # update method does not require access to global metric state @@ -152,6 +155,7 @@ def __init__( dataset_name: Union[str, Sequence[str], None] = None, model_ctx_len: int = 2048, split="validation", + prompts=[None], # List of prompt variants to use ): super().__init__() @@ -159,6 +163,9 @@ def __init__( self.dataset_path = dataset_path self.dataset_name = dataset_name self.model_ctx_len = model_ctx_len + self.prompts = prompts + self.current_prompt = None + self.log_instances = 5 # Log the first few instances as a sanity check self.samples: List[Dict[str, Any]] = [] dataset_names: Sequence[Optional[str]] @@ -174,6 +181,7 @@ def __init__( path=self.dataset_path, name=ds_name, split=split, + trust_remote_code=True, ) ) self.dataset = datasets.concatenate_datasets(dataset_list) @@ -191,51 +199,65 @@ def prep_examples(self): """Append doc_ids to each example so that they are processed together in the metric""" doc_id = 0 for doc in self.dataset: - # from EAI harness - # how this all works: - # CTX CONT - # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] - # gpt2 \ \ - # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the - # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice - - continuations = self.doc_to_continuations(doc) - label_id = self.doc_to_label(doc) - ctx = self.token_encode(self.doc_to_text(doc)) - dc = self.token_encode(self.doc_to_domain_conditional(doc)) - - for cont_id, continuation_str in enumerate(continuations): - cont_str_len = len(continuation_str) - 1 # continuation contain leading blank - continuation = self.token_encode(continuation_str) - - # query, remove last token from continuation, truncate from left is longer than model ctx length - query = ctx + continuation[:-1] - query = query[-self.model_ctx_len :] - - # get domain conditional query - # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left - dc_query = dc + continuation[:-1] - - # form a sample - self.samples.append( - { - "doc_id": doc_id, - "cont_id": cont_id, - "ctx": ctx, - "continuation": continuation, - "ctx_len": len(ctx), - "dc_len": len(dc), - "cont_len": len( - continuation - ), # even if query has last token removed, LM will output same cont len - "cont_str_len": cont_str_len, - "query": query, # remove last token from continuation - "dc_query": dc_query, - "label_id": label_id, - } - ) - - doc_id += 1 + for prompt in self.prompts: + self.current_prompt = prompt + # from EAI harness + # how this all works: + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # gpt2 \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + continuations = self.doc_to_continuations(doc) + label_id = self.doc_to_label(doc) + doc_text = self.doc_to_text(doc) + ctx = self.token_encode(doc_text) + dc = self.token_encode(self.doc_to_domain_conditional(doc)) + if self.log_instances > 0: + self.log_instances -= 1 + ds_name = self.dataset_name + if isinstance(ds_name, list): + ds_name = ds_name[0] + log.info( + f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):" + + f"\ndoc_text: {doc_text}\ncontinuations: {continuations}" + ) + + for cont_id, continuation_str in enumerate(continuations): + cont_str_len = len(continuation_str) - 1 # continuation contain leading blank + continuation = self.token_encode(continuation_str) + + # query, remove last token from continuation, truncate from left is longer than model ctx length + query = ctx + continuation[:-1] + query = query[-self.model_ctx_len :] + # this will be different from len(ctx) when truncated by model_ctx_len + actual_ctx_len = len(query) - len(continuation) + 1 + + # get domain conditional query + # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left + dc_query = dc + continuation[:-1] + + # form a sample + self.samples.append( + { + "doc_id": doc_id, + "cont_id": cont_id, + "ctx": ctx, + "continuation": continuation, + "ctx_len": actual_ctx_len, + "dc_len": len(dc), + "cont_len": len( + continuation + ), # even if query has last token removed, LM will output same cont len + "cont_str_len": cont_str_len, + "query": query, # remove last token from continuation + "dc_query": dc_query, + "label_id": label_id, + } + ) + + doc_id += 1 def pad_tokens_until_max(self, tokens, max_len=2048): """truncate from left if len(tokens) > model_ctx_len, max_len is not considered then @@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None): ) def doc_to_text(self, doc): - return doc["support"] + "\nQuestion: " + doc["question"] + "\nAnswer:".strip() + return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:" def doc_to_continuations(self, doc): # add spaces in front of continuation @@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset): "other": ["other", "business", "health"], } - def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"): + def __init__( + self, + tokenizer, + dataset_path="hails/mmlu_no_train", + dataset_name=None, + split="validation", + prompt_variations=None, + ): dataset_names = [] # Collect the relevant categories if dataset_name in MMLU._categories: @@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N for name, cats in MMLU._subcategories.items(): if dataset_name in cats: dataset_names.append(name) - super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split) + self.dev_set = {} + if prompt_variations == 1: + prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"] + # Need to grab the dev set for the few-shot prompts + for name in dataset_names: + self.dev_set[name] = datasets.load_dataset( + path=dataset_path, name=name, split="dev", trust_remote_code=True + ) + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_names, + split=split, + prompts=prompts, + ) def doc_to_text(self, doc): - return "Question: " + doc["question"] + "\nAnswer:" + output_text = "Question: " + doc["question"] + "\nAnswer:" + if self.current_prompt is not None: + prefix = "" + if "inst" in self.current_prompt: + subject = doc.get("subject").replace("_", " ") + prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n" + num_shots = re.findall("\\+(\\d+)", self.current_prompt) + if num_shots: + dev_set = self.dev_set.get(doc.get("subject"), []) + num_shots_int = int(num_shots[0]) + for idx, dev_doc in enumerate(dev_set): + if idx >= num_shots_int: + break + answer = dev_doc["choices"][dev_doc["answer"]] + prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n" + output_text = prefix + output_text + return output_text def doc_to_continuations(self, doc): # add spaces in front of continuation @@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc): "mmlu_humanities": (MMLU, {"dataset_name": "humanities"}), "mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}), "mmlu_other": (MMLU, {"dataset_name": "other"}), + "mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}), + "mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}), + "mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}), + "mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}), } diff --git a/pyproject.toml b/pyproject.toml index 6ff1c2f91..07177dbdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "numpy", "torch>=2.0,<2.3", "omegaconf", + "fsspec==2023.5.0", # temporary fix for HF dataset downloads "rich", "boto3", "google-cloud-storage",