Skip to content

Commit

Permalink
update index.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Jul 12, 2024
1 parent a554532 commit 40e07a6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
14 changes: 4 additions & 10 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def initialize(self) -> None:
engine_args = AsyncEngineArgs(
model=self.model_path,
tensor_parallel_size=self.num_gpus,
max_num_batched_tokens=8192,
enable_chunked_prefill=False,
max_num_batched_tokens=2048,
enable_chunked_prefill=True,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

Expand Down Expand Up @@ -394,17 +394,11 @@ def __init__(
tac_gen = VllmGenerator(vllm_actor)
elif indexed_corpus_path is not None:
tac_gen = RetrievalAugmentedGenerator(
gen_ckpt_path,
ret_ckpt_path,
indexed_corpus_path,
device,
max_num_retrieved=100,
gen_ckpt_path, ret_ckpt_path, indexed_corpus_path, device, max_num_retrieved=100
)
else:
device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu")
tac_gen = HuggingFaceGenerator(
gen_ckpt_path, device, max_oup_seq_len, length_penalty
)
tac_gen = HuggingFaceGenerator(gen_ckpt_path, device, max_oup_seq_len, length_penalty)

self.distributed = num_workers > 1
if not self.distributed:
Expand Down
2 changes: 1 addition & 1 deletion retrieval/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
device = torch.device("cpu")
else:
device = torch.device("cuda")
model = PremiseRetriever.load_hf(args.ckpt_path, device, num_retrieved=100)
model = PremiseRetriever.load_hf(args.ckpt_path, device, max_seq_len=2048)
model.load_corpus(args.corpus_path)
model.reindex_corpus(batch_size=args.batch_size)

Expand Down
8 changes: 5 additions & 3 deletions retrieval/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
import pytorch_lightning as pl
import torch.nn.functional as F
from typing import List, Dict, Any, Tuple, Union
from typing import List, Dict, Any, Tuple, Union, Optional
from transformers import AutoModelForTextEncoding, AutoTokenizer

from common import (
Expand Down Expand Up @@ -51,10 +51,12 @@ def load(cls, ckpt_path: str, device, freeze: bool) -> "PremiseRetriever":

@classmethod
def load_hf(
cls, ckpt_path: str, device, dtype, num_retrieved: int
cls, ckpt_path: str, device: int, dtype = None, max_seq_len: Optional[int] = None
) -> "PremiseRetriever":
if max_seq_len is None:
max_seq_len = 999999999999
model = (
PremiseRetriever(ckpt_path, 0.0, 0, 999999999999, num_retrieved)
PremiseRetriever(ckpt_path, 0.0, 0, max_seq_len, 100)
.to(device)
.eval()
)
Expand Down

0 comments on commit 40e07a6

Please sign in to comment.