Skip to content

Commit

Permalink
[Pytest] Support CPU device
Browse files Browse the repository at this point in the history
Enable suits for cpu device.
- language
  - python/test/unit/language/test_random.py
  - python/test/unit/language/test_standard.py
- runtime
  - python/test/unit/runtime/test_bindings.py
  - python/test/unit/runtime/test_cache.py
  - python/test/unit/runtime/test_driver.py
  - python/test/unit/runtime/test_jit.py
  - python/test/unit/runtime/test_launch.py

python/test/unit/runtime/test_cache.py expects creation and usage of
files, that doesn't works with multiple workers.

Signed-off-by: Dmitrii Makarenko <[email protected]>
  • Loading branch information
Devjiu committed Aug 14, 2024
1 parent 7af3b7e commit 3ccf910
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 28 deletions.
13 changes: 10 additions & 3 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,18 @@ jobs:
python/test/unit/language/test_compile_errors.py \
python/test/unit/language/test_decorator.py \
python/test/unit/language/test_pipeliner.py \
python/test/unit/test_random.py \
python/test/unit/language/test_random.py \
python/test/unit/language/test_standard.py \
python/test/unit/runtime/test_bindings.py \
python/test/unit/runtime/test_driver.py \
python/test/unit/runtime/test_jit.py \
python/test/unit/runtime/test_launch.py \
python/test/unit/runtime/test_subproc.py \
python/test/unit/runtime/test_autotuner.py \
python/test/unit/runtime/test_cache.py \
python/test/unit/cpu/test_libdevice.py \
python/test/unit/cpu/test_libmvec.py \
python/test/unit/cpu/test_opt.py \
python/test/unit/runtime/test_autotuner.py
python/test/unit/cpu/test_opt.py
- name: Run lit tests
run: |
Expand Down
6 changes: 5 additions & 1 deletion python/test/unit/language/test_pipeliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl
import triton.tools.experimental_descriptor
from test_core import is_cpu


def is_cuda():
Expand Down Expand Up @@ -127,7 +128,10 @@ def test_pipeline_matmul(device):
handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
NUM_STAGES=NUM_STAGES)
ref_out = torch.matmul(a, b)
if is_cpu():
ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16)
else:
ref_out = torch.matmul(a, b)
atol = 1e-2 if is_hip_mi200() else None
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
Expand Down
8 changes: 5 additions & 3 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton.language as tl

from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random
from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu

# ---------------
# test maximum/minimum ops
Expand All @@ -26,7 +26,8 @@ def test_maximum_minium(dtype, op, device):


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_sort(M, N, descending, dtype_str, device):
Expand Down Expand Up @@ -54,7 +55,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_flip(M, N, dtype_str, device):

Expand Down
33 changes: 18 additions & 15 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import triton
import triton.language as tl
from triton.runtime.jit import JITFunction
from triton.runtime.jit import JITFunction, get_device_key


@triton.jit
Expand Down Expand Up @@ -193,12 +193,12 @@ def kernel(X, i: tl.int32):

x = torch.empty(1, dtype=torch.int32, device=device)

device = torch.cuda.current_device()
device_key = get_device_key()
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
kernel[(1, )](x, 17)
assert len(kernel.cache[device]) == 3
assert len(kernel.cache[device_key]) == 3


GLOBAL_DEFAULT_ARG = 1
Expand All @@ -221,7 +221,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
kernel[(1, )](x)
assert x == torch.ones_like(x)

device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel.cache[device]) == 1


Expand Down Expand Up @@ -414,7 +414,7 @@ def kernel_add(a, b, o, N: tl.constexpr):
torch.randn(32, dtype=torch.float32, device=device),
32,
]
device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -424,25 +424,28 @@ def kernel_add(a, b, o, N: tl.constexpr):
assert len(kernel_add.cache[device]) == 1


def test_jit_debug() -> None:
def test_jit_debug(device) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
if device == "cpu":
pytest.skip('Device Assert is not yet supported on CPU')

device_key = get_device_key()
assert len(kernel_add.cache[device_key]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
assert len(kernel_add.cache[device_key]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 2
assert len(kernel_add.cache[device_key]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert len(kernel_add.cache[device_key]) == 3
bins = list(kernel_add.cache[device_key].values())
assert bins[2].asm['ttir'] != bins[1].asm['ttir']


Expand All @@ -452,13 +455,13 @@ def add_fn(a, b, o, N: tl.constexpr):
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))


def test_jit_noinline() -> None:
def test_jit_noinline(device) -> None:

@triton.jit
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)

device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
Expand Down Expand Up @@ -502,7 +505,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr):
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx))

device = torch.cuda.current_device()
device = get_device_key()

# get the serialized specialization data
specialization_data = None
Expand Down
6 changes: 3 additions & 3 deletions python/test/unit/runtime/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def kernel(x):
assert used_hook


def test_memory_leak() -> None:
def test_memory_leak(device) -> None:

@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
Expand All @@ -57,8 +57,8 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

tracemalloc.start()
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
inp = torch.randn(10, device=device)
out = torch.randn(10, device=device)
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
Expand Down
12 changes: 9 additions & 3 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ def create_function_from_signature(sig, kparams):
type_canonicalisation_dict[v] = v


def get_device_key():
target = driver.active.get_current_target()
device = driver.active.get_current_device()
return f"{target.backend}:{device}"


class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
Expand Down Expand Up @@ -605,7 +611,7 @@ def run(self, *args, grid, warmup, **kwargs):
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)

# compute cache key
device_key = f"{target.backend}:{device}"
device_key = get_device_key()
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
kernel = self.cache[device_key].get(key, None)

Expand Down Expand Up @@ -757,7 +763,7 @@ def preload(self, specialization_data):
from ..compiler import AttrsDescriptor, compile, ASTSource
import json
import triton.language as tl
device = driver.active.get_current_device()
device_key = get_device_key()
deserialized_obj = json.loads(specialization_data)
if deserialized_obj['name'] != self.fn.__name__:
raise RuntimeError(
Expand All @@ -774,7 +780,7 @@ def preload(self, specialization_data):
}
key = deserialized_obj['key']
kernel = compile(src, None, options)
self.cache[device][key] = kernel
self.cache[device_key][key] = kernel
return kernel

# we do not parse `src` in the constructor because
Expand Down

0 comments on commit 3ccf910

Please sign in to comment.