Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero] fix missing hook removal #5824

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import weakref
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -94,20 +95,27 @@ def __init__(
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
self.grad_handles = []
if self._overlap_communication or self._partition_grad:
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
param_group = self.working_param_group
for param in param_group:
if param.requires_grad:
self_weak_proxy = weakref.proxy(self)
param_weak_proxy = weakref.proxy(param)

def _grad_handler(grad, param):
def _grad_handler(grad):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param)
if self_weak_proxy.require_grad_sync:
self_weak_proxy._add_to_bucket(param_weak_proxy)
return grad

param.register_hook(partial(_grad_handler, param=param))
self.grad_handles.append(param.register_post_accumulate_grad_hook(partial(_grad_handler)))

def __del__(self):
for handle in self.grad_handles:
handle.remove()

def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
Expand Down
Loading