Skip to content

Commit

Permalink
fixed prompt related bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
NickLennonLiu committed Mar 15, 2024
1 parent f10811f commit b068421
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
6 changes: 3 additions & 3 deletions configs/commons/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def mc_abcd_gen_prompt_template(prompt_hint, answer_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\nA: {{A}}\nB: {{B}}\nC: {{C}}\nD: {{D}}\n{answer_hint}'
),
# dict(role="BOT", prompt="{answer}")
dict(role="BOT", prompt="{answer}")
],
),
ice_token="</E>",
Expand Down Expand Up @@ -85,7 +85,7 @@ def mc_abcd_cot_prompt_template(prompt_hint, cot_think_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\nA: {{A}}\nB: {{B}}\nC: {{C}}\nD: {{D}}\n{cot_think_hint}'
),
# dict(role="BOT", prompt="{answer}")
dict(role="BOT", prompt="{answer}")
]
),
ice_token="</E>",
Expand Down Expand Up @@ -116,7 +116,7 @@ def qa_gen_prompt_template(prompt_hint, answer_hint):
role="HUMAN",
prompt=f'{prompt_hint}{{question}}\n{answer_hint}'
),
# dict(role="BOT", prompt="{answer}")
dict(role="BOT", prompt="{answer}")
],
),
ice_token="</E>",
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/opseval/mc_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def get_mc_gen_datasets(dataset_name, path, langs=['zh'], qtypes=['single']):
f"让我们逐个选项分析:\n"
],
[
[f'{prompts[shot_hint_id][qtype_hint_id][0]}Therefore the answer is: \n'],
[f'{prompts[shot_hint_id][qtype_hint_id][1]}因此答案是:\n']
f'{prompts[shot_hint_id][qtype_hint_id][0]}Therefore the answer is: \n',
f'{prompts[shot_hint_id][qtype_hint_id][1]}因此答案是:\n'
]
)
]
Expand Down
2 changes: 1 addition & 1 deletion opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def inference(self,
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# prompt_list = list(zip(prompt_list, gold_ans))

# Create tmp json file for saving intermediate results and future
# resuming
Expand Down
9 changes: 7 additions & 2 deletions opencompass/openicl/icl_inferencer/icl_sc_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def inference(self,
ice_template=ice_template,
prompt_template=prompt_template)

# 3.1 Fetch and zip prompt & gold answer if output column exists
# # 3.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
prompt_list = list(zip(prompt_list, gold_ans))
# prompt_list = list(zip(prompt_list, gold_ans))

# Create tmp json file for saving intermediate results and future
# resuming
Expand All @@ -132,11 +132,16 @@ def inference(self,
else:
entry = datum
golds = [None for _ in range(len(entry))]
# entry = [t[0] for t in datum]
# golds = [t[1] for t in datum]
# TODO: add more types of CoT method
# 5-1. Inference sc_size times with local model
with torch.no_grad():
parsed_entries = self.model.parse_template(entry,
mode='gen')

# print("[MERGE DEBUG]: ENTRY", entry)
# print("[MERGE DEBUG]: PARSED_ENTRIES", parsed_entries)
sc_results = []
for _ in range(self.sc_size):
results = self.model.generate_from_template(
Expand Down

0 comments on commit b068421

Please sign in to comment.