Skip to content

Commit

Permalink
Try to use triton.CompiledKernel launch hooks
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Sep 25, 2024
1 parent 6955edf commit f5ce852
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
import torch
import triton

di = torch._dynamo.device_interface.get_interface_for_device(device_type)

Expand Down Expand Up @@ -196,9 +197,27 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
n_repeat = max(1, int(rep / estimate_ms))
start_event = [Event(enable_timing=True) for i in range(n_repeat)]
end_event = [Event(enable_timing=True) for i in range(n_repeat)]

# prepare hooks
counter = 0

def enter_hook(launch_metadata):
nonlocal counter
start_event[counter].record()

def exit_hook(launch_metadata):
nonlocal counter
end_event[counter].record()
counter += 1

# Warm-up
for _ in range(n_warmup):
fn()

# setup hooks
triton.compiler.CompiledKernel.launch_enter_hook = enter_hook
triton.compiler.CompiledKernel.launch_exit_hook = exit_hook

# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
Expand All @@ -210,14 +229,17 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
if USE_WALL_TIME:
di.synchronize()
end_event[i].record()
# Record clocks
if not USE_WALL_TIME:
di.synchronize()

# remove hooks
triton.compiler.CompiledKernel.launch_enter_hook = None
triton.compiler.CompiledKernel.launch_exit_hook = None

times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)

Expand Down

0 comments on commit f5ce852

Please sign in to comment.