Skip to content

Commit

Permalink
add autotune (#4822)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xu-Kai authored Sep 28, 2023
1 parent 822051d commit c3bef20
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 5 deletions.
176 changes: 176 additions & 0 deletions colossalai/kernel/triton/custom_autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# code from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/nn_modules/triton_utils/custom_autotune.py

import builtins
import math
import time
from typing import Dict

import triton


class CustomizedTritonAutoTuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]

def _hook(args):
for i in self.reset_idx:
args[i].zero_()

self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by["perf_model"], prune_configs_by["top_k"]
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn

def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)

def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)

try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
except triton.compiler.OutOfResources:
return (float("inf"), float("inf"), float("inf"))

def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)

# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])

if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs

def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None


def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
def decorator(fn):
return CustomizedTritonAutoTuner(
fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two
)

return decorator


def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)

used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]

if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
continue

used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)
11 changes: 6 additions & 5 deletions colossalai/kernel/triton/gptq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
import triton
import triton.language as tl
from auto_gptq.nn_modules.triton_utils import custom_autotune

from .custom_autotune import autotune, matmul248_kernel_config_pruner


@triton.jit
Expand Down Expand Up @@ -94,7 +95,7 @@ def silu(x):
return x * tl.sigmoid(x)


@custom_autotune.autotune(
@autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
Expand Down Expand Up @@ -124,7 +125,7 @@ def silu(x):
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"early_config_prune": matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
Expand Down Expand Up @@ -266,7 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)


@custom_autotune.autotune(
@autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
Expand Down Expand Up @@ -296,7 +297,7 @@ def cai_gptq_matmul_248_kernel(
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"early_config_prune": matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
Expand Down

0 comments on commit c3bef20

Please sign in to comment.