Skip to content

Commit

Permalink
polish the preprocess code
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 30, 2024
1 parent 655b1f0 commit 3cb2f0c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
41 changes: 41 additions & 0 deletions generation/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Script for preprocess state-tactic pairs into the format required by [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)."""

import json
import random
import argparse
from loguru import logger


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-path",
type=str,
default="./data/leandojo_benchmark_4/random/train.json",
)
parser.add_argument("--dst-path", type=str, default="state_tactic_pairs.json")
args = parser.parse_args()
logger.info(args)

pairs = []
for thm in json.load(open(args.data_path)):
for tac in thm["traced_tactics"]:
pairs.append({"state": tac["state_before"], "output": tac["tactic"]})
logger.info(f"Read {len(pairs)} state-tactic paris from {args.data_path}")

random.shuffle(pairs)
data = [
{
"instruction": f"[GOAL]\n{pair['state']}\n[PROOFSTEP]\n",
"input": "",
"output": pair["output"],
}
for pair in pairs
]
logger.info(data[0])
json.dump(data, open(args.dst_path, "wt"))
logger.info(f"Preprocessed data saved to {args.dst_path}")


if __name__ == "__main__":
main()
43 changes: 0 additions & 43 deletions generation/preprocess_data.py

This file was deleted.

0 comments on commit 3cb2f0c

Please sign in to comment.