Skip to content

Commit

Permalink
fix vllm bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
yangky11 committed Aug 15, 2024
1 parent 47ee563 commit 4ce2e67
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
9 changes: 9 additions & 0 deletions prover/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def evaluate(
module: Optional[str] = None,
num_sampled_tactics: int = 64,
timeout: int = 600,
max_expansions: Optional[int] = None,
num_workers: int = 1,
num_gpus: int = 0,
save_results: bool = False,
Expand All @@ -135,6 +136,7 @@ def evaluate(
num_workers,
num_gpus=num_gpus,
timeout=timeout,
max_expansions=max_expansions,
num_sampled_tactics=num_sampled_tactics,
debug=verbose,
)
Expand Down Expand Up @@ -225,6 +227,12 @@ def main() -> None:
default=600,
help="Maximum number of seconds the proof search can take.",
)
parser.add_argument(
"--max-expansions",
type=int,
default=None,
help="Maximum number of expansions during proof search.",
)
parser.add_argument(
"--num-workers", type=int, default=1, help="The number of concurrent provers."
)
Expand Down Expand Up @@ -262,6 +270,7 @@ def main() -> None:
args.module,
args.num_sampled_tactics,
args.timeout,
args.max_expansions,
args.num_workers,
args.num_gpus,
args.save_results,
Expand Down
19 changes: 15 additions & 4 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def __init__(
self,
tac_gen, # A given tactic generator.
timeout: int,
max_expansions: Optional[int],
num_sampled_tactics: int,
debug: bool,
) -> None:
self.tac_gen = tac_gen
self.tac_gen.initialize()
self.timeout = timeout
self.max_expansions = max_expansions
self.num_sampled_tactics = num_sampled_tactics
self.debug = debug

Expand Down Expand Up @@ -149,11 +151,14 @@ async def _best_first_search(self) -> None:
assert time.time() - time_start >= self.timeout

self.total_time = time.time() - time_start
if self.total_time > self.timeout:
if self.total_time > self.timeout or (
self.max_expansions is not None
and self.num_expansions > self.max_expansions
):
if self.root.status == Status.PROVED:
logger.info("Found a proof but timed out.")
logger.info("Found a proof!")
self.root.status = Status.OPEN
logger.info("Search timed out.")
logger.info("Hit the resource limit (timeout or max_expansions).")
break

if self.root.status == Status.FAILED:
Expand Down Expand Up @@ -307,12 +312,14 @@ def __init__(
self,
tac_gen: TacticGenerator,
timeout: int,
max_expansions: Optional[int],
num_sampled_tactics: int,
debug: bool,
) -> None:
self.prover = BestFirstSearchProver(
tac_gen,
timeout,
max_expansions,
num_sampled_tactics,
debug,
)
Expand Down Expand Up @@ -349,6 +356,7 @@ async def generate(self, prompt: str, num_samples: int) -> RequestOutput:
length_penalty=0,
use_beam_search=True,
early_stopping=False,
logprobs=0,
)

async for oup in self.engine.generate(
Expand Down Expand Up @@ -379,6 +387,7 @@ def __init__(
num_workers: int,
num_gpus: int,
timeout: int,
max_expansions: Optional[int],
num_sampled_tactics: int,
debug: Optional[bool] = False,
) -> None:
Expand Down Expand Up @@ -416,7 +425,7 @@ def __init__(
if not self.distributed:
assert num_gpus <= 1
self.prover = BestFirstSearchProver(
tac_gen, timeout, num_sampled_tactics, debug
tac_gen, timeout, max_expansions, num_sampled_tactics, debug
)
return

Expand All @@ -431,6 +440,7 @@ def __init__(
ProverActor.options(num_gpus=num_gpus_per_worker).remote(
tac_gen,
timeout=timeout,
max_expansions=max_expansions,
num_sampled_tactics=num_sampled_tactics,
debug=debug,
)
Expand All @@ -442,6 +452,7 @@ def __init__(
ProverActor.remote(
tac_gen,
timeout=timeout,
max_expansions=max_expansions,
num_sampled_tactics=num_sampled_tactics,
debug=debug,
)
Expand Down

0 comments on commit 4ce2e67

Please sign in to comment.