Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update float8 integration after UX changes #484

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
import functools
from typing import Optional

Expand All @@ -24,20 +23,6 @@
from torchtitan.logging_utils import logger


@contextlib.contextmanager
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
import float8_experimental.config as config

prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
Expand All @@ -63,21 +48,26 @@ def maybe_build_fp8_linear(
)
return
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
from float8_experimental import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model,
scaling_type_w=TensorScalingType.DYNAMIC,
skip_fqn_list=["output"],
)
float8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
)
convert_to_float8_training(
model,
config=float8_config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
)
Expand All @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
Loading