Skip to content

Commit

Permalink
fix grad scaler import
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed May 13, 2024
1 parent d20e810 commit 1e1d4a0
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import contextlib
import threading
from copy import deepcopy
from typing import Dict, Optional
from typing import Any, Dict, Optional

import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
from torch.optim import Optimizer as TorchOptimizer

import hivemind
from hivemind.utils.logging import get_logger

if torch.cuda.is_available():
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
else:
# on cpu the import is not working, so just copy pasting the code here as it is simple
# code taken from here : https://github.com/pytorch/pytorch/blob/afda6685ae87cce7ac2fe4bac3926572da2960f7/torch/amp/grad_scaler.py#L44

from enum import Enum

class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2

def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}


logger = get_logger(__name__)


Expand Down

0 comments on commit 1e1d4a0

Please sign in to comment.