Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Aug 23, 2024
1 parent 4ce2e67 commit 903fac4
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ After the models are trained, run the following commands to retrieve premises fo
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_novel_premises --trainer.logger.save_dir logs/predict_retriever_novel_premises
```
, where `PATH_TO_RETRIEVER_CHECKPOINT` is the model checkpoint produced in the previous step. Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`.
Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`. Note that `PATH_TO_RETRIEVER_CHECKPOINT` is the DeepSpeed model checkpoint produced in the previous step. If you want to use a Hugging Face checkpoint instead, a workaround would be to run the training for 1 step with zero learning rate.


### Evaluating the Retrieved Premises
Expand Down
1 change: 0 additions & 1 deletion generation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _log_io_texts(
def on_fit_start(self) -> None:
if self.logger is not None:
self.logger.log_hyperparams(self.hparams)
self.logger.watch(self.generator)
assert self.trainer is not None
logger.info(f"Logging to {self.trainer.log_dir}")

Expand Down
5 changes: 3 additions & 2 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,9 @@ def initialize(self) -> None:
engine_args = AsyncEngineArgs(
model=self.model_path,
tensor_parallel_size=self.num_gpus,
max_num_batched_tokens=2048,
enable_chunked_prefill=True,
max_num_batched_tokens=8192,
# max_num_batched_tokens=2048,
# enable_chunked_prefill=True,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

Expand Down
30 changes: 30 additions & 0 deletions retrieval/confs/cli_dummy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
seed_everything: 3407 # https://arxiv.org/abs/2109.08203
trainer:
accelerator: gpu
devices: 1
precision: bf16-mixed
strategy:
class_path: pytorch_lightning.strategies.DeepSpeedStrategy
init_args:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
gradient_clip_val: 1.0
max_steps: 1
logger: null

model:
model_name: google/byt5-small
lr: 0
warmup_steps: 2000
num_retrieved: 100

data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
num_negatives: 3
num_in_file_negatives: 1
batch_size: 8
eval_batch_size: 64
max_seq_len: 1024
num_workers: 4
3 changes: 1 addition & 2 deletions retrieval/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
import pytorch_lightning as pl
import torch.nn.functional as F
from typing import List, Dict, Any, Tuple, Union, Optional
from typing import List, Dict, Any, Tuple, Union
from transformers import AutoModelForTextEncoding, AutoTokenizer

from common import (
Expand Down Expand Up @@ -146,7 +146,6 @@ def forward(
def on_fit_start(self) -> None:
if self.logger is not None:
self.logger.log_hyperparams(self.hparams)
self.logger.watch(self.encoder)
logger.info(f"Logging to {self.trainer.log_dir}")

self.corpus = self.trainer.datamodule.corpus
Expand Down

0 comments on commit 903fac4

Please sign in to comment.