Skip to content

Commit

Permalink
RSyncing random seed between ranks (#8230)
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
Kipok and ericharper authored Jan 25, 2024
1 parent f10d694 commit 6143f6b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
10 changes: 10 additions & 0 deletions nemo/collections/nlp/modules/common/text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"min_tokens_to_generate",
"end_strings",
"compute_logprob",
"random_seed",
]
)

Expand Down Expand Up @@ -172,6 +173,14 @@ def put(self):
if not isinstance(compute_logprob, bool):
return "compute_logprob must be a boolean value"

random_seed = None
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if random_seed is not None and not isinstance(random_seed, int):
return "random_seed must be a positive integer number or None"
if random_seed is not None and random_seed < 0:
return "random_seed must be a positive integer number or None"

with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
extra = {}
Expand Down Expand Up @@ -200,6 +209,7 @@ def put(self):
end_strings=end_strings,
min_tokens_to_generate=min_tokens_to_generate,
compute_logprob=compute_logprob,
random_seed=random_seed,
**extra,
)
for k in output:
Expand Down
18 changes: 14 additions & 4 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,15 @@ def send_generate_info(
repetition_penalty,
min_tokens_to_generate,
end_strings,
random_seed,
):
"""
Needs to be synced up with receive_generate_info
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
if random_seed is None:
random_seed = -1 # to be able to convert to float
# Send the sizes of the tensors
input_info = [
context_tokens_tensor.size(0), # batch_size
Expand All @@ -379,6 +382,7 @@ def send_generate_info(
greedy,
repetition_penalty,
min_tokens_to_generate,
random_seed,
]
input_info_tensor = torch.cuda.FloatTensor(input_info)
torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
Expand All @@ -402,7 +406,7 @@ def receive_generate_info():
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
input_info_tensor = torch.empty(11, dtype=torch.float32, device=torch.cuda.current_device())
input_info_tensor = torch.empty(12, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
Expand All @@ -415,6 +419,9 @@ def receive_generate_info():
greedy = bool(input_info_tensor[8].item())
repetition_penalty = float(input_info_tensor[9].item())
min_tokens_to_generate = int(input_info_tensor[10].item())
random_seed = int(input_info_tensor[11].item())
if random_seed == -1: # was converted to -1 before broadcast
random_seed = None

context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
Expand Down Expand Up @@ -443,6 +450,7 @@ def receive_generate_info():
repetition_penalty,
min_tokens_to_generate,
end_strings,
random_seed,
)


Expand Down Expand Up @@ -586,9 +594,6 @@ def generate(
token_ids: List[Tensor], output sentence token ids
offsets: List[List[int]] # list of tokens start positions in text
"""
if random_seed is not None:
seed_everything(random_seed)

if 'strategy' in strategy_args:
inference_strategy = strategy_args['strategy']
else:
Expand All @@ -615,6 +620,7 @@ def generate(
repetition_penalty,
min_tokens_to_generate,
end_strings,
random_seed,
)
else:
(
Expand All @@ -630,8 +636,12 @@ def generate(
repetition_penalty,
min_tokens_to_generate,
end_strings,
random_seed,
) = receive_generate_info()

if random_seed is not None:
seed_everything(random_seed)

output = synced_generate(
model,
inference_strategy,
Expand Down

0 comments on commit 6143f6b

Please sign in to comment.