Skip to content

Commit

Permalink
Add safety checks for 'data' key in MegatronGPTModel cfg
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Apr 20, 2024
1 parent c9c8408 commit a223412
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
self.return_output_tensors = cfg.data.get('return_output_tensors', False)
self.validation_drop_last = cfg.data.get('validation_drop_last', True)
self.sample_weight = cfg.data.get('sample_weight', 'token')
data_cfg = cfg.get('data', {})
self.return_output_tensors = data_cfg.get('return_output_tensors', False)
self.validation_drop_last = data_cfg.get('validation_drop_last', True)
self.sample_weight = data_cfg.data.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)

self.inference_params = None
Expand Down

0 comments on commit a223412

Please sign in to comment.