Skip to content

Commit

Permalink
move float8_experimental to torchao/float8
Browse files Browse the repository at this point in the history
Differential Revision: D60409879

Pull Request resolved: pytorch#551
  • Loading branch information
vkuzo authored Jul 30, 2024
1 parent 77c99d1 commit fc92268
Show file tree
Hide file tree
Showing 35 changed files with 7,637 additions and 0 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ In some cases we rewrote popular GenAI models to be significantly faster in nati

### Training

#### Float8

[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.

#### Sparsity

We've added support for semi-structured 2:4 sparsity with 6% end to end speedups on ViT-L

The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
Expand Down
307 changes: 307 additions & 0 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import pandas as pd

import torch
import torch.utils.benchmark as benchmark
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import ScaledMMConfig
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N

# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
h100_peak_flops_fp16_tc = 1979e12
h100_peak_tops_float8_tc = 3958e12

dtype_to_peak_tops = {
torch.float32: h100_peak_flops_float32,
torch.float16: h100_peak_flops_fp16_tc,
torch.bfloat16: h100_peak_flops_fp16_tc,
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
torch.float8_e5m2: h100_peak_tops_float8_tc,
}

# prevent splitting columns when printing a data frame
pd.set_option("display.expand_frame_repr", False)
# print the entire data frame
pd_print_full_ctx = pd.option_context(
"display.max_rows", None, "display.max_columns", None
)


def benchmark_torch_function_in_microseconds(
func: Callable,
*args,
**kwargs,
) -> float:
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.blocked_autorange().median * 1e6


@dataclass
class Experiment:
name: str
shape: Tuple[int, int, int]
ref_time_sec: float
float8_time_sec: float
dtype: torch.dtype
compiled: bool
use_fast_accum: bool
scaling_repr: str

# 3 Times since we are calculating forward backward
@property
def ref_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.ref_time_sec

@property
def ref_pct_top_peak(self):
return self.ref_tops_sec / dtype_to_peak_tops[self.dtype]

@property
def float8_tops_sec(self):
M, K, N = self.shape
return float(3 * (2 * M * K * N)) / self.float8_time_sec

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]


def main(
sweep_path: Optional[Path] = None,
compile: bool = True,
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
):
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}
input_bias = False
if fast_accum_filter is not None:
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
experiment_list: List[Experiment] = []
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)

linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
config=config,
)
scaling_repr = linear_float8.scaling_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
else:
linear_float8.forward_config = ScaledMMConfig(False, False, False)

bsz, seq_len = 4, 4096
M = bsz * seq_len
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

def float8_forw_backward():
if linear_requires_sync(config):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
for _ in range(n):
fn(*args, **kwargs)

return wrapper

REPEAT_N = 100

ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)

if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)

for _ in range(5):
ref_forw_backward()
float8_forw_backward()

ref_time = (
benchmark_torch_function_in_microseconds(ref_forw_backward)
* 1e-6
/ REPEAT_N
)
float8_time = (
benchmark_torch_function_in_microseconds(float8_forw_backward)
* 1e-6
/ REPEAT_N
)
experiment = Experiment(
name,
(M, K, N),
ref_time,
float8_time,
dtype,
compile,
use_fast_accum=fast_accum,
scaling_repr=scaling_repr,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
experiment_list.append(experiment)
torch._dynamo.reset()

headers = [
"name",
"M",
"K",
"N",
"scaling_repr",
"ref_dtype",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
"ref_pct_top_peak",
"pt_fp8_tops_sec",
"pt_fp8_pct_top_peak",
]
data = []
for experiment in experiment_list:
data.append(
[
experiment.name,
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.scaling_repr,
experiment.dtype,
experiment.compiled,
experiment.use_fast_accum,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
experiment.ref_pct_top_peak,
experiment.float8_tops_sec,
experiment.float8_pct_top_peak,
]
)

data_pd = pd.DataFrame(data, columns=headers)
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
data_pd["shape"] = (
"("
+ data_pd["M"].astype(str)
+ ", "
+ data_pd["K"].astype(str)
+ ", "
+ data_pd["N"].astype(str)
+ ")"
)

data_pd_simple = data_pd[
[
"name",
"shape",
"scaling_repr",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
]
]
with pd_print_full_ctx:
print(data_pd_simple)

if sweep_path is not None:
sweep_path = sweep_path.with_suffix(".csv")
data_pd.to_csv(sweep_path)


def invoke_main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output_path", type=str, required=False)
parser.add_argument("--disable_compile", action="store_true")
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--scaling_type_input", type=str, required=False)
parser.add_argument("--scaling_type_weight", type=str, required=False)
parser.add_argument("--scaling_type_grad_output", type=str, required=False)
args = parser.parse_args()
output_path = Path(args.output_path) if args.output_path is not None else None
kwargs = {}
if args.scaling_type_input is not None:
kwargs["scaling_type_input"] = args.scaling_type_input
if args.scaling_type_weight is not None:
kwargs["scaling_type_weight"] = args.scaling_type_weight
if args.scaling_type_grad_output is not None:
kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output
main(
output_path,
not args.disable_compile,
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
**kwargs,
)


if __name__ == "__main__":
invoke_main() # pragma: no cover
Loading

0 comments on commit fc92268

Please sign in to comment.