diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 46865e1b556e..f675e052acf6 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -83,10 +83,21 @@ jobs: python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ python/test/unit/language/test_conversions.py \ + python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_decorator.py \ + python/test/unit/language/test_pipeliner.py \ + python/test/unit/language/test_random.py \ + python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_driver.py \ + python/test/unit/runtime/test_jit.py \ + python/test/unit/runtime/test_launch.py \ + python/test/unit/runtime/test_subproc.py \ + python/test/unit/runtime/test_autotuner.py \ + python/test/unit/runtime/test_cache.py \ python/test/unit/cpu/test_libdevice.py \ python/test/unit/cpu/test_libmvec.py \ - python/test/unit/cpu/test_opt.py \ - python/test/unit/runtime/test_autotuner.py + python/test/unit/cpu/test_opt.py - name: Run lit tests run: | diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index 7f148ae8daea..24c3e0408bbd 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -5,6 +5,7 @@ import triton import triton.language as tl import triton.tools.experimental_descriptor +from test_core import is_cpu def is_cuda(): @@ -127,7 +128,10 @@ def test_pipeline_matmul(device): handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES) - ref_out = torch.matmul(a, b) + if is_cpu(): + ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16) + else: + ref_out = torch.matmul(a, b) atol = 1e-2 if is_hip_mi200() else None # Bigger tolerance for AMD MI200 devices. # MI200 devices use reduced precision fp16 and bf16 and flush input and diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index b3392d4750c4..2938773867cc 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -3,7 +3,7 @@ import torch import triton.language as tl -from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu # --------------- # test maximum/minimum ops @@ -26,7 +26,8 @@ def test_maximum_minium(dtype, op, device): @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("descending", [False, True]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_sort(M, N, descending, dtype_str, device): @@ -54,7 +55,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_flip(M, N, dtype_str, device): diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 8a3403bc1df1..40a6105057db 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -8,7 +8,7 @@ import triton import triton.language as tl -from triton.runtime.jit import JITFunction +from triton.runtime.jit import JITFunction, get_device_key @triton.jit @@ -193,12 +193,12 @@ def kernel(X, i: tl.int32): x = torch.empty(1, dtype=torch.int32, device=device) - device = getattr(torch, device).current_device() + device_key = get_device_key() kernel[(1, )](x, 1) kernel[(1, )](x, 8) kernel[(1, )](x, 16) kernel[(1, )](x, 17) - assert len(kernel.cache[device]) == 3 + assert len(kernel.cache[device_key]) == 3 GLOBAL_DEFAULT_ARG = 1 @@ -221,7 +221,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): kernel[(1, )](x) assert x == torch.ones_like(x) - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel.cache[device]) == 1 @@ -414,7 +414,7 @@ def kernel_add(a, b, o, N: tl.constexpr): torch.randn(32, dtype=torch.float32, device=device), 32, ] - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel_add.cache[device]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 @@ -432,17 +432,20 @@ def kernel_add(a, b, o, N: tl.constexpr): tl.device_assert(idx < 32, "idx < 32") tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) - device = getattr(torch, device).current_device() - assert len(kernel_add.cache[device]) == 0 + if device == "cpu": + pytest.skip('Device Assert is not yet supported on CPU') + + device_key = get_device_key() + assert len(kernel_add.cache[device_key]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 1 + assert len(kernel_add.cache[device_key]) == 1 kernel_add.debug = False kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 2 + assert len(kernel_add.cache[device_key]) == 2 kernel_add.debug = True kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 3 - bins = list(kernel_add.cache[device].values()) + assert len(kernel_add.cache[device_key]) == 3 + bins = list(kernel_add.cache[device_key].values()) assert bins[2].asm['ttir'] != bins[1].asm['ttir'] @@ -458,7 +461,7 @@ def test_jit_noinline(device) -> None: def kernel_add_device(a, b, o, N: tl.constexpr): add_fn(a, b, o, N) - device = getattr(torch, device).current_device() + device = get_device_key() assert len(kernel_add_device.cache[device]) == 0 kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 @@ -502,7 +505,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): tl.device_assert(idx < 32, "idx < 32") tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) - device = getattr(torch, device).current_device() + device = get_device_key() # get the serialized specialization data specialization_data = None diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 95f469ec1f5a..bccde9150ccb 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -439,6 +439,12 @@ def create_function_from_signature(sig, kparams): type_canonicalisation_dict[v] = v +def get_device_key(): + target = driver.active.get_current_target() + device = driver.active.get_current_device() + return f"{target.backend}:{device}" + + class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None @@ -614,7 +620,7 @@ def run(self, *args, grid, warmup, **kwargs): bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) # compute cache key - device_key = f"{target.backend}:{device}" + device_key = get_device_key() key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) kernel = self.cache[device_key].get(key, None) @@ -767,7 +773,7 @@ def preload(self, specialization_data): from ..compiler import AttrsDescriptor, compile, ASTSource import json import triton.language as tl - device = driver.active.get_current_device() + device_key = get_device_key() deserialized_obj = json.loads(specialization_data) if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( @@ -784,7 +790,7 @@ def preload(self, specialization_data): } key = deserialized_obj['key'] kernel = compile(src, None, options) - self.cache[device][key] = kernel + self.cache[device_key][key] = kernel return kernel # we do not parse `src` in the constructor because