Skip to content

Commit

Permalink
[zero] fix missing hook removal
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Jun 17, 2024
1 parent a10802e commit f8cdf5f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,26 @@ def __init__(
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
self.grad_handles = []
if self._overlap_communication or self._partition_grad:
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
param_group = self.working_param_group
for param in param_group:
if param.requires_grad:

def _grad_handler(grad, param):
def _grad_handler(grad):
# if run with no_sync context, would not sync grad when backward
# see LowLevelOptStrategyBase.__del__ for hook removal
if self.require_grad_sync:
self._add_to_bucket(param)
return grad

param.register_hook(partial(_grad_handler, param=param))
self.grad_handles.append(param.register_post_accumulate_grad_hook(partial(_grad_handler)))

def __del__(self):
for handle in self.grad_handles:
handle.remove()

def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
Expand Down

0 comments on commit f8cdf5f

Please sign in to comment.