Skip to content

Commit

Permalink
Change from no_grad to inference_mode (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
loubbrad authored Sep 29, 2024
1 parent 3e20c33 commit 4120f57
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def wrapper(*args, **kwargs):
return wrapper


@torch.no_grad()
@torch.inference_mode()
def decode_token(
model: AmtEncoderDecoder,
x: torch.Tensor,
Expand All @@ -170,7 +170,7 @@ def decode_token(
return logits, next_tok_ids


@torch.no_grad()
@torch.inference_mode()
def prefill(
model: AmtEncoderDecoder,
x: torch.Tensor,
Expand Down Expand Up @@ -211,7 +211,7 @@ def calculate_input_pos(prefix_lens: torch.Tensor, prefill: bool):


@optional_bf16_autocast
@torch.no_grad()
@torch.inference_mode()
def process_segments(
tasks: List,
model: AmtEncoderDecoder,
Expand Down Expand Up @@ -304,6 +304,7 @@ def process_segments(
return results


@torch.inference_mode()
def gpu_manager(
gpu_batch_queue: Queue,
gpu_waiting_dict: dict,
Expand Down Expand Up @@ -1024,7 +1025,9 @@ def batch_transcribe(
min(batch_size, multiprocessing.cpu_count() - num_gpus),
file_queue.qsize(),
)
num_processes_per_worker = min(10, file_queue.qsize() // num_workers)
num_processes_per_worker = min(
3 * (batch_size // num_workers), file_queue.qsize() // num_workers
)

mp_manager = Manager()
gpu_waiting_dict = mp_manager.dict()
Expand Down

0 comments on commit 4120f57

Please sign in to comment.