Skip to content

Commit

Permalink
minor update to preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 30, 2024
1 parent 3cb2f0c commit 47ee563
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions generation/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Script for preprocess state-tactic pairs into the format required by [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)."""

import os
import json
import random
import argparse
Expand All @@ -11,30 +12,33 @@ def main() -> None:
parser.add_argument(
"--data-path",
type=str,
default="./data/leandojo_benchmark_4/random/train.json",
default="./data/leandojo_benchmark_4/random",
)
parser.add_argument("--dst-path", type=str, default="state_tactic_pairs.json")
parser.add_argument("--dst-path", type=str, default="state_tactic_pairs")
args = parser.parse_args()
logger.info(args)

pairs = []
for thm in json.load(open(args.data_path)):
for tac in thm["traced_tactics"]:
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})
logger.info(f"Read {len(pairs)} state-tactic paris from {args.data_path}")
for split in ("train", "val"):
data_path = os.path.join(args.data_path, f"{split}.json")
pairs = []
for thm in json.load(open(data_path)):
for tac in thm["traced_tactics"]:
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})
logger.info(f"Read {len(pairs)} state-tactic paris from {data_path}")

random.shuffle(pairs)
data = [
{
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
"input": "",
"output": pair["output"],
}
for pair in pairs
]
logger.info(data[0])
json.dump(data, open(args.dst_path, "wt"))
logger.info(f"Preprocessed data saved to {args.dst_path}")
random.shuffle(pairs)
data = [
{
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
"input": "",
"output": pair["output"],
}
for pair in pairs
]
logger.info(data[0])
dst_path = args.dst_path + f"_{split}.json"
json.dump(data, open(dst_path, "wt"))
logger.info(f"Preprocessed data saved to {dst_path}")


if __name__ == "__main__":
Expand Down

0 comments on commit 47ee563

Please sign in to comment.