Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ZeRO3 - GPU memory leakage during backward operation while training a Huggingface PEFT model #3378

Closed
suri-kunal opened this issue Apr 25, 2023 · 12 comments
Assignees
Labels
bug Something isn't working training

Comments

@suri-kunal
Copy link

suri-kunal commented Apr 25, 2023

@stas00, @tjruwase - Tagging you here since I have seen you working on ZeRO3 extensively. Apologies if I shouldn't do this.
Describe the bug
I am fine tuning a LoRA model on top of BioBART-V2-Base using Deepspeed and Hugginface PEFT library on T4 instance. I am not using Hugginface Trainer class as I wanted to learn how to integrate Deepspeed in with any code. To benchmark how different ZeRO configurations work, I ran the code using following configurations -

Baseline -

{
    "scheduler": {
        "type": "WarmupDecayLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_type": "linear", 
            "total_num_steps": 6.497000e+03, 
            "warmup_max_lr": 0.001, 
            "warmup_num_steps": 650
        }
    }, 
    "optimizer": {
        "type": "Adam", 
        "params": {
            "betas": [0.9, 0.999], 
            "eps": 1e-06, 
            "weight_decay": 0.01, 
            "bias_correction": true
        }
    }, 
    "train_micro_batch_size_per_gpu": 1, 
    "gradient_accumulation_steps": 16, 
    "gradient_clipping": 1.0
}

ZeRO 2 -

{
    "scheduler": {
        "type": "WarmupDecayLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_type": "linear", 
            "total_num_steps": 6.497000e+03, 
            "warmup_max_lr": 0.001, 
            "warmup_num_steps": 650
        }
    }, 
    "optimizer": {
        "type": "Adam", 
        "params": {
            "betas": [0.9, 0.999], 
            "eps": 1e-06, 
            "weight_decay": 0.01, 
            "bias_correction": true
        }
    }, 
    "fp16": {
        "enabled": true, 
        "auto_cast": false, 
        "loss_scale": 0, 
        "initial_scale_power": 16, 
        "loss_scale_window": 1000, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e8,
        "contiguous_gradients": true
    }, 
    "train_micro_batch_size_per_gpu": 1, 
    "gradient_accumulation_steps": 16, 
    "gradient_clipping": 1.0
}

and ZeRO 3 -

{
    "scheduler": {
        "type": "WarmupDecayLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_type": "linear", 
            "total_num_steps": 6.497000e+03, 
            "warmup_max_lr": 0.001, 
            "warmup_num_steps": 650
        }
    }, 
    "optimizer": {
        "type": "Adam", 
        "params": {
            "betas": [0.9, 0.999], 
            "eps": 1e-06, 
            "weight_decay": 0.01, 
            "bias_correction": true
        }
    }, 
    "fp16": {
        "enabled": true, 
        "auto_cast": false, 
        "loss_scale": 0, 
        "initial_scale_power": 16, 
        "loss_scale_window": 1000, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "zero_optimization": {
        "stage": 3, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "offload_param": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "sub_group_size": 1.000000e+09, 
        "reduce_bucket_size": "auto", 
        "stage3_prefetch_bucket_size": "auto", 
        "stage3_param_persistence_threshold": "auto", 
        "stage3_max_live_parameters": 1.000000e+09, 
        "stage3_max_reuse_distance": 1.000000e+09, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "train_micro_batch_size_per_gpu": 1, 
    "gradient_accumulation_steps": 16, 
    "gradient_clipping": 1.0
}

Training learning curves are matching perfectly for Baseline and ZeRO2 but I am getting RuntimeError: CUDA out of memory when I try to use ZeRO3.

To Reproduce
Steps to reproduce the behavior:

def create_model_optimizer(ds_config_json,peft_config):
    
    ds_config = ds_config_json

    seed_everything(42)
    
    model = \
    AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")

    model = get_peft_model(model,peft_config)

    if model.config.decoder_start_token_id is None:
        raise Exception("Ensure that config.decoder_start_token_id is set")

    ds_config["optimizer"]["params"]["eps"] = 1e-6
    ds_config["optimizer"]["params"]["weight_decay"] = 0.01
    ds_config["optimizer"]["params"]["bias_correction"] = True
    
    return model, ds_config
def train_summarization(ds_config,peft_config,train_ds,epoch,checkpoint_folder=None):
    
    seed_everything(code_config.TASKA_SUMMARY_SEED)
    
    model, ds_config = create_model_optimizer(ds_config, peft_config)
    
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, \
                                       model=model, \
                                       label_pad_token_id=label_pad_token_id)
    
    model_engine, _, train_dl, _ = deepspeed.initialize(model=model,
                                                 model_parameters=model.parameters(),
                                                 training_data=train_ds,
                                                 collate_fn=data_collator,
                                                 config=ds_config)
    
    if checkpoint_folder.is_dir() and checkpoint_folder.exists():
        _, client_state = model_engine.load_checkpoint(load_dir=checkpoint_folder)
        old_epoch = client_state['epoch']
    else:
        checkpoint_folder.mkdir(parents=True, exist_ok=False)
    
    model_engine.train()
    if model_engine.training is False:
        raise Exception("Model is not trainable")
    total_train_loss = 0
    for train_step,train_batch in enumerate(train_dl):
        
        if train_batch["input_ids"].shape[0] > ds_config["train_micro_batch_size_per_gpu"] :
            raise Exception("batch size is not equal to train_micro_batch_size_per_gpu")
        input_ids = train_batch["input_ids"].to(device)
        attention_mask = train_batch["attention_mask"].to(device)
        labels = train_batch["labels"].to(device)
        decoder_input_ids = train_batch["decoder_input_ids"].to(device)

        output = model_engine(input_ids=input_ids, \
                       attention_mask=attention_mask, \
                       decoder_input_ids=decoder_input_ids, \
                       labels=labels, \
                       output_hidden_states=True, \
                       use_cache=False, \
                       return_dict=True)

        loss = output.loss
        
        model_engine.backward(loss)
        
        model_engine.step()
        
        total_train_loss += loss.item()

        train_step_new = train_step + epoch * len(train_dl)
        
    model_engine.save_checkpoint(save_dir=checkpoint_folder,
                                 client_state={'epoch': epoch})

    avg_train_loss = total_train_loss / len(train_dl)
    
    return avg_train_loss
def training_loop(model_name, \
                  ds_config, \
                  peft_config, \
                  train_ds, \
                  valid_ds, \
                  checkpoint_folder=None):
    
    best_loss = np.inf
    best_model = None
    best_epoch = 0
    for epoch in tqdm(range(code_config.TASKA_SUMMARY_EPOCHS)):
        avg_train_loss = \
        train_summarization(ds_config, \
                            peft_config, \
                            train_ds, \
                            epoch, \
                            checkpoint_folder)

    shutil.rmtree(checkpoint_folder)

    return best_loss

StackTrace -

[2023-04-25 17:56:09,677] [INFO] [partition_parameters.py:415:__exit__] finished initializing model with 0.00B parameters
[2023-04-25 17:56:09,678] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed info: version=0.8.3, git-hash=unknown, git-branch=unknown
[2023-04-25 17:56:09,816] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2023-04-25 17:56:11,070] [WARNING] [cpu_adam.py:85:__init__] FP16 params for CPUAdam may not work on AMD CPUs
Using /root/.cache/torch_extensions/py38_cu121 as PyTorch extensions root...
No modifications detected for re-loaded extension module cpu_adam, skipping build step...
Loading extension module cpu_adam...
Time to load cpu_adam op: 2.6402900218963623 seconds
Adam Optimizer #13 is created with AVX2 arithmetic capability.
Config: alpha=0.001000, betas=(0.900000, 0.999000), weight_decay=0.010000, adam_w=1
[2023-04-25 17:56:14,687] [INFO] [logging.py:93:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adam as basic optimizer
[2023-04-25 17:56:14,700] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2023-04-25 17:56:14,701] [INFO] [utils.py:55:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2023-04-25 17:56:14,701] [INFO] [logging.py:93:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 3 optimizer
[2023-04-25 17:56:15,020] [INFO] [utils.py:829:see_memory_usage] Stage 3 initialize beginning
[2023-04-25 17:56:15,021] [INFO] [utils.py:830:see_memory_usage] MA 12.46 GB         Max_MA 13.28 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:15,021] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.17 GB, percent = 56.8%
[2023-04-25 17:56:15,023] [INFO] [stage3.py:113:__init__] Reduce bucket size 500,000,000
[2023-04-25 17:56:15,023] [INFO] [stage3.py:114:__init__] Prefetch bucket size 50,000,000
Using /root/.cache/torch_extensions/py38_cu121 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.00042366981506347656 seconds
[2023-04-25 17:56:15,323] [INFO] [utils.py:829:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
[2023-04-25 17:56:15,324] [INFO] [utils.py:830:see_memory_usage] MA 12.46 GB         Max_MA 12.46 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:15,324] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.17 GB, percent = 56.8%
Parameter Offload: Total persistent parameters: 592896 in 232 params
[2023-04-25 17:56:16,307] [INFO] [utils.py:829:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
[2023-04-25 17:56:16,308] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.46 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:16,308] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.64 GB, percent = 57.8%
[2023-04-25 17:56:16,595] [INFO] [utils.py:829:see_memory_usage] Before creating fp16 partitions
[2023-04-25 17:56:16,596] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:16,596] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.64 GB, percent = 57.8%
[2023-04-25 17:56:16,882] [INFO] [utils.py:829:see_memory_usage] After creating fp16 partitions: 1
[2023-04-25 17:56:16,883] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:16,883] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:17,169] [INFO] [utils.py:829:see_memory_usage] Before creating fp32 partitions
[2023-04-25 17:56:17,170] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:17,170] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:17,464] [INFO] [utils.py:829:see_memory_usage] After creating fp32 partitions
[2023-04-25 17:56:17,464] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:17,465] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:17,758] [INFO] [utils.py:829:see_memory_usage] Before initializing optimizer states
[2023-04-25 17:56:17,759] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:17,759] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:18,047] [INFO] [utils.py:829:see_memory_usage] After initializing optimizer states
[2023-04-25 17:56:18,048] [INFO] [utils.py:830:see_memory_usage] MA 12.14 GB         Max_MA 12.14 GB         CA 13.68 GB         Max_CA 14 GB 
[2023-04-25 17:56:18,048] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:18,049] [INFO] [stage3.py:376:_setup_for_real_optimizer] optimizer state initialized
[2023-04-25 17:56:18,440] [INFO] [utils.py:829:see_memory_usage] After initializing ZeRO optimizer
[2023-04-25 17:56:18,441] [INFO] [utils.py:830:see_memory_usage] MA 13.08 GB         Max_MA 13.08 GB         CA 13.2 GB         Max_CA 14 GB 
[2023-04-25 17:56:18,441] [INFO] [utils.py:838:see_memory_usage] CPU Virtual Memory:  used = 27.65 GB, percent = 57.8%
[2023-04-25 17:56:18,442] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed Final Optimizer = adam
[2023-04-25 17:56:18,442] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed using configured LR scheduler = WarmupDecayLR
[2023-04-25 17:56:18,442] [INFO] [logging.py:93:log_dist] [Rank 0] DeepSpeed LR Scheduler = <deepspeed.runtime.lr_schedules.WarmupDecayLR object at 0x7f7123882ee0>
[2023-04-25 17:56:18,442] [INFO] [logging.py:93:log_dist] [Rank 0] step=0, skipped=0, lr=[0.001], mom=[[0.9, 0.999]]
[2023-04-25 17:56:18,443] [INFO] [config.py:1018:print] DeepSpeedEngine configuration:
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   activation_checkpointing_config  {
    "partition_activations": false, 
    "contiguous_memory_optimization": false, 
    "cpu_checkpointing": false, 
    "number_checkpoints": null, 
    "synchronize_checkpoint_boundary": false, 
    "profile": false
}
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   amp_enabled .................. False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   amp_params ................... False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   autotuning_config ............ {
    "enabled": false, 
    "start_step": null, 
    "end_step": null, 
    "metric_path": null, 
    "arg_mappings": null, 
    "metric": "throughput", 
    "model_info": null, 
    "results_dir": "autotuning_results", 
    "exps_dir": "autotuning_exps", 
    "overwrite": true, 
    "fast": true, 
    "start_profile_step": 3, 
    "end_profile_step": 5, 
    "tuner_type": "gridsearch", 
    "tuner_early_stopping": 5, 
    "tuner_num_trials": 50, 
    "model_info_path": null, 
    "mp_size": 1, 
    "max_train_batch_size": null, 
    "min_train_batch_size": 1, 
    "max_train_micro_batch_size_per_gpu": 1.024000e+03, 
    "min_train_micro_batch_size_per_gpu": 1, 
    "num_tuning_micro_batch_sizes": 3
}
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   bfloat16_enabled ............. False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   checkpoint_parallel_write_pipeline  False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   checkpoint_tag_validation_enabled  True
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   checkpoint_tag_validation_fail  False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   comms_config ................. <deepspeed.comm.config.DeepSpeedCommsConfig object at 0x7f7123c0c9a0>
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   communication_data_type ...... None
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   curriculum_enabled_legacy .... False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   curriculum_params_legacy ..... False
[2023-04-25 17:56:18,443] [INFO] [config.py:1022:print]   data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   data_efficiency_enabled ...... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   dataloader_drop_last ......... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   disable_allgather ............ False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   dump_state ................... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   dynamic_loss_scale_args ...... {'init_scale': 65536, 'scale_window': 1000, 'delayed_shift': 2, 'min_scale': 1}
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_enabled ........... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_gas_boundary_resolution  1
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_layer_name ........ bert.encoder.layer
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_layer_num ......... 0
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_max_iter .......... 100
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_stability ......... 1e-06
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_tol ............... 0.01
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   eigenvalue_verbose ........... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   elasticity_enabled ........... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   flops_profiler_config ........ {
    "enabled": false, 
    "profile_step": 1, 
    "module_depth": -1, 
    "top_modules": 1, 
    "detailed": true, 
    "output_file": null
}
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   fp16_auto_cast ............... False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   fp16_enabled ................. True
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   fp16_master_weights_and_gradients  False
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   global_rank .................. 0
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   grad_accum_dtype ............. None
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   gradient_accumulation_steps .. 16
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   gradient_clipping ............ 1.0
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   gradient_predivide_factor .... 1.0
[2023-04-25 17:56:18,444] [INFO] [config.py:1022:print]   initial_dynamic_scale ........ 65536
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   load_universal_checkpoint .... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   loss_scale ................... 0
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   memory_breakdown ............. False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   nebula_config ................ {
    "enabled": false, 
    "persistent_storage_path": null, 
    "persistent_time_interval": 100, 
    "num_of_version_in_retention": 2, 
    "enable_nebula_load": true, 
    "load_path": null
}
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   optimizer_legacy_fusion ...... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   optimizer_name ............... adam
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   optimizer_params ............. {'betas': [0.9, 0.999], 'eps': 1e-06, 'weight_decay': 0.01, 'bias_correction': True}
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   pld_enabled .................. False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   pld_params ................... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   prescale_gradients ........... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   scheduler_name ............... WarmupDecayLR
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   scheduler_params ............. {'warmup_min_lr': 0, 'warmup_type': 'linear', 'total_num_steps': 6497, 'warmup_max_lr': 0.001, 'warmup_num_steps': 650}
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   sparse_attention ............. None
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   sparse_gradients_enabled ..... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   steps_per_print .............. 10
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   train_batch_size ............. 16
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   train_micro_batch_size_per_gpu  1
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   use_node_local_storage ....... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   wall_clock_breakdown ......... False
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   world_size ................... 1
[2023-04-25 17:56:18,445] [INFO] [config.py:1022:print]   zero_allow_untested_optimizer  False
[2023-04-25 17:56:18,446] [INFO] [config.py:1022:print]   zero_config .................. stage=3 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=500,000,000 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='cpu', nvme_path=None, buffer_count=5, buffer_size=100,000,000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='cpu', nvme_path=None, buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=50,000,000 param_persistence_threshold=100,000 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False
[2023-04-25 17:56:18,446] [INFO] [config.py:1022:print]   zero_enabled ................. True
[2023-04-25 17:56:18,446] [INFO] [config.py:1022:print]   zero_force_ds_cpu_optimizer .. True
[2023-04-25 17:56:18,446] [INFO] [config.py:1022:print]   zero_optimization_stage ...... 3
[2023-04-25 17:56:18,446] [INFO] [config.py:1007:print_user_config]   json = {
    "scheduler": {
        "type": "WarmupDecayLR", 
        "params": {
            "warmup_min_lr": 0, 
            "warmup_type": "linear", 
            "total_num_steps": 6.497000e+03, 
            "warmup_max_lr": 0.001, 
            "warmup_num_steps": 650
        }
    }, 
    "optimizer": {
        "type": "Adam", 
        "params": {
            "betas": [0.9, 0.999], 
            "eps": 1e-06, 
            "weight_decay": 0.01, 
            "bias_correction": true
        }
    }, 
    "fp16": {
        "enabled": true, 
        "auto_cast": false, 
        "loss_scale": 0, 
        "initial_scale_power": 16, 
        "loss_scale_window": 1000, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "zero_optimization": {
        "stage": 3, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "offload_param": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "sub_group_size": 1.000000e+09, 
        "reduce_bucket_size": "auto", 
        "stage3_prefetch_bucket_size": "auto", 
        "stage3_param_persistence_threshold": "auto", 
        "stage3_max_live_parameters": 1.000000e+09, 
        "stage3_max_reuse_distance": 1.000000e+09, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "train_micro_batch_size_per_gpu": 1, 
    "gradient_accumulation_steps": 16, 
    "gradient_clipping": 1.0
}
Using /root/.cache/torch_extensions/py38_cu121 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.0003974437713623047 seconds
[2023-04-25 17:56:18,450] [INFO] [torch_checkpoint_engine.py:23:load] [Torch] Loading checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_model_states.pt...
[2023-04-25 17:56:18,458] [INFO] [torch_checkpoint_engine.py:25:load] [Torch] Loaded checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_model_states.pt.
[2023-04-25 17:56:18,458] [INFO] [torch_checkpoint_engine.py:23:load] [Torch] Loading checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_model_states.pt...
[2023-04-25 17:56:18,465] [INFO] [torch_checkpoint_engine.py:25:load] [Torch] Loaded checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_model_states.pt.
[2023-04-25 17:56:18,472] [INFO] [torch_checkpoint_engine.py:23:load] [Torch] Loading checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_optim_states.pt...
[2023-04-25 17:56:18,474] [INFO] [torch_checkpoint_engine.py:25:load] [Torch] Loaded checkpoint from /workspace/3-fold-stratified-cv-biobart-v2-base-peft-deepspeed-zero3-0/global_step559/zero_pp_rank_0_mp_rank_00_optim_states.pt.
[2023-04-25 17:56:18,474] [INFO] [engine.py:3043:_get_all_zero_checkpoint_state_dicts] successfully read 1 ZeRO state_dicts for rank 0
[2023-04-25 17:56:18,493] [INFO] [engine.py:2983:_load_zero_checkpoint] loading 1 zero partition checkpoints for rank 0
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  9%|█████████▍                                                                                                   | 13/150 [51:44<9:05:11, 238.77s/it]
Traceback (most recent call last):
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 580, in <module>
    main()
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 570, in main
    training_loop(model_name, \
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 493, in training_loop
    train_summarization(ds_config, \
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 277, in train_summarization
    output = model_engine(input_ids=input_ids, \
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/utils/nvtx.py", line 11, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1846, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 667, in forward
    return self.base_model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/bart/modeling_bart.py", line 1392, in forward
    lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 84.00 MiB (GPU 0; 14.61 GiB total capacity; 14.06 GiB already allocated; 1.12 MiB free; 14.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[2023-04-25 17:56:25,874] [INFO] [launch.py:318:sigkill_handler] Killing subprocess 143
[2023-04-25 17:56:25,875] [ERROR] [launch.py:324:sigkill_handler] ['/usr/bin/python', '-u', 'Task A - Summarization - Sweep with Deepspeed wo wandb.py', '--local_rank=0'] exits with return code = 1
  1. I am using latest version of Huggingface, Huggingface PEFT, and Deepspeed libraries
  2. Executing training_loop function runs the entire code. However, I am afraid this snipped can't be run because it is missing code-config and dataframes.

Expected behavior
GPU Utilization should not increase

ds_report output
Please run ds_report to give us details about your setup.

Screenshots

As you can see, GPU Usage in ZeRO3 is increasing as compared to ZeRO2. I tried using model_engine.empty_partition_cache() as well but I got an error that empty_partition_cache attribute doesn't exist for model_engine.

ZeRO3 GPU Usage -
Zero3

ZeRO2 GPU Usage -
zero2

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context
Are you using a specific docker image that you can share?

Additional context
Add any other context about the problem here.

@suri-kunal suri-kunal added bug Something isn't working training labels Apr 25, 2023
@tjruwase
Copy link
Contributor

@suri-kunal. thanks for reporting this issue. Can you please add this to the "zero_optimization" section of ds_config

 "memory_efficient_linear": false

@tjruwase
Copy link
Contributor

As you can see, GPU Usage in ZeRO3 is increasing as compared to ZeRO2. I tried using model_engine.empty_partition_cache() as well but I got an error that empty_partition_cache attribute doesn't exist for model_engine.

@suri-kunal, can you please open a new issue for this and share stack trace there? Thanks!

@suri-kunal
Copy link
Author

As you can see, GPU Usage in ZeRO3 is increasing as compared to ZeRO2. I tried using model_engine.empty_partition_cache() as well but I got an error that empty_partition_cache attribute doesn't exist for model_engine.

@suri-kunal, can you please open a new issue for this and share stack trace there? Thanks!

I attached these graphs to prove that GPU memory was indeed increasing. empty_partition_cache was mentioned in some issue (can't find which one) so I mentioned here. I will still open an issue.

@tjruwase
Copy link
Contributor

I attached these graphs to prove that GPU memory was indeed increasing. empty_partition_cache was mentioned in some issue (can't find which one) so I mentioned here. I will still open an issue.

@suri-kunal, my request for this is because empty_partition_cache() failure is unexpected.

@suri-kunal
Copy link
Author

@suri-kunal. thanks for reporting this issue. Can you please add this to the "zero_optimization" section of ds_config

 "memory_efficient_linear": false

I am getting the following error -

Traceback (most recent call last):
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 580, in <module>
    main()
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 570, in main
    training_loop(model_name, \
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 493, in training_loop
    train_summarization(ds_config, \
  File "Task A - Summarization - Sweep with Deepspeed wo wandb.py", line 252, in train_summarization
    model_engine, _, train_dl, _ = deepspeed.initialize(model=model_zero_init,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/__init__.py", line 125, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 272, in __init__
    self._configure_with_arguments(args, mpu)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/engine.py", line 1010, in _configure_with_arguments
    self._config = DeepSpeedConfig(self.config, mpu)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/config.py", line 813, in __init__
    self._initialize_params(copy.copy(self._param_dict))
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/config.py", line 832, in _initialize_params
    self.zero_config = get_zero_config(param_dict)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/config.py", line 67, in get_zero_config
    return DeepSpeedZeroConfig(**zero_config_dict)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/config_utils.py", line 62, in __init__
    super().__init__(**data)
  File "pydantic/main.py", line 341, in pydantic.main.BaseModel.__init__
pydantic.error_wrappers.ValidationError: 1 validation error for DeepSpeedZeroConfig
memory_efficient_linear
  extra fields not permitted (type=value_error.extra)
[2023-04-25 18:50:33,573] [INFO] [launch.py:318:sigkill_handler] Killing subprocess 143
[2023-04-25 18:50:33,574] [ERROR] [launch.py:324:sigkill_handler] ['/usr/bin/python', '-u', 'Task A - Summarization - Sweep with Deepspeed wo wandb.py', '--local_rank=0'] exits with return code = 1

The new ds_config is -

{
  "scheduler": {
    "type": "WarmupDecayLR",
    "params": {
          "warmup_min_lr": 0,
          "warmup_type": "linear"
      }
  },
  "optimizer": {
    "type": "Adam",
    "params": {
      "betas": [
        0.9,
        0.999
      ]
        }
  },
  "fp16": {
    "enabled": true,
    "auto_cast": false,
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "zero_optimization": {
     "stage": 3,
     "offload_optimizer": {
       "device": "cpu",
       "pin_memory": true
     },
     "offload_param": {
       "device": "cpu",
       "pin_memory": true
     },
     "overlap_comm": true,
     "contiguous_gradients": true,
     "sub_group_size": 1.000000e+09,
     "reduce_bucket_size": "auto",
     "stage3_prefetch_bucket_size": "auto",
     "stage3_param_persistence_threshold": "auto",
     "stage3_max_live_parameters": 1.000000e+09,
     "stage3_max_reuse_distance": 1.000000e+09,
     "stage3_gather_16bit_weights_on_model_save": true,
     "memory_efficient_linear": false
  }
}

@tjruwase
Copy link
Contributor

Can you please use latest deepspeed?

@suri-kunal
Copy link
Author

Can you please use latest deepspeed?

Issue resolved!! Thanks for creating this library. Could you please look into #3377 as well?

@suri-kunal
Copy link
Author

suri-kunal commented Apr 25, 2023

@tjruwase, @tohtana I am sorry but the issue still seems to persist. I have even applied model_engine.empty_partition_cache() as well. Please the attached plot -

zero3 old vs new

The 'glamorous-sweep-1' graph is that ZeRO 3 with memory_efficient_linear as True and 'expert-sweep-1' is ZeRO 3 with memory_efficient_linear as False.

Surprisingly GPU utilization remains the same as shown in the following chart -

gpu_util

@tjruwase
Copy link
Contributor

@suri-kunal, thanks for sharing this update. We will investigate further.

@suri-kunal
Copy link
Author

@suri-kunal, thanks for sharing this update. We will investigate further.

Any update on this issue?

@tohtana
Copy link
Contributor

tohtana commented May 7, 2023

@suri-kunal I'm trying to reproduce the problem but haven't succeeded.
Can you provide the entire code including dataset preparation and peft config?

One thing you may need to fix is that you call deepspeed.initialize() for each epoch. You can use the DeepSpeed engine created by deepspeed.initialize() through multiple epochs.

@tohtana
Copy link
Contributor

tohtana commented Jun 2, 2023

Closing this issue because we don't have an update.

Feel free to reopen if you still have this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants