From 6143f6b75a2e26a6821aa0b3ad1ae97f2d6d7dcc Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Thu, 25 Jan 2024 09:29:35 -0800 Subject: [PATCH] RSyncing random seed between ranks (#8230) Signed-off-by: Igor Gitman Co-authored-by: Eric Harper --- .../modules/common/text_generation_server.py | 10 ++++++++++ .../modules/common/text_generation_utils.py | 18 ++++++++++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/text_generation_server.py b/nemo/collections/nlp/modules/common/text_generation_server.py index 52efc95c3d1a..6c257317b99f 100644 --- a/nemo/collections/nlp/modules/common/text_generation_server.py +++ b/nemo/collections/nlp/modules/common/text_generation_server.py @@ -46,6 +46,7 @@ "min_tokens_to_generate", "end_strings", "compute_logprob", + "random_seed", ] ) @@ -172,6 +173,14 @@ def put(self): if not isinstance(compute_logprob, bool): return "compute_logprob must be a boolean value" + random_seed = None + if "random_seed" in request.get_json(): + random_seed = request.get_json()["random_seed"] + if random_seed is not None and not isinstance(random_seed, int): + return "random_seed must be a positive integer number or None" + if random_seed is not None and random_seed < 0: + return "random_seed must be a positive integer number or None" + with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate extra = {} @@ -200,6 +209,7 @@ def put(self): end_strings=end_strings, min_tokens_to_generate=min_tokens_to_generate, compute_logprob=compute_logprob, + random_seed=random_seed, **extra, ) for k in output: diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 486536c1521a..c6a8f1e46900 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -360,12 +360,15 @@ def send_generate_info( repetition_penalty, min_tokens_to_generate, end_strings, + random_seed, ): """ Needs to be synced up with receive_generate_info """ model_parallel_group = parallel_state.get_model_parallel_group() src = get_model_parallel_src_rank() + if random_seed is None: + random_seed = -1 # to be able to convert to float # Send the sizes of the tensors input_info = [ context_tokens_tensor.size(0), # batch_size @@ -379,6 +382,7 @@ def send_generate_info( greedy, repetition_penalty, min_tokens_to_generate, + random_seed, ] input_info_tensor = torch.cuda.FloatTensor(input_info) torch.distributed.broadcast(input_info_tensor, src, model_parallel_group) @@ -402,7 +406,7 @@ def receive_generate_info(): """ model_parallel_group = parallel_state.get_model_parallel_group() src = get_model_parallel_src_rank() - input_info_tensor = torch.empty(11, dtype=torch.float32, device=torch.cuda.current_device()) + input_info_tensor = torch.empty(12, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, src, model_parallel_group) batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) @@ -415,6 +419,9 @@ def receive_generate_info(): greedy = bool(input_info_tensor[8].item()) repetition_penalty = float(input_info_tensor[9].item()) min_tokens_to_generate = int(input_info_tensor[10].item()) + random_seed = int(input_info_tensor[11].item()) + if random_seed == -1: # was converted to -1 before broadcast + random_seed = None context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) @@ -443,6 +450,7 @@ def receive_generate_info(): repetition_penalty, min_tokens_to_generate, end_strings, + random_seed, ) @@ -586,9 +594,6 @@ def generate( token_ids: List[Tensor], output sentence token ids offsets: List[List[int]] # list of tokens start positions in text """ - if random_seed is not None: - seed_everything(random_seed) - if 'strategy' in strategy_args: inference_strategy = strategy_args['strategy'] else: @@ -615,6 +620,7 @@ def generate( repetition_penalty, min_tokens_to_generate, end_strings, + random_seed, ) else: ( @@ -630,8 +636,12 @@ def generate( repetition_penalty, min_tokens_to_generate, end_strings, + random_seed, ) = receive_generate_info() + if random_seed is not None: + seed_everything(random_seed) + output = synced_generate( model, inference_strategy,