From 9de501e38a545aa0107ef89c72a1bb3c846992a3 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:10:31 +0800 Subject: [PATCH] [Muti_backend]part_2_device (#344) * new feature, muti_backend * update auto_tune_module * update auto_tune_module * update auto_tune_module * update __init__ * rebase * fix bug * modifiy auto_tune_config * fix bug * fix bug * update * update * update scatter&gather * fix auto_tune * add gen_torch_device_fn * fix codestyle * fix codestyle * Modify code based on comments * Modify gen_impl with loops instead of recursion * Update code structure * Polish code * update * Polish code * Modify code based on comments * modify based on comment * Modify code based on comments * update * final fix * modify_device_unify * fix bug * update * remove '-> list' * modify * modify * modify * fix bug --- README.md | 6 +- README_cn.md | 6 +- benchmark/conftest.py | 15 +- benchmark/performance_utils.py | 12 +- examples/model_bert_test.py | 8 +- examples/model_llama_test.py | 6 +- examples/model_llava_test.py | 8 +- src/flag_gems/fused/gelu_and_mul.py | 11 +- src/flag_gems/fused/rotary_embedding.py | 3 +- src/flag_gems/fused/skip_layernorm.py | 3 +- src/flag_gems/fused/skip_rms_norm.py | 3 +- src/flag_gems/ops/addmm.py | 3 +- src/flag_gems/ops/all.py | 7 +- src/flag_gems/ops/amax.py | 5 +- src/flag_gems/ops/any.py | 7 +- src/flag_gems/ops/arange.py | 5 +- src/flag_gems/ops/argmax.py | 5 +- src/flag_gems/ops/bmm.py | 3 +- src/flag_gems/ops/cross_entropy_loss.py | 7 +- src/flag_gems/ops/cumsum.py | 17 +- src/flag_gems/ops/diag.py | 5 +- src/flag_gems/ops/div.py | 12 +- src/flag_gems/ops/dropout.py | 13 +- src/flag_gems/ops/embedding.py | 9 +- src/flag_gems/ops/eq.py | 5 +- src/flag_gems/ops/exponential_.py | 11 +- src/flag_gems/ops/fill.py | 5 +- src/flag_gems/ops/full.py | 3 +- src/flag_gems/ops/full_like.py | 3 +- src/flag_gems/ops/gelu.py | 12 +- src/flag_gems/ops/groupnorm.py | 14 +- src/flag_gems/ops/isclose.py | 18 +- src/flag_gems/ops/isfinite.py | 18 +- src/flag_gems/ops/isin.py | 5 +- src/flag_gems/ops/isinf.py | 9 +- src/flag_gems/ops/isnan.py | 9 +- src/flag_gems/ops/layernorm.py | 5 +- src/flag_gems/ops/log_softmax.py | 5 +- src/flag_gems/ops/masked_select.py | 3 +- src/flag_gems/ops/max.py | 5 +- src/flag_gems/ops/maximum.py | 5 +- src/flag_gems/ops/mean.py | 5 +- src/flag_gems/ops/min.py | 5 +- src/flag_gems/ops/minimum.py | 5 +- src/flag_gems/ops/mm.py | 3 +- src/flag_gems/ops/multinomial.py | 4 +- src/flag_gems/ops/mv.py | 3 +- src/flag_gems/ops/nonzero.py | 3 +- src/flag_gems/ops/normal.py | 7 +- src/flag_gems/ops/ones.py | 7 +- src/flag_gems/ops/ones_like.py | 3 +- src/flag_gems/ops/pad.py | 3 +- src/flag_gems/ops/pow.py | 9 +- src/flag_gems/ops/prod.py | 5 +- src/flag_gems/ops/rand.py | 15 +- src/flag_gems/ops/rand_like.py | 8 +- src/flag_gems/ops/randn.py | 16 +- src/flag_gems/ops/randn_like.py | 8 +- src/flag_gems/ops/randperm.py | 19 +- src/flag_gems/ops/repeat.py | 3 +- src/flag_gems/ops/rms_norm.py | 3 +- src/flag_gems/ops/sigmoid.py | 9 +- src/flag_gems/ops/silu.py | 9 +- src/flag_gems/ops/softmax.py | 9 +- src/flag_gems/ops/sum.py | 5 +- src/flag_gems/ops/tanh.py | 18 +- src/flag_gems/ops/tile.py | 3 +- src/flag_gems/ops/topk.py | 5 +- src/flag_gems/ops/triu.py | 3 +- src/flag_gems/ops/uniform.py | 12 +- src/flag_gems/ops/unique.py | 9 +- src/flag_gems/ops/upsample_bicubic2d_aa.py | 5 +- src/flag_gems/ops/upsample_nearest2d.py | 5 +- src/flag_gems/ops/var_mean.py | 5 +- src/flag_gems/ops/vector_norm.py | 14 +- src/flag_gems/ops/vstack.py | 3 +- src/flag_gems/ops/weightnorm.py | 13 +- src/flag_gems/ops/zeros.py | 7 +- src/flag_gems/ops/zeros_like.py | 3 +- src/flag_gems/runtime/__init__.py | 11 +- src/flag_gems/runtime/backend/__init__.py | 50 ++++- .../runtime/backend/_nvidia/ops/add.py | 3 +- src/flag_gems/runtime/backend/device.py | 6 +- src/flag_gems/runtime/moduel_tool.py | 17 ++ src/flag_gems/runtime/register.py | 26 +-- src/flag_gems/utils/libentry.py | 4 +- src/flag_gems/utils/pointwise_dynamic.py | 5 +- src/flag_gems/utils/random_utils.py | 8 +- tests/conftest.py | 8 +- tests/ks_tests.py | 14 +- tests/test_binary_pointwise_ops.py | 212 +++++++++--------- tests/test_blas_ops.py | 22 +- tests/test_distribution_ops.py | 16 +- tests/test_general_reduction_ops.py | 40 ++-- tests/test_libentry.py | 10 +- tests/test_norm_ops.py | 58 +++-- tests/test_pointwise_dynamic.py | 90 ++++---- tests/test_pointwise_type_promotion.py | 24 +- tests/test_reduction_ops.py | 80 +++---- tests/test_special_ops.py | 122 +++++----- tests/test_tensor_constructor_ops.py | 44 ++-- tests/test_tensor_wrapper.py | 13 +- tests/test_unary_pointwise_ops.py | 74 +++--- 103 files changed, 843 insertions(+), 699 deletions(-) create mode 100644 src/flag_gems/runtime/moduel_tool.py diff --git a/README.md b/README.md index 9a0b86f4c..03f4c6c15 100644 --- a/README.md +++ b/README.md @@ -138,8 +138,8 @@ pip install . import flag_gems M, N, K = 1024, 1024, 1024 - A = torch.randn((M, K), dtype=torch.float16, device="cuda") - B = torch.randn((K, N), dtype=torch.float16, device="cuda") + A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device) + B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device) with flag_gems.use_gems(): C = torch.mm(A, B) ``` @@ -147,7 +147,7 @@ pip install . ### Execute 1. Test Operator Accuracy - - Run reference on cuda + - Run reference on specific backend like cuda ```shell cd tests pytest test_xx_ops.py diff --git a/README_cn.md b/README_cn.md index 1851895cd..79f417cae 100644 --- a/README_cn.md +++ b/README_cn.md @@ -136,8 +136,8 @@ pip install . import flag_gems M, N, K = 1024, 1024, 1024 - A = torch.randn((M, K), dtype=torch.float16, device="cuda") - B = torch.randn((K, N), dtype=torch.float16, device="cuda") + A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device) + B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device) with flag_gems.use_gems(): C = torch.mm(A, B) ``` @@ -145,7 +145,7 @@ pip install . ### 执行 1. 算子正确性测试 - - 在CUDA上运行参考实现 + - 在例如CUDA的异构设备上运行参考实现 ```shell cd tests pytest test_xx_ops.py diff --git a/benchmark/conftest.py b/benchmark/conftest.py index 0039b2c16..fbb2b4b1d 100644 --- a/benchmark/conftest.py +++ b/benchmark/conftest.py @@ -5,6 +5,9 @@ import pytest import torch +import flag_gems +from flag_gems.runtime import torch_device_fn + from .attri_util import ( ALL_AVAILABLE_METRICS, BOOL_DTYPES, @@ -17,6 +20,8 @@ get_recommended_shapes, ) +device = flag_gems.device + class BenchConfig: def __init__(self): @@ -38,12 +43,12 @@ def pytest_addoption(parser): parser.addoption( "--mode", action="store", - default="cuda", + default=device, required=False, - choices=["cuda", "cpu"], + choices=[device, "cpu"], help=( "Specify how to measure latency, " - "'cpu' for CPU-side measurement or 'cuda' for GPU-side measurement." + f"'cpu' for CPU-side measurement or {device} for GPU-side measurement." ), ) @@ -186,13 +191,13 @@ def setup_once(request): @pytest.fixture(scope="function", autouse=True) def clear_function_cache(): yield - torch.cuda.empty_cache() + torch_device_fn.empty_cache() @pytest.fixture(scope="module", autouse=True) def clear_module_cache(): yield - torch.cuda.empty_cache() + torch_device_fn.empty_cache() @pytest.fixture() diff --git a/benchmark/performance_utils.py b/benchmark/performance_utils.py index db1284d2f..b7492d4d5 100644 --- a/benchmark/performance_utils.py +++ b/benchmark/performance_utils.py @@ -9,6 +9,7 @@ import yaml import flag_gems +from flag_gems.runtime import torch_backend_device, torch_device_fn from .attri_util import ( BOOL_DTYPES, @@ -24,11 +25,12 @@ ) from .conftest import Config -torch.backends.cuda.matmul.allow_tf32 = False +torch_backend_device.matmul.allow_tf32 = False +device = flag_gems.device class Benchmark: - device: str = "cuda" + device: str = device DEFAULT_METRICS = DEFAULT_METRICS DEFAULT_DTYPES = FLOAT_DTYPES DEFAULT_SHAPES = DEFAULT_SHAPES @@ -191,11 +193,11 @@ def get_latency(self, op, *args, **kwargs): if Config.cpu_mode: for i in range(Config.warm_up): fn() - torch.cuda.synchronize() + torch_device_fn.synchronize() start = time.time() for i in range(Config.repetition): fn() - torch.cuda.synchronize() + torch_device_fn.synchronize() end = time.time() latency = (end - start) / Config.repetition * 1000 else: @@ -309,7 +311,7 @@ def run(self): level=Config.bench_level.value, op_name=self.op_name, dtype=str(dtype), - mode="cpu" if Config.cpu_mode else "cuda", + mode="cpu" if Config.cpu_mode else device, result=metrics, ) print(result) diff --git a/examples/model_bert_test.py b/examples/model_bert_test.py index 7b49b6254..6db8393ff 100644 --- a/examples/model_bert_test.py +++ b/examples/model_bert_test.py @@ -6,6 +6,8 @@ import flag_gems +device = flag_gems.device + @pytest.mark.parametrize( "prompt", @@ -16,16 +18,16 @@ def test_accuracy_bert(prompt, dtype): config = BertConfig() model = BertModel(config) tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") - inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + inputs = tokenizer(prompt, return_tensors="pt").to(device) ref_model = copy.deepcopy(model) - ref_model.to(torch.float64).to("cuda").eval() + ref_model.to(torch.float64).to(device).eval() ref_inputs = copy.deepcopy(inputs).to(torch.float64) with torch.no_grad(): ref_outputs = ref_model(**ref_inputs).last_hidden_state.to(dtype) res_model = copy.deepcopy(model) - res_model.to(dtype).to("cuda").eval() + res_model.to(dtype).to(device).eval() res_inputs = copy.deepcopy(inputs).to(dtype) with flag_gems.use_gems(): with torch.no_grad(): diff --git a/examples/model_llama_test.py b/examples/model_llama_test.py index 5b2642278..9400455d3 100644 --- a/examples/model_llama_test.py +++ b/examples/model_llama_test.py @@ -4,6 +4,8 @@ import flag_gems +device = flag_gems.device + @pytest.mark.parametrize( "prompt", @@ -13,8 +15,8 @@ def test_accuracy_llama(prompt): tokenizer = AutoTokenizer.from_pretrained("sharpbai/Llama-2-7b-hf") model = AutoModelForCausalLM.from_pretrained("sharpbai/Llama-2-7b-hf") - model.to("cuda").eval() - inputs = tokenizer(prompt, return_tensors="pt").to(device="cuda") + model.to(device).eval() + inputs = tokenizer(prompt, return_tensors="pt").to(device=device) with torch.no_grad(): ref_output = model.generate(**inputs, max_length=100, num_beams=5) diff --git a/examples/model_llava_test.py b/examples/model_llava_test.py index 472e1c6eb..15a220dca 100644 --- a/examples/model_llava_test.py +++ b/examples/model_llava_test.py @@ -6,6 +6,8 @@ import flag_gems +device = flag_gems.device + @pytest.mark.parametrize( "prompt", ["USER: \nWhat's the content of the image? ASSISTANT:"] @@ -22,9 +24,11 @@ def test_accuracy_llava(prompt, url): model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") torch.manual_seed(1234) - model.to("cuda").eval() + model.to(device).eval() image = Image.open(requests.get(url, stream=True).raw) - inputs = processor(text=prompt, images=image, return_tensors="pt").to(device="cuda") + inputs = processor(text=prompt, images=image, return_tensors="pt").to( + device=flag_gems.device + ) with torch.no_grad(): ref_output = model(**inputs).logits diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index c4dea3e0e..da123bb63 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -4,15 +4,12 @@ import triton import triton.language as tl +from ..runtime.moduel_tool import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import erf, pow, tanh -except ImportError: - try: - from triton.language.math import erf, pow, tanh - except ImportError: - from triton.language.libdevice import erf, pow, tanh +erf = tl_extra_module.erf +pow = tl_extra_module.pow +tanh = tl_extra_module.tanh @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) diff --git a/src/flag_gems/fused/rotary_embedding.py b/src/flag_gems/fused/rotary_embedding.py index d2e8f2fc5..93086578d 100644 --- a/src/flag_gems/fused/rotary_embedding.py +++ b/src/flag_gems/fused/rotary_embedding.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -163,7 +164,7 @@ def apply_rotary_pos_emb( padded_head_dim = max(triton.next_power_of_2(head_dim), 16) grid = (n_tokens,) - with torch.cuda.device(q_embed.device): + with torch_device_fn.device(q_embed.device): apply_rotary_pos_emb_kernel[grid]( q_embed, k_embed, diff --git a/src/flag_gems/fused/skip_layernorm.py b/src/flag_gems/fused/skip_layernorm.py index a6b3f9d8f..5f9396b21 100644 --- a/src/flag_gems/fused/skip_layernorm.py +++ b/src/flag_gems/fused/skip_layernorm.py @@ -5,6 +5,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -71,7 +72,7 @@ def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): bias = bias.contiguous() y = torch.empty_like(x) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): skip_layer_norm_kernel[M,]( y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE ) diff --git a/src/flag_gems/fused/skip_rms_norm.py b/src/flag_gems/fused/skip_rms_norm.py index a70b50c33..162a1e009 100644 --- a/src/flag_gems/fused/skip_rms_norm.py +++ b/src/flag_gems/fused/skip_rms_norm.py @@ -5,6 +5,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -60,7 +61,7 @@ def forward(ctx, x, residual, normalized_shape, weight, eps=1e-5): weight = weight.contiguous() y = torch.empty_like(x) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): skip_rms_norm_kernel[M,]( y, x, residual, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE ) diff --git a/src/flag_gems/ops/addmm.py b/src/flag_gems/ops/addmm.py index 867cbd32f..fc7785df1 100644 --- a/src/flag_gems/ops/addmm.py +++ b/src/flag_gems/ops/addmm.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -85,7 +86,7 @@ def addmm(bias, mat1, mat2, *, beta=1, alpha=1): triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - with torch.cuda.device(mat1.device): + with torch_device_fn.device(mat1.device): addmm_kernel[grid]( mat1, mat2, diff --git a/src/flag_gems/ops/all.py b/src/flag_gems/ops/all.py index 8e3fe7804..344ba617f 100644 --- a/src/flag_gems/ops/all.py +++ b/src/flag_gems/ops/all.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -89,7 +90,7 @@ def all(inp): mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) out = torch.empty([], dtype=torch.bool, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) @@ -114,7 +115,7 @@ def all_dim(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): all_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) @@ -140,7 +141,7 @@ def all_dims(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): all_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/ops/amax.py b/src/flag_gems/ops/amax.py index 6d0e0375b..975c57e65 100644 --- a/src/flag_gems/ops/amax.py +++ b/src/flag_gems/ops/amax.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -86,7 +87,7 @@ def amax(inp, dim=None, keepdim=False): for i in range(0, inp.dim()): shape[i] = 1 out = torch.empty(shape, dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): amax_kernel_1[(mid_size, 1)]( inp, mid, @@ -115,7 +116,7 @@ def amax(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=dtype, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): amax_kernel[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/ops/any.py b/src/flag_gems/ops/any.py index 9fcf9e61b..150ba5353 100644 --- a/src/flag_gems/ops/any.py +++ b/src/flag_gems/ops/any.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -89,7 +90,7 @@ def any(inp): mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) out = torch.empty([], dtype=torch.bool, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) any_kernel_2[(1, 1)](mid, out, mid_size, block_mid) @@ -114,7 +115,7 @@ def any_dim(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): any_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) @@ -140,7 +141,7 @@ def any_dims(inp, dim=None, keepdim=False): out = torch.empty(shape, dtype=torch.bool, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): any_kernel_dim[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/ops/arange.py b/src/flag_gems/ops/arange.py index 6c98613e7..aa8b35d0f 100644 --- a/src/flag_gems/ops/arange.py +++ b/src/flag_gems/ops/arange.py @@ -5,6 +5,7 @@ import triton import triton.language as tl +from .. import runtime from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -42,8 +43,8 @@ def arange_start( pin_memory = False if device is None: - device = torch.device( - "cuda" + device = ( + runtime.device.name ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU. result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) diff --git a/src/flag_gems/ops/argmax.py b/src/flag_gems/ops/argmax.py index ef6d53660..310cbd227 100644 --- a/src/flag_gems/ops/argmax.py +++ b/src/flag_gems/ops/argmax.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -120,7 +121,7 @@ def argmax(inp, dim=None, keepdim=False, *, dtype=None): else: out = torch.empty([], dtype=torch.int64, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): argmax_kernel_1[(mid_size, 1, 1)]( inp, mid_value, @@ -150,7 +151,7 @@ def argmax(inp, dim=None, keepdim=False, *, dtype=None): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): argmax_kernel[grid]( inp, out_index, diff --git a/src/flag_gems/ops/bmm.py b/src/flag_gems/ops/bmm.py index fb45c2a86..76c8377f6 100644 --- a/src/flag_gems/ops/bmm.py +++ b/src/flag_gems/ops/bmm.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -145,6 +146,6 @@ def bmm(A, B): triton.cdiv(meta["N"], meta["TILE_N"]), batch, ) - with torch.cuda.device(A.device): + with torch_device_fn.device(A.device): bmm_kernel[grid_fn](A, B, out, M, N, K) return out diff --git a/src/flag_gems/ops/cross_entropy_loss.py b/src/flag_gems/ops/cross_entropy_loss.py index 3b462c14b..6c367813e 100644 --- a/src/flag_gems/ops/cross_entropy_loss.py +++ b/src/flag_gems/ops/cross_entropy_loss.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -536,7 +537,7 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): if tgt.ndim == dim: # target probabilities - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): celoss_probability_kernel[grid]( inp, tgt, @@ -549,7 +550,7 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): elif label_smoothing == 0: # target indices w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): celoss_indices_kernel[grid]( inp, tgt, @@ -562,7 +563,7 @@ def forward(ctx, inp, target, weight, reduction, ignore_index, label_smoothing): ) else: w_tgt = torch.empty(shape, dtype=torch.float32, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): celoss_indices_smooth_kernel[grid]( inp, tgt, diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 51dc98a22..44de4bdf3 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -7,8 +7,11 @@ from flag_gems.utils import libentry +from ..runtime import device, torch_device_fn from ..utils import triton_lang_extension as tle +device = device.name + @libentry() @triton.jit(do_not_specialize=["n_elements", "part_num"]) @@ -160,12 +163,12 @@ def scan_then_fan_col(inp, out, n_ele, dtype): partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) grid = (part_num,) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) if part_num >= 2: scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) @@ -178,14 +181,14 @@ def scan_then_fan(inp, out, A, B, C, dtype): partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) grid = (A, part_num, C) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): scan_part_sum_abc_kernel[grid]( inp, out, partial_sum, B, C, part_num, BLOCK_SIZE ) if part_num >= 2: scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) @@ -375,9 +378,9 @@ def normed_cumsum(inp, dim=-1): inp = inp.transpose(dim, -1).contiguous() dim = -1 out = torch.empty_like(inp) - with torch.cuda.device(inp.device.index): + with torch_device_fn.device(inp.device.index): # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta - num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + num_sms = torch_device_fn.get_device_properties(device).multi_processor_count TILE = 2048 # Each row is split into n_chunks of chunks where each chunk is compised of # n_tiles of tiles. Different chunks are assigned to different ctas. @@ -416,7 +419,7 @@ def normed_cumsum(inp, dim=-1): if inp.dtype != torch.float64: acc_dtype = torch.float32 - sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device="cuda") + sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) cumsums = torch.empty_like(sums) block_cumsum_kernel[grid]( inp, diff --git a/src/flag_gems/ops/diag.py b/src/flag_gems/ops/diag.py index aa312674c..2abf54321 100644 --- a/src/flag_gems/ops/diag.py +++ b/src/flag_gems/ops/diag.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import triton_lang_extension as tle @@ -63,7 +64,7 @@ def diag_1d_to_2d(x, diagonal=0): grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): diag_1d_to_2d_kernel[grid]( x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE ) @@ -85,7 +86,7 @@ def diag_2d_to_1d(x, diagonal=0): grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): diag_2d_to_1d_kernel[grid]( x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE ) diff --git a/src/flag_gems/ops/div.py b/src/flag_gems/ops/div.py index 2cf0eb2ca..56eb8cd01 100644 --- a/src/flag_gems/ops/div.py +++ b/src/flag_gems/ops/div.py @@ -4,15 +4,13 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import div_rn, div_rz, fmod, trunc -except ImportError: - try: - from triton.language.math import div_rn, div_rz, fmod, trunc - except ImportError: - from triton.language.libdevice import div_rn, div_rz, fmod, trunc +div_rn = tl_extra_module.div_rn +div_rz = tl_extra_module.div_rz +fmod = tl_extra_module.fmod +trunc = tl_extra_module.trunc @pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) diff --git a/src/flag_gems/ops/dropout.py b/src/flag_gems/ops/dropout.py index 2a23c4a0d..8cf012222 100644 --- a/src/flag_gems/ops/dropout.py +++ b/src/flag_gems/ops/dropout.py @@ -4,7 +4,12 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) + +from ..runtime import torch_device_fn def heur_block(args): @@ -152,8 +157,8 @@ def forward(ctx, x, p, train): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - with torch.cuda.device(device): - philox_seed, philox_offset = philox_cuda_seed_offset(increment) + with torch_device_fn.device(device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) dropout_forward_kernel[grid_fn](x, out, N, p, philox_seed, philox_offset) ctx.p = p ctx.philox_seed = philox_seed @@ -168,7 +173,7 @@ def backward(ctx, grad_outputs, kwargs): grad_inputs = torch.empty_like(grad_outputs) N = grad_outputs.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) - with torch.cuda.device(device): + with torch_device_fn.device(device): dropout_backward_kernel[grid_fn]( grad_outputs, grad_inputs, N, ctx.p, ctx.philox_seed, ctx.philox_offset ) diff --git a/src/flag_gems/ops/embedding.py b/src/flag_gems/ops/embedding.py index ae1f36d26..37cd898e8 100644 --- a/src/flag_gems/ops/embedding.py +++ b/src/flag_gems/ops/embedding.py @@ -5,6 +5,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -126,7 +127,7 @@ def forward( (*indices.shape, N), device=indices.device, dtype=weight.dtype ) - with torch.cuda.device(weight.device): + with torch_device_fn.device(weight.device): embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE) ctx.M = M @@ -162,7 +163,7 @@ def backward(ctx, grad_outputs): INDICE_BLOCK_SIZE = 256 indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),) - with torch.cuda.device(grad_outputs.device): + with torch_device_fn.device(grad_outputs.device): indice_freq_kernel[indice_grid]( indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE ) @@ -173,7 +174,7 @@ def backward(ctx, grad_outputs): HAS_PADDING_IDX = ctx.padding_idx is not None - with torch.cuda.device(grad_outputs.device): + with torch_device_fn.device(grad_outputs.device): embedding_backward_kernel[ctx.M,]( grad_inputs, grad_outputs, @@ -185,7 +186,7 @@ def backward(ctx, grad_outputs): ) if ctx.scale_grad_by_freq: - with torch.cuda.device(grad_outputs.device): + with torch_device_fn.device(grad_outputs.device): embedding_grad_scale_kernel[ctx.M,]( grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE ) diff --git a/src/flag_gems/ops/eq.py b/src/flag_gems/ops/eq.py index 1cc5d6f79..a290d11a4 100644 --- a/src/flag_gems/ops/eq.py +++ b/src/flag_gems/ops/eq.py @@ -3,8 +3,11 @@ import triton import triton.language as tl +from ..runtime import device from ..utils import pointwise_dynamic +device = device.name + @pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")]) @triton.jit @@ -14,7 +17,7 @@ def eq_func(x, y): def eq(A, B): if A.device != B.device: - if A.device.type == "cuda": + if A.device.type == device: B = B.to(A.device) else: A = A.to(B.device) diff --git a/src/flag_gems/ops/exponential_.py b/src/flag_gems/ops/exponential_.py index b47328edc..4ab322d67 100644 --- a/src/flag_gems/ops/exponential_.py +++ b/src/flag_gems/ops/exponential_.py @@ -4,7 +4,12 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) + +from ..runtime import torch_device_fn def heur_block(args): @@ -109,10 +114,10 @@ def exponential_(x, lambd: float = 1.0, *, gen=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) + philox_seed, philox_offset = philox_backend_seed_offset(increment) eps = torch.finfo(dtype).eps x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) - with torch.cuda.device(device): + with torch_device_fn.device(device): fused_exponential_kernel[grid_fn]( x_, N, is_double, lambd, eps, philox_seed, philox_offset ) diff --git a/src/flag_gems/ops/fill.py b/src/flag_gems/ops/fill.py index 8a1e46ebe..9da425e75 100644 --- a/src/flag_gems/ops/fill.py +++ b/src/flag_gems/ops/fill.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -44,7 +45,7 @@ def fill_tensor(input, value): BLOCK_SIZE = 512 grid = triton.cdiv(N, BLOCK_SIZE) - with torch.cuda.device(input.device): + with torch_device_fn.device(input.device): fill_tensor_kernel[grid,](out, N, value, BLOCK_SIZE) return out @@ -56,6 +57,6 @@ def fill_scalar(input, value): BLOCK_SIZE = 512 grid = triton.cdiv(N, BLOCK_SIZE) - with torch.cuda.device(input.device): + with torch_device_fn.device(input.device): fill_scalar_kernel[grid,](out, N, value, BLOCK_SIZE) return out diff --git a/src/flag_gems/ops/full.py b/src/flag_gems/ops/full.py index 8f1448769..b77d56ca2 100644 --- a/src/flag_gems/ops/full.py +++ b/src/flag_gems/ops/full.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import triton_lang_extension as tle from ..utils.shape_utils import volume @@ -67,7 +68,7 @@ def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=N out = torch.empty(size, device=device, dtype=dtype) N = volume(size) grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - with torch.cuda.device(device): + with torch_device_fn.device(device): full_kernel[grid_fn]( out, N, diff --git a/src/flag_gems/ops/full_like.py b/src/flag_gems/ops/full_like.py index 7bf02ad5a..597692107 100644 --- a/src/flag_gems/ops/full_like.py +++ b/src/flag_gems/ops/full_like.py @@ -3,6 +3,7 @@ import torch import triton +from ..runtime import torch_device_fn from .full import check_dtype, full_kernel @@ -25,7 +26,7 @@ def full_like( out = torch.empty_like(x, device=device, dtype=dtype) N = x.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): full_kernel[grid_fn]( out, N, diff --git a/src/flag_gems/ops/gelu.py b/src/flag_gems/ops/gelu.py index fc6dd3245..07537e46e 100644 --- a/src/flag_gems/ops/gelu.py +++ b/src/flag_gems/ops/gelu.py @@ -4,15 +4,13 @@ import triton import triton.language as tl +from ..runtime.moduel_tool import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import erf, exp, pow, tanh -except ImportError: - try: - from triton.language.math import erf, exp, pow, tanh - except ImportError: - from triton.language.libdevice import erf, exp, pow, tanh +erf = tl_extra_module.erf +exp = tl_extra_module.exp +pow = tl_extra_module.pow +tanh = tl_extra_module.tanh @pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) diff --git a/src/flag_gems/ops/groupnorm.py b/src/flag_gems/ops/groupnorm.py index 7b4872aa8..3ce8ce860 100644 --- a/src/flag_gems/ops/groupnorm.py +++ b/src/flag_gems/ops/groupnorm.py @@ -4,16 +4,12 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn +from ..runtime.moduel_tool import tl_extra_module from ..utils import libentry from ..utils import triton_lang_extension as tle -try: - from triton.language.extra.cuda.libdevice import rsqrt -except ImportError: - try: - from triton.language.math import rsqrt - except ImportError: - from triton.language.libdevice import rsqrt +rsqrt = tl_extra_module.rsqrt @libentry() @@ -185,7 +181,7 @@ def forward(ctx, x, weight, bias, N, C, HW, num_groups, eps): rstd = torch.empty((N, num_groups), dtype=x.dtype, device=x.device) grid = (N * num_groups,) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): group_norm_kernel[grid]( x, y, @@ -223,7 +219,7 @@ def backward(ctx, y_grad, mean_grad, rstd_grad): weight_grad = torch.empty_like(weight) bias_grad = torch.empty_like(weight) grid = (N * num_groups,) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): group_norm_backward_kernel[grid]( y_grad, x, diff --git a/src/flag_gems/ops/isclose.py b/src/flag_gems/ops/isclose.py index b396a1883..78b4e9551 100644 --- a/src/flag_gems/ops/isclose.py +++ b/src/flag_gems/ops/isclose.py @@ -4,24 +4,12 @@ import triton import triton.language as tl +from ..runtime.moduel_tool import tl_extra_module from ..utils import pointwise_dynamic from .all import all -try: - from triton.language.extra.cuda.libdevice import isfinited as _isfinited -except ImportError: - try: - from triton.language.math import isfinited as _isfinited - except ImportError: - from triton.language.libdevice import isfinited as _isfinited - -try: - from triton.language.extra.cuda.libdevice import finitef as _finitef -except ImportError: - try: - from triton.language.math import finitef as _finitef - except ImportError: - from triton.language.libdevice import finitef as _finitef +_isfinited = tl_extra_module.isfinited +_finitef = tl_extra_module.finitef @pointwise_dynamic( diff --git a/src/flag_gems/ops/isfinite.py b/src/flag_gems/ops/isfinite.py index d0dc0b2cb..72739b021 100644 --- a/src/flag_gems/ops/isfinite.py +++ b/src/flag_gems/ops/isfinite.py @@ -4,23 +4,11 @@ import triton import triton.language as tl +from ..runtime.moduel_tool import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import isfinited as _isfinited -except ImportError: - try: - from triton.language.math import isfinited as _isfinited - except ImportError: - from triton.language.libdevice import isfinited as _isfinited - -try: - from triton.language.extra.cuda.libdevice import finitef as _finitef -except ImportError: - try: - from triton.language.math import finitef as _finitef - except ImportError: - from triton.language.libdevice import finitef as _finitef +_isfinited = tl_extra_module.isfinited +_finitef = tl_extra_module.finitef @pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")]) diff --git a/src/flag_gems/ops/isin.py b/src/flag_gems/ops/isin.py index 07d15bf30..3bb21aaa4 100644 --- a/src/flag_gems/ops/isin.py +++ b/src/flag_gems/ops/isin.py @@ -6,6 +6,7 @@ from flag_gems.utils.libentry import libentry +from ..runtime import torch_device_fn from ..utils import triton_lang_extension as tle from .all import reduce_all from .any import reduce_any @@ -105,7 +106,7 @@ def isin_by_comparation( tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) grid = (ctas_num,) out = torch.empty_like(in0_ravel, dtype=torch.bool) - with torch.cuda.device(in0_ravel.device.index): + with torch_device_fn.device(in0_ravel.device.index): isin_by_comparation_kernel[grid]( in0_ravel, in1_ravel, # in @@ -226,7 +227,7 @@ def isin_by_search( tiles_per_cta = triton.cdiv(M, BLOCK_M * ctas_num) grid = (ctas_num,) out = torch.empty_like(in0_ravel, dtype=torch.bool) - with torch.cuda.device(in0_ravel.device.index): + with torch_device_fn.device(in0_ravel.device.index): isin_by_search_kernel[grid]( in0_ravel, in1_ravel, # in diff --git a/src/flag_gems/ops/isinf.py b/src/flag_gems/ops/isinf.py index 7f92b2341..04627d679 100644 --- a/src/flag_gems/ops/isinf.py +++ b/src/flag_gems/ops/isinf.py @@ -3,15 +3,10 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import isinf as _isinf -except ImportError: - try: - from triton.language.math import isinf as _isinf - except ImportError: - from triton.language.libdevice import isinf as _isinf +_isinf = tl_extra_module.isinf @pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) diff --git a/src/flag_gems/ops/isnan.py b/src/flag_gems/ops/isnan.py index f175c3432..75247b292 100644 --- a/src/flag_gems/ops/isnan.py +++ b/src/flag_gems/ops/isnan.py @@ -3,15 +3,10 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import isnan as _isnan -except ImportError: - try: - from triton.language.math import isnan as _isnan - except ImportError: - from triton.language.libdevice import isnan as _isnan +_isnan = tl_extra_module.isnan @pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) diff --git a/src/flag_gems/ops/layernorm.py b/src/flag_gems/ops/layernorm.py index 3c8053aee..e678f37b0 100644 --- a/src/flag_gems/ops/layernorm.py +++ b/src/flag_gems/ops/layernorm.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle from ..utils.type_utils import get_accumulator_dtype @@ -319,7 +320,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps=1e-5, cudnn_enable=True) mean = torch.empty(M, dtype=acc_type, device=x.device) rstd = torch.empty(M, dtype=acc_type, device=x.device) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): if N <= 128: TILE_N = triton.next_power_of_2(N) TILE_M = triton.cdiv(1024, TILE_N) @@ -378,7 +379,7 @@ def backward(ctx, out_grad, mean_grad, rstd_grad): M = ctx.M N = ctx.N - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): in_grad = torch.empty_like(x) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_ROW_SIZE"]), 1, 1) layer_norm_backward_kernel[grid]( diff --git a/src/flag_gems/ops/log_softmax.py b/src/flag_gems/ops/log_softmax.py index 250baffaa..ffa95cc49 100644 --- a/src/flag_gems/ops/log_softmax.py +++ b/src/flag_gems/ops/log_softmax.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -127,7 +128,7 @@ def forward(ctx, x, dim, dtype): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): log_softmax_kernel[grid]( out, inp, @@ -161,7 +162,7 @@ def backward(ctx, out_grad): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(in_grad.device): + with torch_device_fn.device(in_grad.device): log_softmax_backward_kernel[grid]( out, out_grad, diff --git a/src/flag_gems/ops/masked_select.py b/src/flag_gems/ops/masked_select.py index ab701719e..c98932bde 100644 --- a/src/flag_gems/ops/masked_select.py +++ b/src/flag_gems/ops/masked_select.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import broadcastable, libentry from ..utils import triton_lang_extension as tle @@ -52,6 +53,6 @@ def masked_select(inp, mask): n_elements = inp.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements) return out diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index b367a0049..763751b3a 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -7,6 +7,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -101,7 +102,7 @@ def max(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -131,7 +132,7 @@ def max_dim(inp, dim=None, keepdim=False): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): max_kernel[grid](inp, out_value, out_index, M, N, K) Max_out = namedtuple("max", ["values", "indices"]) out = Max_out(values=out_value, indices=out_index) diff --git a/src/flag_gems/ops/maximum.py b/src/flag_gems/ops/maximum.py index cc9236fbd..4ba086d86 100644 --- a/src/flag_gems/ops/maximum.py +++ b/src/flag_gems/ops/maximum.py @@ -3,8 +3,11 @@ import triton import triton.language as tl +from ..runtime import device from ..utils import pointwise_dynamic +device = device.name + @pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit @@ -18,5 +21,5 @@ def maximum_kernel(X, Y): def maximum(X, Y): logging.debug("GEMS MAXIMUM") - assert X.is_cuda and Y.is_cuda + assert X.device.type == device and Y.device.type == device return maximum_kernel(X, Y) diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index 4b86dd853..e3aba6864 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -51,7 +52,7 @@ def mean(inp, *, dtype=None): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) return out @@ -106,7 +107,7 @@ def mean_dim(x, dim, keepdim=False, *, dtype=None): out = torch.empty(shape, dtype=dtype, device=x.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): mean_dim_kernel[grid](x, out, M, N) if not keepdim: out = out.squeeze(dim) diff --git a/src/flag_gems/ops/min.py b/src/flag_gems/ops/min.py index 2d77db4fe..6d3815cd3 100644 --- a/src/flag_gems/ops/min.py +++ b/src/flag_gems/ops/min.py @@ -7,6 +7,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -100,7 +101,7 @@ def min(inp): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -130,7 +131,7 @@ def min_dim(inp, dim=None, keepdim=False): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): min_kernel[grid](inp, out_value, out_index, M, N, K) Min_out = namedtuple("min", ["values", "indices"]) out = Min_out(values=out_value, indices=out_index) diff --git a/src/flag_gems/ops/minimum.py b/src/flag_gems/ops/minimum.py index d16f27dd6..aa12246d7 100644 --- a/src/flag_gems/ops/minimum.py +++ b/src/flag_gems/ops/minimum.py @@ -3,8 +3,11 @@ import triton import triton.language as tl +from ..runtime import device from ..utils import pointwise_dynamic +device = device.name + @pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 0, "DEFAULT")]) @triton.jit @@ -17,5 +20,5 @@ def minimum_kernel(X, Y): def minimum(X, Y): logging.debug("GEMS MINIMUM") - assert X.is_cuda and Y.is_cuda + assert X.device.type == device and Y.device.type == device return minimum_kernel(X, Y) diff --git a/src/flag_gems/ops/mm.py b/src/flag_gems/ops/mm.py index 8e0998074..7bbc5ead3 100644 --- a/src/flag_gems/ops/mm.py +++ b/src/flag_gems/ops/mm.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -132,7 +133,7 @@ def mm(a, b): triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"], ) - with torch.cuda.device(a.device): + with torch_device_fn.device(a.device): mm_kernel[grid]( a, b, diff --git a/src/flag_gems/ops/multinomial.py b/src/flag_gems/ops/multinomial.py index 85412da26..d2f40404f 100644 --- a/src/flag_gems/ops/multinomial.py +++ b/src/flag_gems/ops/multinomial.py @@ -5,7 +5,7 @@ import triton.language as tl from flag_gems.utils import libentry -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uniform +from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform @libentry() @@ -90,7 +90,7 @@ def multinomial(prob, n_samples, with_replacement=False, *, gen=None): # The CTA level parallelism is framed in a 2d grid of blocks with grid.y # indexing into distributions and grid.x output sample batches increment = n_dist * n_samples - philox_seed, philox_offset = philox_cuda_seed_offset(increment) + philox_seed, philox_offset = philox_backend_seed_offset(increment) grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist) multinomial_with_replacement[grid]( cum_prob, out, n_categories, n_samples, philox_seed, philox_offset diff --git a/src/flag_gems/ops/mv.py b/src/flag_gems/ops/mv.py index 616345ffd..723d94292 100644 --- a/src/flag_gems/ops/mv.py +++ b/src/flag_gems/ops/mv.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -54,7 +55,7 @@ def mv(inp, vec): N, M = inp.shape out = torch.empty((N,), device=inp.device, dtype=inp.dtype) grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): mv_kernel[grid]( inp, vec, diff --git a/src/flag_gems/ops/nonzero.py b/src/flag_gems/ops/nonzero.py index 147ac34e7..0578b4fd9 100644 --- a/src/flag_gems/ops/nonzero.py +++ b/src/flag_gems/ops/nonzero.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -65,7 +66,7 @@ def nonzero(inp, *, as_tuple=False): out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim) num_nonzeros = prefix_sum[n_elements - 1].item() diff --git a/src/flag_gems/ops/normal.py b/src/flag_gems/ops/normal.py index ab05166f6..b24b4398e 100644 --- a/src/flag_gems/ops/normal.py +++ b/src/flag_gems/ops/normal.py @@ -3,8 +3,9 @@ import torch import triton +from ..runtime import torch_device_fn from ..utils import pointwise_dynamic -from ..utils.random_utils import philox_cuda_seed_offset +from ..utils.random_utils import philox_backend_seed_offset from ..utils.shape_utils import broadcast_shapes, volume from .randn import randn_kernel @@ -49,8 +50,8 @@ def normal_distribution(shape, device, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/ones.py b/src/flag_gems/ops/ones.py index 77d3a59a6..00eb42c59 100644 --- a/src/flag_gems/ops/ones.py +++ b/src/flag_gems/ops/ones.py @@ -4,10 +4,13 @@ import triton import triton.language as tl +from ..runtime import device, torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle from ..utils.shape_utils import volume +device_ = device + @libentry() @triton.jit @@ -28,12 +31,12 @@ def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None): if dtype is None: dtype = torch.get_default_dtype() if device is None: - device = torch.device("cuda") + device = torch.device(device_.name) out = torch.empty(size, device=device, dtype=dtype) N = volume(size) BLOCK_SIZE = 1024 grid = (triton.cdiv(N, BLOCK_SIZE),) - with torch.cuda.device(device): + with torch_device_fn.device(device): ones_kernel[grid](out, N, BLOCK_SIZE) return out diff --git a/src/flag_gems/ops/ones_like.py b/src/flag_gems/ops/ones_like.py index 88c894ea3..12e712606 100644 --- a/src/flag_gems/ops/ones_like.py +++ b/src/flag_gems/ops/ones_like.py @@ -3,6 +3,7 @@ import torch import triton +from ..runtime import torch_device_fn from .ones import ones_kernel @@ -17,6 +18,6 @@ def ones_like( out = torch.empty_like(x, device=device, dtype=dtype) N = x.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): ones_kernel[grid_fn](out, N, BLOCK_SIZE=1024) return out diff --git a/src/flag_gems/ops/pad.py b/src/flag_gems/ops/pad.py index 439be2e9b..8f508e5a4 100644 --- a/src/flag_gems/ops/pad.py +++ b/src/flag_gems/ops/pad.py @@ -68,6 +68,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("from triton import language as tl") code.newline() code.writeline("from flag_gems.utils.libentry import libentry") + code.writeline("from flag_gems.runtime import torch_device_fn") code.writeline("from flag_gems.utils import triton_lang_extension as tle") code.writeline("from flag_gems.utils.type_utils import type_promotion") code.newline() @@ -166,7 +167,7 @@ def generate_destination_passing_padding_wrapper( code.writeline("# kernel launch") # launch kernel - code.writeline("with torch.cuda.device(in0.device):") + code.writeline("with torch_device_fn.device(in0.device):") with code.indent(): kernel_launch: str = f"{kernel_name}[grid](" code.writeline(kernel_launch) diff --git a/src/flag_gems/ops/pow.py b/src/flag_gems/ops/pow.py index af6090d0c..a55859c06 100644 --- a/src/flag_gems/ops/pow.py +++ b/src/flag_gems/ops/pow.py @@ -3,15 +3,10 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import pow as _pow -except ImportError: - try: - from triton.language.math import pow as _pow - except ImportError: - from triton.language.libdevice import pow as _pow +_pow = tl_extra_module.pow @pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")]) diff --git a/src/flag_gems/ops/prod.py b/src/flag_gems/ops/prod.py index d9740f35a..cc4c623d8 100644 --- a/src/flag_gems/ops/prod.py +++ b/src/flag_gems/ops/prod.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -57,7 +58,7 @@ def prod(inp, *, dtype=None): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): prod_kernel_mid[(mid_size, 1, 1)](inp, mid, M, block_size) prod_kernel_result[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -133,7 +134,7 @@ def prod_dim(inp, dim=None, keepdim=False, *, dtype=None): triton.cdiv(M, meta["BLOCK_M"]), K, ) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): prod_kernel[grid](inp, out, M, N, K) return out diff --git a/src/flag_gems/ops/rand.py b/src/flag_gems/ops/rand.py index 1a604d755..124d450fd 100644 --- a/src/flag_gems/ops/rand.py +++ b/src/flag_gems/ops/rand.py @@ -4,9 +4,16 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) from flag_gems.utils.shape_utils import volume +from ..runtime import device, torch_device_fn + +device_ = device + def heur_block(args): if args["N"] <= 512: @@ -68,7 +75,7 @@ def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): if dtype is None: dtype = torch.get_default_dtype() if device is None: - device = torch.device("cuda") + device = torch.device(device_.name) out = torch.empty(size, device=device, dtype=dtype) N = volume(size) @@ -76,7 +83,7 @@ def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/rand_like.py b/src/flag_gems/ops/rand_like.py index f7dc407e2..85338f595 100644 --- a/src/flag_gems/ops/rand_like.py +++ b/src/flag_gems/ops/rand_like.py @@ -4,7 +4,9 @@ import triton from flag_gems.ops.rand import rand_kernel -from flag_gems.utils.random_utils import philox_cuda_seed_offset +from flag_gems.utils.random_utils import philox_backend_seed_offset + +from ..runtime import torch_device_fn UNROLL = 4 @@ -23,7 +25,7 @@ def rand_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(x.device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(x.device): rand_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn.py b/src/flag_gems/ops/randn.py index 43f591a87..47c090bd9 100644 --- a/src/flag_gems/ops/randn.py +++ b/src/flag_gems/ops/randn.py @@ -4,9 +4,14 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) from flag_gems.utils.shape_utils import volume +from ..runtime import device, torch_device_fn + try: pair_uniform_to_normal = tl.pair_uniform_to_normal except AttributeError: @@ -20,6 +25,9 @@ def pair_uniform_to_normal(u1, u2): return r * tl.cos(th), r * tl.sin(th) +device_ = device + + def heur_block(args): if args["N"] <= 512: return 512 @@ -82,14 +90,14 @@ def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): if dtype is None: dtype = torch.get_default_dtype() if device is None: - device = torch.device("cuda") + device = torch.device(device_.name) out = torch.empty(size, device=device, dtype=dtype) N = volume(size) grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randn_like.py b/src/flag_gems/ops/randn_like.py index 77034ce6f..0458328dc 100644 --- a/src/flag_gems/ops/randn_like.py +++ b/src/flag_gems/ops/randn_like.py @@ -4,7 +4,9 @@ import triton from flag_gems.ops.randn import randn_kernel -from flag_gems.utils.random_utils import philox_cuda_seed_offset +from flag_gems.utils.random_utils import philox_backend_seed_offset + +from ..runtime import torch_device_fn UNROLL = 4 @@ -23,7 +25,7 @@ def randn_like( # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, # hence we cannot obtain the per thread offset as in Pytorch. increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(x.device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(x.device): randn_kernel[grid_fn](out, N, philox_seed, philox_offset) return out diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py index aa2e3b4b9..46377d757 100644 --- a/src/flag_gems/ops/randperm.py +++ b/src/flag_gems/ops/randperm.py @@ -4,12 +4,15 @@ import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset +from flag_gems.utils.random_utils import philox_backend_seed_offset from .. import runtime +from ..runtime import device, torch_device_fn from ..utils import libentry from .topk import argsort +device_ = device + _MIN_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).min _MAX_INT8_VAL: tl.constexpr = torch.iinfo(torch.int8).max _MIN_INT16_VAL: tl.constexpr = torch.iinfo(torch.int16).min @@ -254,7 +257,7 @@ def radix_sortbykey_scatter_kernel( tl.store(value_out + global_offsets, value_data, mask=key_digit_mask) -# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch cuda backend +# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch GPU backend @libentry() @triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) def duplicate_keys_shuffle_kernel( @@ -324,7 +327,7 @@ def sort_by_key(key, value, valid_bits): # step1 d_lookback.zero_() - with torch.cuda.device(key.device): + with torch_device_fn.device(key.device): digit_hist_kernel[grid_hist]( digit_hist_slice, key, @@ -355,7 +358,7 @@ def sort_by_key(key, value, valid_bits): ) tiles_per_portion = triton.cdiv(portion_items, BLOCK_SIZE) grid_scatter = (tiles_per_portion, grid_hist[1]) - with torch.cuda.device(key.device): + with torch_device_fn.device(key.device): radix_sortbykey_scatter_kernel[grid_scatter]( k_out, v_out, @@ -381,8 +384,8 @@ def sort_by_key(key, value, valid_bits): # last step, shuffle inner-block data BLOCK_SIZE_SHUFFLE = 512 grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) - philox_seed, philox_offset = philox_cuda_seed_offset(n_elements) - with torch.cuda.device(key.device): + philox_seed, philox_offset = philox_backend_seed_offset(n_elements) + with torch_device_fn.device(key.device): duplicate_keys_shuffle_kernel[grid_shuffle]( v_out, n_elements, @@ -398,7 +401,7 @@ def sort_by_key(key, value, valid_bits): grid = (1,) k_out = torch.empty_like(key) v_out = torch.empty_like(value) - with torch.cuda.device(key.device): + with torch_device_fn.device(key.device): bitonic_sortbykey_kernel[grid]( k_out, v_out, key, value, n_elements, BLOCK_SIZE, False ) @@ -421,7 +424,7 @@ def randperm( assert n <= _MAX_INT64_VAL, "n exceeds maximum int64" if device is None: - device = torch.device("cuda") + device = torch.device(device_.name) in_range = torch.arange(n, dtype=dtype, device=device) u8max = 2**8 diff --git a/src/flag_gems/ops/repeat.py b/src/flag_gems/ops/repeat.py index 74d2d64c4..a25a369fc 100644 --- a/src/flag_gems/ops/repeat.py +++ b/src/flag_gems/ops/repeat.py @@ -55,6 +55,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("import triton") code.writeline("from triton import language as tl") code.newline() + code.writeline("from flag_gems.runtime import torch_device_fn") code.writeline("from flag_gems.utils.shape_utils import volume") code.writeline("from flag_gems.utils.libentry import libentry") code.writeline("from flag_gems.utils.type_utils import type_promotion") @@ -170,7 +171,7 @@ def generate_destination_passing_repeat_wrapper( code.writeline("# kernel launch") # launch kernel - code.writeline("with torch.cuda.device(in0.device.index):") + code.writeline("with torch_device_fn.device(in0.device.index):") with code.indent(): kernel_launch: str = f"{kernel_name}[grid](" code.writeline(kernel_launch) diff --git a/src/flag_gems/ops/rms_norm.py b/src/flag_gems/ops/rms_norm.py index 0089e5660..c6d68f8ce 100644 --- a/src/flag_gems/ops/rms_norm.py +++ b/src/flag_gems/ops/rms_norm.py @@ -5,6 +5,7 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -52,7 +53,7 @@ def forward(ctx, x, normalized_shape, weight, eps=1e-5): weight = weight.contiguous() y = torch.empty_like(x) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): rms_norm_kernel[M,](y, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE) return y diff --git a/src/flag_gems/ops/sigmoid.py b/src/flag_gems/ops/sigmoid.py index da559efcd..23216b5dd 100644 --- a/src/flag_gems/ops/sigmoid.py +++ b/src/flag_gems/ops/sigmoid.py @@ -4,15 +4,10 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import exp2 -except ImportError: - try: - from triton.language.math import exp2 - except ImportError: - from triton.language.libdevice import exp2 +exp2 = tl_extra_module.exp2 @pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) diff --git a/src/flag_gems/ops/silu.py b/src/flag_gems/ops/silu.py index 53ae6812a..59786b853 100644 --- a/src/flag_gems/ops/silu.py +++ b/src/flag_gems/ops/silu.py @@ -4,15 +4,10 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import div_rn -except ImportError: - try: - from triton.language.math import div_rn - except ImportError: - from triton.language.libdevice import div_rn +div_rn = tl_extra_module.div_rn @pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) diff --git a/src/flag_gems/ops/softmax.py b/src/flag_gems/ops/softmax.py index 4bb3900c2..745103a85 100644 --- a/src/flag_gems/ops/softmax.py +++ b/src/flag_gems/ops/softmax.py @@ -5,12 +5,13 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle MAX_TILE_K = 8192 -NUM_SMS = torch.cuda.get_device_properties( - torch.cuda.current_device() +NUM_SMS = torch_device_fn.get_device_properties( + torch_device_fn.current_device() ).multi_processor_count @@ -376,7 +377,7 @@ def forward(ctx, x, dim, dtype): out = torch.empty_like(inp, dtype=dtype) K = inp.numel() // M // N # post_dim - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): if K > 1: grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) softmax_kernel_non_inner[grid]( @@ -415,7 +416,7 @@ def backward(ctx, out_grad): in_grad = torch.empty_like(out) K = out.numel() // M // N - with torch.cuda.device(in_grad.device): + with torch_device_fn.device(in_grad.device): if K > 1: grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) softmax_backward_kernel_non_inner[grid]( diff --git a/src/flag_gems/ops/sum.py b/src/flag_gems/ops/sum.py index fd4ce365b..3194bfb90 100644 --- a/src/flag_gems/ops/sum.py +++ b/src/flag_gems/ops/sum.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -105,7 +106,7 @@ def sum(inp, *, dtype=None): mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) out = torch.empty([], dtype=dtype, device=inp.device) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) return out @@ -137,7 +138,7 @@ def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): out = torch.empty(shape, dtype=dtype, device=inp.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) - with torch.cuda.device(inp.device): + with torch_device_fn.device(inp.device): sum_kernel[grid](inp, out, M, N) if not keepdim: out = out.squeeze(dim=dim) diff --git a/src/flag_gems/ops/tanh.py b/src/flag_gems/ops/tanh.py index ac093d4b1..729fa3eed 100644 --- a/src/flag_gems/ops/tanh.py +++ b/src/flag_gems/ops/tanh.py @@ -4,23 +4,11 @@ import triton import triton.language as tl +from ..runtime import tl_extra_module from ..utils import pointwise_dynamic -try: - from triton.language.extra.cuda.libdevice import pow -except ImportError: - try: - from triton.language.math import pow - except ImportError: - from triton.language.libdevice import pow - -try: - from triton.language.extra.cuda.libdevice import tanh as _tanh -except ImportError: - try: - from triton.language.math import tanh as _tanh - except ImportError: - from triton.language.libdevice import tanh as _tanh +pow = tl_extra_module.pow +_tanh = tl_extra_module.tanh @pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) diff --git a/src/flag_gems/ops/tile.py b/src/flag_gems/ops/tile.py index 11a177be9..197860604 100644 --- a/src/flag_gems/ops/tile.py +++ b/src/flag_gems/ops/tile.py @@ -55,6 +55,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("import triton") code.writeline("from triton import language as tl") code.newline() + code.writeline("from flag_gems.runtime import torch_device_fn") code.writeline("from flag_gems.utils.shape_utils import volume") code.writeline("from flag_gems.utils.libentry import libentry") code.writeline("from flag_gems.utils.type_utils import type_promotion") @@ -170,7 +171,7 @@ def generate_destination_passing_tile_wrapper( code.writeline("# kernel launch") # launch kernel - code.writeline("with torch.cuda.device(in0.device.index):") + code.writeline("with torch_device_fn.device(in0.device.index):") with code.indent(): kernel_launch: str = f"{kernel_name}[grid](" code.writeline(kernel_launch) diff --git a/src/flag_gems/ops/topk.py b/src/flag_gems/ops/topk.py index 7da0202b7..fa9c002a7 100644 --- a/src/flag_gems/ops/topk.py +++ b/src/flag_gems/ops/topk.py @@ -7,6 +7,7 @@ import triton.language.core as core from triton.language.standard import _log2, zeros_like +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -303,7 +304,7 @@ def topk(x, k, dim=-1, largest=True, sorted=True): stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): topk_stage1_kernel[ batch_size, chunk_num, @@ -319,7 +320,7 @@ def topk(x, k, dim=-1, largest=True, sorted=True): stage2_elem_cnt = chunk_num * k BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): topk_stage2_kernel[batch_size,]( stage2_out, stage2_out_idx, diff --git a/src/flag_gems/ops/triu.py b/src/flag_gems/ops/triu.py index c247e8b2e..429dc87ef 100644 --- a/src/flag_gems/ops/triu.py +++ b/src/flag_gems/ops/triu.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -79,7 +80,7 @@ def triu(A, diagonal=0): out = torch.empty_like(A) assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" M, N = A.shape[-2:] - with torch.cuda.device(A.device): + with torch_device_fn.device(A.device): if len(A.shape) == 2: grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) triu_kernel[grid](A, out, M, N, diagonal) diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py index 0ade2a960..0c8de6aa2 100644 --- a/src/flag_gems/ops/uniform.py +++ b/src/flag_gems/ops/uniform.py @@ -1,12 +1,16 @@ import logging -import torch import triton import triton.language as tl -from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.random_utils import ( + philox_backend_seed_offset, + uint_to_uniform_float, +) from flag_gems.utils.shape_utils import volume +from ..runtime import torch_device_fn + def heur_block(args): if args["N"] <= 512: @@ -71,7 +75,7 @@ def uniform_(self, from_=0.0, to=1.0, *, generator=None): grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) increment = triton.cdiv(N, UNROLL) - philox_seed, philox_offset = philox_cuda_seed_offset(increment) - with torch.cuda.device(self.device): + philox_seed, philox_offset = philox_backend_seed_offset(increment) + with torch_device_fn.device(self.device): uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to) return self diff --git a/src/flag_gems/ops/unique.py b/src/flag_gems/ops/unique.py index 0ac97a205..499ec945a 100644 --- a/src/flag_gems/ops/unique.py +++ b/src/flag_gems/ops/unique.py @@ -4,6 +4,7 @@ from flag_gems.utils.libentry import libentry +from ..runtime import torch_device_fn from ..utils import triton_lang_extension as tle @@ -390,7 +391,7 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): data_out = torch.empty_like(sorted_data) # launch kernel - with torch.cuda.device(sorted_data.device.index): + with torch_device_fn.device(sorted_data.device.index): local_quick_unique_flat_kernel[grid]( sorted_data, # in local_unique, @@ -665,7 +666,7 @@ def sorted_indices_unique_flat( idx = torch.empty_like(inverse_indices) # launch kernel - with torch.cuda.device(sorted_data.device.index): + with torch_device_fn.device(sorted_data.device.index): local_ne_flat_kernel[grid]( sorted_data, # in ne_result, @@ -734,7 +735,7 @@ def simple_unique_flat( unique_size = torch.empty([1], dtype=torch.int64, device=sorted_data.device) # launch kernel - with torch.cuda.device(sorted_data.device.index): + with torch_device_fn.device(sorted_data.device.index): simple_unique_flat_kernel[grid]( sorted_data, sorted_indices, # in @@ -753,7 +754,7 @@ def simple_unique_flat( if return_counts: idx = idx[:out_size] counts = torch.empty_like(idx) - with torch.cuda.device(sorted_data.device.index): + with torch_device_fn.device(sorted_data.device.index): output_counts_flat_kernel[grid]( idx, num_tasks, # in diff --git a/src/flag_gems/ops/upsample_bicubic2d_aa.py b/src/flag_gems/ops/upsample_bicubic2d_aa.py index 55c9af277..64d1414e2 100644 --- a/src/flag_gems/ops/upsample_bicubic2d_aa.py +++ b/src/flag_gems/ops/upsample_bicubic2d_aa.py @@ -5,8 +5,11 @@ import triton import triton.language as tl +from ..runtime import device from ..utils import triton_lang_extension as tle +device = device.name + def configs(): block = [(bx, by) for bx in (512, 256, 128, 64) for by in (2, 1)] @@ -488,7 +491,7 @@ def _upsample_bicubic2d_aa( scales_w: Optional[float] = None, ): logging.debug("GEMS UPSAMPLE BICUBIC2D AA") - assert input.is_cuda + assert input.device.type == device assert input.ndim == 4, "The ndim of input must be 4" assert len(output_size) == 2, "The len of output_size must be 2" diff --git a/src/flag_gems/ops/upsample_nearest2d.py b/src/flag_gems/ops/upsample_nearest2d.py index b0e6be089..521be1f51 100644 --- a/src/flag_gems/ops/upsample_nearest2d.py +++ b/src/flag_gems/ops/upsample_nearest2d.py @@ -5,8 +5,11 @@ import triton import triton.language as tl +from ..runtime import device from ..utils import triton_lang_extension as tle +device = device.name + def configs(): block = [1024, 2048] @@ -67,7 +70,7 @@ def upsample_nearest2d( scales_w: Optional[float] = None, ) -> torch.Tensor: logging.debug("GEMS UPSAMPLE NEAREST2D") - assert input.is_cuda + assert input.device.type == device assert input.ndim == 4, "The ndim of input must be 4" assert len(output_size) == 2, "The len of output_size must be 2" OH, OW = output_size diff --git a/src/flag_gems/ops/var_mean.py b/src/flag_gems/ops/var_mean.py index d7cdcaf77..f9f6612ea 100644 --- a/src/flag_gems/ops/var_mean.py +++ b/src/flag_gems/ops/var_mean.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle @@ -152,7 +153,7 @@ def var_mean(x, dim=None, *, correction=None, keepdim=False): average = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) count = torch.empty([BLOCK_NUM], dtype=x.dtype, device=x.device) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): var_mean_kernel_1[(BLOCK_NUM,)](x, acc, average, count, N, BLOCK_N=BLOCK_N) var_mean_kernel_2[(1,)]( acc, average, count, var, mean, N, correction, BLOCK_NUM @@ -170,7 +171,7 @@ def var_mean(x, dim=None, *, correction=None, keepdim=False): mean = torch.empty(shape, dtype=x.dtype, device=x.device) grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): var_mean_welford_kernel[grid](x, var, mean, M, N, correction) if not keepdim: diff --git a/src/flag_gems/ops/vector_norm.py b/src/flag_gems/ops/vector_norm.py index 2cf3c9bde..023b30747 100644 --- a/src/flag_gems/ops/vector_norm.py +++ b/src/flag_gems/ops/vector_norm.py @@ -5,18 +5,12 @@ import triton import triton.language as tl +from .. import runtime +from ..runtime import tl_extra_module, torch_device_fn from ..utils import dim_compress, libentry from ..utils import triton_lang_extension as tle -try: - from triton.language.extra.cuda.libdevice import pow -except ImportError: - try: - from triton.language.math import pow - except ImportError: - from triton.language.libdevice import pow - -from .. import runtime +pow = tl_extra_module.pow @libentry() @@ -267,7 +261,7 @@ def vector_norm(x, ord=2, dim=None, keepdim=False, dtype=None): if dtype not in [torch.float16, torch.float32, torch.bfloat16]: raise NotImplementedError(f"vector_norm not implemented for {dtype}") - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): if (not dim) or len(dim) == x.ndim: dim = list(range(x.ndim)) shape = [1] * x.ndim diff --git a/src/flag_gems/ops/vstack.py b/src/flag_gems/ops/vstack.py index 54fd25a5f..55991ec33 100644 --- a/src/flag_gems/ops/vstack.py +++ b/src/flag_gems/ops/vstack.py @@ -5,6 +5,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -116,7 +117,7 @@ def vstack(tensors: list): scheduled_num_tensors, ) # Launch the kernel - with torch.cuda.device(c_tensors[0].device): + with torch_device_fn.device(c_tensors[0].device): vstack_kernel[grid]( itensors[0], itensors[1], diff --git a/src/flag_gems/ops/weightnorm.py b/src/flag_gems/ops/weightnorm.py index df4f0b8d4..61eb38bd2 100644 --- a/src/flag_gems/ops/weightnorm.py +++ b/src/flag_gems/ops/weightnorm.py @@ -6,6 +6,7 @@ import triton.language as tl from .. import runtime +from ..runtime import torch_device_fn from ..utils import libentry from ..utils import triton_lang_extension as tle @@ -290,7 +291,7 @@ def forward(ctx, v, g, dim): M = v.shape[0] N = math.prod(v.shape[1:]) grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): weight_norm_kernel_first[grid]( output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny ) @@ -298,7 +299,7 @@ def forward(ctx, v, g, dim): M = math.prod(v.shape[:-1]) N = v.shape[dim] grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): weight_norm_kernel_last[grid]( output, norm, v, g, M, N, eps=torch.finfo(torch.float32).tiny ) @@ -320,7 +321,7 @@ def backward(ctx, w_grad, norm_grad): M = v.shape[0] N = math.prod(v.shape[1:]) grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): weight_norm_bwd_kernel_first[grid]( v_grad, g_grad, @@ -336,7 +337,7 @@ def backward(ctx, w_grad, norm_grad): M = math.prod(v.shape[:dim]) N = v.shape[dim] grid = lambda META: (triton.cdiv(N, META["BLOCK_COL_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): weight_norm_bwd_kernel_last[grid]( v_grad, g_grad, @@ -369,7 +370,7 @@ def forward(ctx, v, dim): grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): norm_kernel[grid]( output, v, @@ -390,7 +391,7 @@ def backward(ctx, norm_grad): v_grad = torch.empty_like(v) grid = lambda META: (triton.cdiv(ctx.V_SHAPE[1], META["BLOCK_ROW_SIZE"]),) - with torch.cuda.device(v.device): + with torch_device_fn.device(v.device): norm_bwd_kernel[grid]( v_grad, norm_grad, diff --git a/src/flag_gems/ops/zeros.py b/src/flag_gems/ops/zeros.py index 0651fa3da..5b566a01e 100644 --- a/src/flag_gems/ops/zeros.py +++ b/src/flag_gems/ops/zeros.py @@ -4,9 +4,12 @@ import triton import triton.language as tl +from ..runtime import device, torch_device_fn from ..utils import triton_lang_extension as tle from ..utils.shape_utils import volume +device_ = device + @triton.jit def zeros_kernel( @@ -26,11 +29,11 @@ def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None): if dtype is None: dtype = torch.get_default_dtype() if device is None: - device = torch.device("cuda") + device = torch.device(device_.name) out = torch.empty(size, device=device, dtype=dtype) N = volume(size) grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - with torch.cuda.device(device): + with torch_device_fn.device(device): zeros_kernel[grid_fn](out, N, BLOCK_SIZE=1024) return out diff --git a/src/flag_gems/ops/zeros_like.py b/src/flag_gems/ops/zeros_like.py index 264d7874f..b8e5e2573 100644 --- a/src/flag_gems/ops/zeros_like.py +++ b/src/flag_gems/ops/zeros_like.py @@ -3,6 +3,7 @@ import torch import triton +from ..runtime import torch_device_fn from .zeros import zeros_kernel @@ -17,6 +18,6 @@ def zeros_like( out = torch.empty_like(x, device=device, dtype=dtype) N = x.numel() grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) - with torch.cuda.device(x.device): + with torch_device_fn.device(x.device): zeros_kernel[grid_fn](out, N, BLOCK_SIZE=1024) return out diff --git a/src/flag_gems/runtime/__init__.py b/src/flag_gems/runtime/__init__.py index 1ac9fff43..3ba390072 100644 --- a/src/flag_gems/runtime/__init__.py +++ b/src/flag_gems/runtime/__init__.py @@ -1,17 +1,18 @@ -from . import backend, commom_utils +from . import backend, commom_utils, moduel_tool from .backend.device import DeviceDetector from .configloader import ConfigLoader config_loader = ConfigLoader() device = DeviceDetector() +# torch_device_fn is like 'torch.cuda' object +torch_device_fn = backend.gen_torch_device_object() +tl_extra_module = moduel_tool.tl_extra_module +# torch_backend_device is like 'torch.backend.cuda' object +torch_backend_device = backend.get_torch_backend_device_fn() def get_triton_config(op_name): return config_loader.get_triton_config(op_name) -def get_device_fn(api_name): - return backend.gen_torch_device_fn(api_name) - - __all__ = ["commom_utils", "backend", "device", "get_triton_config"] diff --git a/src/flag_gems/runtime/backend/__init__.py b/src/flag_gems/runtime/backend/__init__.py index ef6eb9bbe..f74f59257 100644 --- a/src/flag_gems/runtime/backend/__init__.py +++ b/src/flag_gems/runtime/backend/__init__.py @@ -9,6 +9,9 @@ vendor_module = None device_name = None +torch_device_object = None +torch_device_fn_device = None +tl_extra_backend_module = None device_fn_cache = {} @@ -33,18 +36,41 @@ def gen_torch_tensor_attr_res(tensor, attr_name): return get_codegen_result(code, "res") -def gen_torch_device_fn(api_name, vendor_name=None): - global device_name +def set_tl_extra_backend_module(vendor_name=None): + global device_name, tl_extra_backend_module + device_name = device_name or get_vendor_info(vendor_name).device_name + module_str = f"triton.language.extra.{device_name}.libdevice" + tl_extra_backend_module = importlib.import_module(module_str) + + +def get_tl_extra_backend_module(): + global tl_extra_backend_module + return tl_extra_backend_module + + +def set_torch_backend_device_fn(vendor_name=None): + global device_name, torch_device_fn_device + device_name = device_name or get_vendor_info(vendor_name).device_name + module_str = f"torch.backends.{device_name}" + torch_device_fn_device = importlib.import_module(module_str) + + +def get_torch_backend_device_fn(): + global torch_device_fn_device + return torch_device_fn_device + + +def gen_torch_device_object(vendor_name=None): + global device_name, torch_device_object + if torch_device_object is not None: + return torch_device_object device_name = device_name or get_vendor_info(vendor_name).device_name - if api_name in device_fn_cache: - return device_fn_cache[api_name] code = f""" import torch -fn = torch.{device_name}.{api_name} +fn = torch.{device_name} """ - fn = get_codegen_result(code, "fn") - device_fn_cache[api_name] = fn - return fn + torch_device_object = get_codegen_result(code, "fn") + return torch_device_object def get_vendor_module(vendor_name, query=False): @@ -73,7 +99,7 @@ def get_vendor_info(vendor_name=None, query=False): return vendor_module.vendor_info -def get_vendor_infos() -> list: +def get_vendor_infos(): infos = [] for vendor_name in vendors_map: vendor_name = "_" + vendor_name @@ -86,7 +112,7 @@ def get_vendor_infos() -> list: return infos -def get_curent_device_extend_op(vendor_name=None) -> dict: +def get_curent_device_extend_op(vendor_name=None): global vendor_module get_vendor_module(vendor_name) tuples = vendor_module.get_register_op_config() @@ -96,13 +122,13 @@ def get_curent_device_extend_op(vendor_name=None) -> dict: return configs -def get_curent_device_unused_op(vendor_name=None) -> list: +def get_curent_device_unused_op(vendor_name=None): global vendor_module get_vendor_module(vendor_name) return vendor_module.get_unused_op() -def get_tune_config(vendor_name=None) -> dict: +def get_tune_config(vendor_name=None): global vendor_module get_vendor_module(vendor_name) return backend_utils.get_tune_config(vendor_name) diff --git a/src/flag_gems/runtime/backend/_nvidia/ops/add.py b/src/flag_gems/runtime/backend/_nvidia/ops/add.py index fd8fd9973..5000de188 100644 --- a/src/flag_gems/runtime/backend/_nvidia/ops/add.py +++ b/src/flag_gems/runtime/backend/_nvidia/ops/add.py @@ -36,7 +36,6 @@ def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. print("\n.......test for mutibackend........\n") output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -47,6 +46,6 @@ def add(x: torch.Tensor, y: torch.Tensor): # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) - # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output diff --git a/src/flag_gems/runtime/backend/device.py b/src/flag_gems/runtime/backend/device.py index f57dd4e1e..62b1475af 100644 --- a/src/flag_gems/runtime/backend/device.py +++ b/src/flag_gems/runtime/backend/device.py @@ -29,9 +29,9 @@ def __init__(self, vendor_name=None): self.vendor_name = self.info.vendor_name self.name = self.info.device_name self.vendor = vendors_map[self.vendor_name] - self.device_count = backend.gen_torch_device_fn( - "device_count", self.vendor_name - )() + self.device_count = backend.gen_torch_device_object( + self.vendor_name + ).device_count() def get_vendor(self, vendor_name=None) -> tuple: # Try to get the vendor name from a quick special command like 'torch.mlu'. diff --git a/src/flag_gems/runtime/moduel_tool.py b/src/flag_gems/runtime/moduel_tool.py new file mode 100644 index 000000000..a99b403c0 --- /dev/null +++ b/src/flag_gems/runtime/moduel_tool.py @@ -0,0 +1,17 @@ +import triton + +from . import backend +from .backend.device import DeviceDetector + +device = DeviceDetector() +backend.set_torch_backend_device_fn(device.vendor_name) +tl_extra_module = None +if tl_extra_module is None: + try: + backend.set_tl_extra_backend_module(device.vendor_name) + tl_extra_module = backend.get_tl_extra_backend_module() + except ImportError: + try: + tl_extra_module = triton.language.math + except ImportError: + tl_extra_module = triton.language.libdevice diff --git a/src/flag_gems/runtime/register.py b/src/flag_gems/runtime/register.py index 2146070c8..dbdbbfe57 100644 --- a/src/flag_gems/runtime/register.py +++ b/src/flag_gems/runtime/register.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from . import backend, commom_utils, error @@ -11,10 +9,10 @@ class Register: def __init__( self, - config: Optional[tuple[tuple]], - user_unused_ops_list: Optional[list[str]] = None, - lib: Optional[any] = None, - debug: Optional[bool] = False, + config, + user_unused_ops_list=None, + lib=None, + debug=False, ): # lib is a instance of torch.library.Library self.lib = lib @@ -77,20 +75,20 @@ def _set_info(self, config): fn_name ) if hasbackward else self.forward_ops.append(fn_name) - def get_forward_ops(self) -> list[str]: + def get_forward_ops(self): return self.forward_ops if self.debug else [] - def get_backward_ops(self) -> list[str]: - return self.backward_opss if self.debug else [] + def get_backward_ops(self): + return self.backward_ops if self.debug else [] - def get_unused_ops(self) -> list[str]: + def get_unused_ops(self): return self.unused_ops - def get_vendor_name(self) -> str: + def get_vendor_name(self): return self.device.vendor_name - def get_current_device(self) -> str: + def get_current_device(self): return self.device.name - def support_backward(self, fn) -> bool: - return fn.__name__ in self.backend_ops + def support_backward(self, fn): + return fn.__name__ in self.backward_ops diff --git a/src/flag_gems/utils/libentry.py b/src/flag_gems/utils/libentry.py index b681909e5..0df60ff69 100644 --- a/src/flag_gems/utils/libentry.py +++ b/src/flag_gems/utils/libentry.py @@ -1,10 +1,10 @@ import inspect import threading -import torch import triton from .. import runtime +from ..runtime import torch_device_fn DEVICE_COUNT = runtime.device.device_count @@ -92,7 +92,7 @@ def run(self, *args, **kwargs): k_args.append(val) entry_key = self.key(spec_args, dns_args, const_args) - device = torch.cuda.current_device() + device = torch_device_fn.current_device() cache = self.kernel_cache[device] while entry_key not in cache: # NOTE: we serialize the first run of a jit function regardless of which device to run on diff --git a/src/flag_gems/utils/pointwise_dynamic.py b/src/flag_gems/utils/pointwise_dynamic.py index 208bf5b94..69310defd 100644 --- a/src/flag_gems/utils/pointwise_dynamic.py +++ b/src/flag_gems/utils/pointwise_dynamic.py @@ -902,7 +902,7 @@ def gen_kernel_launch( else: code.writeline(f"out{i}_stride_order = (0,)") - code.writeline("with torch.cuda._DeviceGuard(in0.device.index):") + code.writeline("with torch_device_fn._DeviceGuard(in0.device.index):") with code.indent(): code.writeline(f"{self.jit_fn_name}[grid](") with code.indent(): @@ -962,7 +962,7 @@ def gen_kernel_launch_1d( for i in range(schema.num_output_tensors()): code.writeline(f"out{i}_strides = out{i}.stride()") - code.writeline("with torch.cuda._DeviceGuard(in0.device.index):") + code.writeline("with torch_device_fn._DeviceGuard(in0.device.index):") with code.indent(): code.writeline(f"{self.jit_fn_name}[grid](") with code.indent(): @@ -1058,6 +1058,7 @@ def generate_imports(code: IndentedBuffer) -> IndentedBuffer: code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") code.writeline("from flag_gems.utils.libentry import libentry") code.writeline("from flag_gems.utils import triton_lang_extension as tle") + code.writeline("from flag_gems.runtime import torch_device_fn") code.newline() code.newline() return code diff --git a/src/flag_gems/utils/random_utils.py b/src/flag_gems/utils/random_utils.py index 32b511fce..22f7155ce 100644 --- a/src/flag_gems/utils/random_utils.py +++ b/src/flag_gems/utils/random_utils.py @@ -2,6 +2,8 @@ import triton import triton.language as tl +from ..runtime import torch_device_fn + try: uint_to_uniform_float = tl.uint_to_uniform_float except AttributeError: @@ -32,9 +34,9 @@ def uint_to_uniform_float(x): # https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452 # It returns the current state of the default Philox RNG in seed and offset and # updates the next offset by adding `increment`. -def philox_cuda_seed_offset(increment, device=None): - device = device or torch.cuda.current_device() - gen = torch.cuda.default_generators[device] +def philox_backend_seed_offset(increment, device=None): + device = device or torch_device_fn.current_device() + gen = torch_device_fn.default_generators[device] state_copy = gen.get_state() c0, c1 = state_copy.view(torch.int64) seed, offset = int(c0), int(c1) diff --git a/tests/conftest.py b/tests/conftest.py index 3c8f5b86a..27d4e00c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,18 @@ import json import logging +import flag_gems + +device = flag_gems.device + def pytest_addoption(parser): parser.addoption( "--ref", action="store", - default="cuda", + default=device, required=False, - choices=["cuda", "cpu"], + choices=[device, "cpu"], help="device to run reference tests on", ) parser.addoption( diff --git a/tests/ks_tests.py b/tests/ks_tests.py index ba91565ab..32dc5e1ba 100644 --- a/tests/ks_tests.py +++ b/tests/ks_tests.py @@ -14,8 +14,10 @@ @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_normal_pvalue(shape, dtype): - loc = torch.full(size=shape, fill_value=3.0, dtype=dtype, device="cuda") - scale = torch.full(size=shape, fill_value=10.0, dtype=dtype, device="cuda") + loc = torch.full(size=shape, fill_value=3.0, dtype=dtype, device=flag_gems.device) + scale = torch.full( + size=shape, fill_value=10.0, dtype=dtype, device=flag_gems.device + ) with flag_gems.use_gems(): res_out = torch.distributions.normal.Normal(loc, scale).sample() pvalue = scipy.stats.kstest( @@ -28,7 +30,7 @@ def test_accuracy_normal_pvalue(shape, dtype): @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_accuracy_uniform_pvalue(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=flag_gems.device) with flag_gems.use_gems(): x.uniform_(-3, 3) pvalue = scipy.stats.kstest( @@ -42,7 +44,7 @@ def test_accuracy_uniform_pvalue(shape, dtype): @pytest.mark.parametrize("dtype", (torch.float32,)) @pytest.mark.parametrize("lambd", (0.01, 0.5, 100.0)) def test_accuracy_exponential_pvalue(shape, dtype, lambd): - x = torch.empty(size=shape, dtype=dtype, device="cuda") + x = torch.empty(size=shape, dtype=dtype, device=flag_gems.device) with flag_gems.use_gems(): x.exponential_(lambd=lambd) expo_cdf = lambda x: np.where(x < 0, 0, 1.0 - np.exp(-lambd * x)) @@ -54,7 +56,7 @@ def test_accuracy_exponential_pvalue(shape, dtype, lambd): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_accuracy_rand_pvalue(shape, dtype): with flag_gems.use_gems(): - res_out = torch.rand(shape, dtype=dtype, device="cuda") + res_out = torch.rand(shape, dtype=dtype, device=flag_gems.device) pvalue = scipy.stats.kstest( res_out.cpu().numpy().flatten(), lambda x: scipy.stats.uniform.cdf(x) ).pvalue @@ -65,7 +67,7 @@ def test_accuracy_rand_pvalue(shape, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_accuracy_randn_pvalue(shape, dtype): with flag_gems.use_gems(): - res_out = torch.randn(shape, dtype=dtype, device="cuda") + res_out = torch.randn(shape, dtype=dtype, device=flag_gems.device) pvalue = scipy.stats.kstest( res_out.cpu().numpy().flatten(), lambda x: scipy.stats.norm.cdf(x) ).pvalue diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index e29eab238..822a520d4 100644 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -31,8 +31,8 @@ def replace_zeros(inp): @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_add(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -49,7 +49,7 @@ def test_accuracy_add(shape, alpha, dtype): @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_add_tensor_scalar(shape, scalar, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = scalar ref_inp1 = to_reference(inp1, True) @@ -67,7 +67,7 @@ def test_accuracy_add_tensor_scalar(shape, scalar, alpha, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_add_scalar_tensor(shape, scalar, alpha, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2, True) ref_out = torch.add(inp1, ref_inp2, alpha=alpha) @@ -104,14 +104,14 @@ def test_accuracy_add_scalar_scalar(dtype): @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseand(shape, dtype): if dtype in BOOL_TYPES: - inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) else: inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -128,11 +128,11 @@ def test_accuracy_bitwiseand(shape, dtype): @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseand_scalar(shape, dtype): if dtype in BOOL_TYPES: - inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) inp2 = bool(random.randint(0, 2)) else: inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) inp2 = 0x00FF ref_inp1 = to_reference(inp1) @@ -150,11 +150,11 @@ def test_accuracy_bitwiseand_scalar(shape, dtype): def test_accuracy_bitwiseand_scalar_tensor(shape, dtype): if dtype in BOOL_TYPES: inp1 = bool(random.randint(0, 2)) - inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) else: inp1 = 0x00FF inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) ref_inp2 = to_reference(inp2) @@ -171,14 +171,14 @@ def test_accuracy_bitwiseand_scalar_tensor(shape, dtype): @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseor(shape, dtype): if dtype in BOOL_TYPES: - inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) else: inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -196,11 +196,11 @@ def test_accuracy_bitwiseor(shape, dtype): @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwiseor_scalar(shape, dtype): if dtype in BOOL_TYPES: - inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) inp2 = bool(random.randint(0, 2)) else: inp1 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) inp2 = 0x00FF ref_inp1 = to_reference(inp1) @@ -219,11 +219,13 @@ def test_accuracy_bitwiseor_scalar(shape, dtype): def test_accuracy_bitwiseor_scalar_tensor(shape, dtype): if dtype in BOOL_TYPES: inp1 = bool(random.randint(0, 2)) - inp2 = torch.randint(0, 2, size=shape, dtype=torch.bool, device="cuda") + inp2 = torch.randint( + 0, 2, size=shape, dtype=torch.bool, device=flag_gems.device + ) else: inp1 = 0x00FF inp2 = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) ref_inp2 = to_reference(inp2) @@ -241,7 +243,7 @@ def test_accuracy_bitwiseor_scalar_tensor(shape, dtype): @pytest.mark.parametrize("isnone", [None, "max", "min"]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_clamp(shape, maxi, mini, isnone, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) if isnone == "min": mini = None elif isnone == "max": @@ -260,9 +262,9 @@ def test_accuracy_clamp(shape, maxi, mini, isnone, dtype): @pytest.mark.parametrize("isnone", [None, "max", "min"]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_clamp_tensor(shape, isnone, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - maxi = torch.randn(shape, dtype=dtype, device="cuda") - mini = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + maxi = torch.randn(shape, dtype=dtype, device=flag_gems.device) + mini = torch.randn(shape, dtype=dtype, device=flag_gems.device) if isnone == "min": mini = None elif isnone == "max": @@ -282,8 +284,8 @@ def test_accuracy_clamp_tensor(shape, isnone, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_div_tensor_tensor(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, False) ref_inp2 = to_reference(inp2, False) @@ -299,7 +301,7 @@ def test_accuracy_div_tensor_tensor(shape, dtype): @pytest.mark.parametrize("scalar", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_div_tensor_scalar(shape, scalar, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = scalar ref_inp1 = to_reference(inp1, False) @@ -316,7 +318,7 @@ def test_accuracy_div_tensor_scalar(shape, scalar, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_div_scalar_tensor(shape, scalar, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2, False) ref_out = torch.div(inp1, ref_inp2) @@ -352,8 +354,8 @@ def test_accuracy_div_scalar_scalar(dtype): # Note : tl.math.div_rz only support float32, cast will cause diff # with torch, so we only do float32 test for now. def test_accuracy_trunc_div(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -394,8 +396,8 @@ def test_accuracy_trunc_divide_scalar_scalar(dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", [torch.float32]) def test_accuracy_floor_div_float(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, False) ref_inp2 = to_reference(inp2, False) @@ -415,14 +417,14 @@ def test_accuracy_floor_div_int(shape, dtype): torch.iinfo(dtype).max, shape, dtype=dtype, - device="cuda", + device=flag_gems.device, ) inp2 = torch.randint( torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype, - device="cuda", + device=flag_gems.device, ) if TO_CPU: inp1 = replace_zeros(inp1) @@ -478,14 +480,14 @@ def test_accuracy_remainder(shape, dtype): torch.iinfo(dtype).max, shape, dtype=dtype, - device="cuda", + device=flag_gems.device, ) inp2 = torch.randint( torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype, - device="cuda", + device=flag_gems.device, ) if TO_CPU: inp1 = replace_zeros(inp1) @@ -516,8 +518,8 @@ def test_accuracy_remainder(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_eq(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -532,7 +534,7 @@ def test_accuracy_eq(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_eq_scalar(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) inp2 = 0 ref_inp1 = to_reference(inp1) @@ -547,8 +549,8 @@ def test_accuracy_eq_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_ge(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -563,7 +565,7 @@ def test_accuracy_ge(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_ge_scalar(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = 0 ref_inp1 = to_reference(inp1) @@ -579,8 +581,8 @@ def test_accuracy_ge_scalar(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("approximate", ["none", "tanh"]) def test_accuracy_gelu_and_mul(shape, approximate, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -597,8 +599,8 @@ def test_accuracy_gelu_and_mul(shape, approximate, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_gt(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -613,7 +615,7 @@ def test_accuracy_gt(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_gt_scalar(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) inp2 = 0 @@ -628,8 +630,8 @@ def test_accuracy_gt_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_le(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -644,7 +646,7 @@ def test_accuracy_le(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_le_scalar(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = 0 ref_inp1 = to_reference(inp1) @@ -659,8 +661,8 @@ def test_accuracy_le_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_lt(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -675,7 +677,7 @@ def test_accuracy_lt(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_lt_scalar(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = 0 ref_inp1 = to_reference(inp1) @@ -690,8 +692,8 @@ def test_accuracy_lt_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mul(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -707,7 +709,7 @@ def test_accuracy_mul(shape, dtype): @pytest.mark.parametrize("scalar", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mul_tensor_scalar(shape, scalar, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = scalar ref_inp1 = to_reference(inp1, True) @@ -724,7 +726,7 @@ def test_accuracy_mul_tensor_scalar(shape, scalar, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mul_scalar_tensor(shape, scalar, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2, True) ref_out = torch.mul(inp1, ref_inp2) @@ -758,8 +760,8 @@ def test_accuracy_mul_scalar_scalar(dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_ne(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") - inp2 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -774,7 +776,7 @@ def test_accuracy_ne(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_ne_scalar(shape, dtype): - inp1 = torch.randint(0, 10, shape, dtype=dtype, device="cuda") + inp1 = torch.randint(0, 10, shape, dtype=dtype, device=flag_gems.device) inp2 = 0 ref_inp1 = to_reference(inp1) @@ -789,8 +791,8 @@ def test_accuracy_ne_scalar(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_pow(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -805,8 +807,8 @@ def test_accuracy_pow(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_maximum(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -821,8 +823,8 @@ def test_accuracy_maximum(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_minimum(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -839,7 +841,7 @@ def test_accuracy_minimum(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_pow_scalar_tensor(scalar, shape, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2, True) ref_out = torch.pow(inp1, ref_inp2) @@ -854,7 +856,7 @@ def test_accuracy_pow_scalar_tensor(scalar, shape, dtype): @pytest.mark.parametrize("scalar", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_pow_tensor_scalar(scalar, shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = scalar ref_inp1 = to_reference(inp1, True) @@ -870,8 +872,8 @@ def test_accuracy_pow_tensor_scalar(scalar, shape, dtype): @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_rsub(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -886,8 +888,8 @@ def test_accuracy_rsub(shape, alpha, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_silu_and_mul(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -903,8 +905,8 @@ def test_accuracy_silu_and_mul(shape, dtype): @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sub(shape, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -921,7 +923,7 @@ def test_accuracy_sub(shape, alpha, dtype): @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sub_tensor_scalar(shape, scalar, alpha, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp2 = scalar ref_inp1 = to_reference(inp1, True) @@ -939,7 +941,7 @@ def test_accuracy_sub_tensor_scalar(shape, scalar, alpha, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sub_scalar_tensor(shape, scalar, alpha, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2, True) ref_out = torch.sub(inp1, ref_inp2, alpha=alpha) @@ -975,9 +977,9 @@ def test_accuracy_sub_scalar_scalar(dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_where_self_out_cross_device(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - cond = torch.randint(0, 2, shape, dtype=torch.bool, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + cond = torch.randint(0, 2, shape, dtype=torch.bool, device=flag_gems.device) import itertools @@ -1002,10 +1004,10 @@ def test_accuracy_where_self_out_cross_device(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_where_self_out(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") - cond = torch.randint(0, 2, shape, dtype=torch.bool, device="cuda") - out = torch.empty(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + cond = torch.randint(0, 2, shape, dtype=torch.bool, device=flag_gems.device) + out = torch.empty(shape, dtype=dtype, device=flag_gems.device) ref_out = to_reference(out) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -1022,8 +1024,8 @@ def test_accuracy_where_self_out(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_where_self(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) @@ -1040,7 +1042,7 @@ def test_accuracy_where_self(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_where_scalar_self(shape, scalar, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2) ref_out = torch.where(ref_inp2 > 0, inp1, ref_inp2) @@ -1056,7 +1058,7 @@ def test_accuracy_where_scalar_self(shape, scalar, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_where_scalar_other(shape, scalar, dtype): inp1 = scalar - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp2 = to_reference(inp2) ref_out = torch.where(ref_inp2 > 0, ref_inp2, inp1) @@ -1075,30 +1077,31 @@ def test_accuracy_where_scalar_other(shape, scalar, dtype): def test_accuracy_isclose(shape, dtype, zero_tol, equal_nan, gen_nan): # [gen_nan] 1: nan, 2: inf, 3: -inf, 4: inf vs -inf rtol = ( - torch.rand(1, dtype=torch.float32, device="cuda").item() * 0.0001 + torch.rand(1, dtype=torch.float32, device=flag_gems.device).item() * 0.0001 if not zero_tol else 0 ) if dtype in ALL_FLOAT_DTYPES: - inp1 = torch.randn(shape, dtype=dtype, device="cuda") - inp2 = torch.randn(shape, dtype=dtype, device="cuda") + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) if gen_nan: nan_num = torch.full( (1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, - device="cuda", + device=flag_gems.device, ) inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num atol = ( - torch.finfo(dtype).tiny * torch.randint(0, 4, (1,), device="cuda").item() + torch.finfo(dtype).tiny + * torch.randint(0, 4, (1,), device=flag_gems.device).item() if not zero_tol else 0 ) else: - inp1 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype) - inp2 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype) + inp1 = torch.randint(-1000, 1000, shape, device=flag_gems.device).to(dtype) + inp2 = torch.randint(-1000, 1000, shape, device=flag_gems.device).to(dtype) if dtype in [torch.int64]: inp1.view(-1)[0] = 2**63 - 1 inp2.view(-1)[0] = -(2**63) @@ -1125,7 +1128,7 @@ def test_accuracy_isclose(shape, dtype, zero_tol, equal_nan, gen_nan): atol = ( ( torch.finfo(torch.float16).eps - * torch.randint(0, 10, (1,), device="cuda").item() + * torch.randint(0, 10, (1,), device=flag_gems.device).item() ) if not zero_tol else 0 @@ -1172,29 +1175,32 @@ def test_accuracy_isclose(shape, dtype, zero_tol, equal_nan, gen_nan): @pytest.mark.parametrize("gen_nan", [0, 1, 2, 3, 4]) def test_accuracy_allclose(shape, dtype, equal_nan, gen_nan): # [gen_nan] 1: nan, 2: inf, 3: -inf, 4: inf vs -inf - rtol = torch.rand(1, dtype=torch.float32, device="cuda").item() * ( + rtol = torch.rand(1, dtype=torch.float32, device=flag_gems.device).item() * ( 0.0001 if dtype in [torch.bfloat16, torch.float16] else 0.01 ) if dtype in ALL_FLOAT_DTYPES: - atol = torch.finfo(dtype).tiny * torch.randint(0, 4, (1,), device="cuda").item() - inp1 = torch.full(shape, 1.234, dtype=dtype, device="cuda") - inp2 = torch.full(shape, 1.234, dtype=dtype, device="cuda") + atol = ( + torch.finfo(dtype).tiny + * torch.randint(0, 4, (1,), device=flag_gems.device).item() + ) + inp1 = torch.full(shape, 1.234, dtype=dtype, device=flag_gems.device) + inp2 = torch.full(shape, 1.234, dtype=dtype, device=flag_gems.device) if gen_nan: nan_num = torch.full( (1,), float("nan" if gen_nan == 1 else "inf"), dtype=dtype, - device="cuda", + device=flag_gems.device, ) inp1.view(-1)[0] = -nan_num if gen_nan == 3 else nan_num inp2.view(-1)[0] = -nan_num if gen_nan >= 3 else nan_num else: atol = ( torch.finfo(torch.float16).eps - * torch.randint(0, 10, (1,), device="cuda").item() + * torch.randint(0, 10, (1,), device=flag_gems.device).item() ) - inp1 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype) - inp2 = torch.randint(-1000, 1000, shape, device="cuda").to(dtype) + inp1 = torch.randint(-1000, 1000, shape, device=flag_gems.device).to(dtype) + inp2 = torch.randint(-1000, 1000, shape, device=flag_gems.device).to(dtype) ref_inp1 = to_reference(inp1, False) ref_inp2 = to_reference(inp2, False) diff --git a/tests/test_blas_ops.py b/tests/test_blas_ops.py index 4bbedd9c0..a7773f011 100644 --- a/tests/test_blas_ops.py +++ b/tests/test_blas_ops.py @@ -20,9 +20,9 @@ @pytest.mark.parametrize("scalar", SCALARS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_addmm(M, N, K, scalar, dtype): - mat1 = torch.randn((M, K), dtype=dtype, device="cuda") - mat2 = torch.randn((K, N), dtype=dtype, device="cuda") - bias = torch.randn((N,), dtype=dtype, device="cuda") + mat1 = torch.randn((M, K), dtype=dtype, device=flag_gems.device) + mat2 = torch.randn((K, N), dtype=dtype, device=flag_gems.device) + bias = torch.randn((N,), dtype=dtype, device=flag_gems.device) ref_mat1 = to_reference(mat1, True) ref_mat2 = to_reference(mat2, True) ref_bias = to_reference(bias, True) @@ -41,8 +41,8 @@ def test_accuracy_addmm(M, N, K, scalar, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_bmm(M, N, K, dtype): batch = 4 - mat1 = torch.randn((batch, M, K), dtype=dtype, device="cuda") - mat2 = torch.randn((batch, K, N), dtype=dtype, device="cuda") + mat1 = torch.randn((batch, M, K), dtype=dtype, device=flag_gems.device) + mat2 = torch.randn((batch, K, N), dtype=dtype, device=flag_gems.device) ref_mat1 = to_reference(mat1, True) ref_mat2 = to_reference(mat2, True) @@ -58,8 +58,8 @@ def test_accuracy_bmm(M, N, K, dtype): @pytest.mark.parametrize("M, N, K", MNK_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mm(M, N, K, dtype): - mat1 = torch.randn((M, K), dtype=dtype, device="cuda") - mat2 = torch.randn((K, N), dtype=dtype, device="cuda") + mat1 = torch.randn((M, K), dtype=dtype, device=flag_gems.device) + mat2 = torch.randn((K, N), dtype=dtype, device=flag_gems.device) ref_mat1 = to_reference(mat1, True) ref_mat2 = to_reference(mat2, True) @@ -74,8 +74,8 @@ def test_accuracy_mm(M, N, K, dtype): @pytest.mark.parametrize("M, N", MN_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mv(M, N, dtype): - matrix = torch.randn((N, M), dtype=dtype, device="cuda") - vector = torch.randn((M,), dtype=dtype, device="cuda") + matrix = torch.randn((N, M), dtype=dtype, device=flag_gems.device) + vector = torch.randn((M,), dtype=dtype, device=flag_gems.device) ref_matrix = to_reference(matrix, True) ref_vector = to_reference(vector, True) @@ -90,8 +90,8 @@ def test_accuracy_mv(M, N, dtype): @pytest.mark.parametrize("M, N", MN_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_outer(M, N, dtype): - inp1 = torch.randn(M, dtype=dtype, device="cuda", requires_grad=True) - inp2 = torch.randn(N, dtype=dtype, device="cuda", requires_grad=True) + inp1 = torch.randn(M, dtype=dtype, device=flag_gems.device, requires_grad=True) + inp2 = torch.randn(N, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) diff --git a/tests/test_distribution_ops.py b/tests/test_distribution_ops.py index 04083ea17..4554eff47 100644 --- a/tests/test_distribution_ops.py +++ b/tests/test_distribution_ops.py @@ -7,6 +7,8 @@ from .accuracy_utils import DISTRIBUTION_SHAPES, FLOAT_DTYPES +device = flag_gems.device + @pytest.mark.normal @pytest.mark.parametrize("float", ["none", "mean", "std"]) @@ -16,12 +18,16 @@ def test_accuracy_normal(float, shape, dtype): loc = ( 3.0 if float == "mean" - else torch.full(size=shape, fill_value=3.0, dtype=dtype, device="cuda") + else torch.full( + size=shape, fill_value=3.0, dtype=dtype, device=flag_gems.device + ) ) scale = ( 10.0 if float == "std" - else torch.full(size=shape, fill_value=10.0, dtype=dtype, device="cuda") + else torch.full( + size=shape, fill_value=10.0, dtype=dtype, device=flag_gems.device + ) ) with flag_gems.use_gems(): res_out = torch.normal(loc, scale) @@ -35,7 +41,7 @@ def test_accuracy_normal(float, shape, dtype): @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_uniform(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=flag_gems.device) with flag_gems.use_gems(): x.uniform_(-3, 3) assert (x <= 3.0).all() @@ -46,7 +52,7 @@ def test_accuracy_uniform(shape, dtype): @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_exponential_(shape, dtype): - x = torch.empty(size=shape, dtype=dtype, device="cuda") + x = torch.empty(size=shape, dtype=dtype, device=flag_gems.device) with flag_gems.use_gems(): x.exponential_() assert x.min() > 0 @@ -59,7 +65,7 @@ def test_accuracy_exponential_(shape, dtype): def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples): # First use multinomial to generate a series of indices, then # use the index counts as the input probabilities (scaled) - rand_indices = torch.multinomial(torch.rand(shape), n_samples, True).to("cuda") + rand_indices = torch.multinomial(torch.rand(shape), n_samples, True).to(device) inp_counts = torch.nn.functional.one_hot(rand_indices).sum(1) with flag_gems.use_gems(): out_indices = torch.multinomial(inp_counts.to(dtype=dtype), n_samples, True) diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index b05fb30ff..41700cf53 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -43,9 +43,9 @@ @pytest.mark.parametrize("kind", ["normal", "allTrue"]) def test_accuracy_all_without_dim(shape, dtype, kind): if kind == "allTrue": - inp = torch.ones(shape, dtype=dtype, device="cuda") + inp = torch.ones(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.all(ref_inp) @@ -61,9 +61,9 @@ def test_accuracy_all_without_dim(shape, dtype, kind): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) def test_accuracy_all_dims(shape, dim, keepdim, dtype, kind): if kind == "allTrue": - inp = torch.ones(shape, dtype=dtype, device="cuda") + inp = torch.ones(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.all(ref_inp, dim=dim, keepdim=keepdim) @@ -79,9 +79,9 @@ def test_accuracy_all_dims(shape, dim, keepdim, dtype, kind): @pytest.mark.parametrize("kind", ["normal", "allFalse"]) def test_accuracy_any_without_dim(shape, dtype, kind): if kind == "allFalse": - inp = torch.zeros(shape, dtype=dtype, device="cuda") + inp = torch.zeros(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.any(ref_inp) @@ -97,9 +97,9 @@ def test_accuracy_any_without_dim(shape, dtype, kind): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): if kind == "allFalse": - inp = torch.zeros(shape, dtype=dtype, device="cuda") + inp = torch.zeros(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.any(ref_inp, dim=dim, keepdim=keepdim) @@ -113,7 +113,7 @@ def test_accuracy_any_dims(shape, dim, keepdim, dtype, kind): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_max_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -127,7 +127,7 @@ def test_accuracy_max_without_dim(shape, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_max_without_dim_uncontiguous(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda")[::2, ::2] + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2, ::2] ref_inp = to_reference(inp) ref_out = torch.max(ref_inp) @@ -143,7 +143,7 @@ def test_accuracy_max_without_dim_uncontiguous(shape, dtype): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_max_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) @@ -159,7 +159,7 @@ def test_accuracy_max_dim(shape, dim, keepdim, dtype): @pytest.mark.parametrize("keepdim, dim", [(True, 1), (False, 1)]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_max_dim_big_shape(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.max(ref_inp, dim=dim, keepdim=keepdim) @@ -174,7 +174,7 @@ def test_accuracy_max_dim_big_shape(shape, dim, keepdim, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mean_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.mean(ref_inp) @@ -189,7 +189,7 @@ def test_accuracy_mean_without_dim(shape, dtype): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIMS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_mean_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.mean(ref_inp, dim, keepdim) @@ -203,7 +203,7 @@ def test_accuracy_mean_dim(shape, dim, keepdim, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_min_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.min(ref_inp) @@ -219,7 +219,7 @@ def test_accuracy_min_without_dim(shape, dtype): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_min_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out_value, ref_out_index = torch.min(ref_inp, dim=dim, keepdim=keepdim) @@ -234,7 +234,7 @@ def test_accuracy_min_dim(shape, dim, keepdim, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_prod_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.prod(ref_inp) @@ -250,7 +250,7 @@ def test_accuracy_prod_without_dim(shape, dtype): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_prod_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.prod(ref_inp, dim=dim, keepdim=keepdim) @@ -264,7 +264,7 @@ def test_accuracy_prod_dim(shape, dim, keepdim, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sum_without_dim(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.sum(ref_inp) @@ -279,7 +279,7 @@ def test_accuracy_sum_without_dim(shape, dtype): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIM + [(False, []), (True, [])]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sum_dim(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.sum(ref_inp, dim=dim, keepdim=keepdim) diff --git a/tests/test_libentry.py b/tests/test_libentry.py index e365acfe2..fa3777d04 100644 --- a/tests/test_libentry.py +++ b/tests/test_libentry.py @@ -5,6 +5,8 @@ import triton from triton import language as tl +import flag_gems +from flag_gems.runtime import torch_device_fn from flag_gems.utils import libentry @@ -34,7 +36,7 @@ def softmax_inner_decorator_cascade(x, dim, dtype=None): out = torch.empty_like(inp, dtype=dtype) - with torch.cuda.device(out.device): + with torch_device_fn.device(out.device): grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) softmax_kernel_inner[grid]( out, @@ -168,19 +170,19 @@ def softmax_kernel_inner( def test_decorator_cascade(): # to test inner decorator can use arguments supplied by outer decorator # and grid function can use arguments supplied by all the decorator - x = torch.randn((128, 128, 128), device="cuda") + x = torch.randn((128, 128, 128), device=flag_gems.device) with not_raises(KeyError): _ = softmax_inner_decorator_cascade(x, dim=2) def test_pass_kernel_arg_via_kw(): - x = torch.randn((128, 128, 128), device="cuda") + x = torch.randn((128, 128, 128), device=flag_gems.device) with not_raises(KeyError): _ = softmax_inner_pass_kernel_arg_via_kw(x, dim=2) def test_kernel_arg_apply_default(): - x = torch.randn((128, 128, 128), device="cuda") + x = torch.randn((128, 128, 128), device=flag_gems.device) with not_raises(KeyError): _ = softmax_inner_kernel_arg_apply_default(x, dim=2) diff --git a/tests/test_norm_ops.py b/tests/test_norm_ops.py index eedb38ce8..e9579a1de 100644 --- a/tests/test_norm_ops.py +++ b/tests/test_norm_ops.py @@ -35,9 +35,15 @@ @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_groupnorm(N, C, H, W, num_groups, dtype): HW = H * W - inp = torch.randn(size=(N, C, H, W), dtype=dtype, device="cuda", requires_grad=True) - weight = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) - bias = torch.randn(size=(C,), dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn( + size=(N, C, H, W), dtype=dtype, device=flag_gems.device, requires_grad=True + ) + weight = torch.randn( + size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=True + ) + bias = torch.randn( + size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=True + ) eps = 1e-5 ref_inp = to_reference(inp, True) @@ -95,9 +101,15 @@ def test_accuracy_layernorm(shape, dtype): layer_shape = [ N, ] - inp = torch.randn(shape[:2], dtype=dtype, device="cuda", requires_grad=True) - weight = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) - bias = torch.randn(layer_shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn( + shape[:2], dtype=dtype, device=flag_gems.device, requires_grad=True + ) + weight = torch.randn( + layer_shape, dtype=dtype, device=flag_gems.device, requires_grad=True + ) + bias = torch.randn( + layer_shape, dtype=dtype, device=flag_gems.device, requires_grad=True + ) eps = 1e-5 ref_inp = to_reference(inp, True) @@ -256,11 +268,11 @@ def test_accuracy_instancenorm( @pytest.mark.parametrize("shape, dtype, dim", WEIGHT_NORM_SHAPE_DTYPE_DIM) def test_accuracy_weightnorm(shape, dtype, dim): dim = dim % len(shape) - v = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) g = torch.randn( [1 if i != dim else shape[i] for i in range(v.ndim)], dtype=dtype, - device="cuda", + device=flag_gems.device, requires_grad=True, ) reduce_size = v.numel() // shape[dim] @@ -272,7 +284,9 @@ def test_accuracy_weightnorm(shape, dtype, dim): res_w_out = torch._weight_norm(v, g, dim) gems_assert_close(res_w_out, ref_w_out, dtype, reduce_dim=reduce_size) - res_w_grad = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + res_w_grad = torch.randn( + shape, dtype=dtype, device=flag_gems.device, requires_grad=True + ) ref_w_grad = to_reference(res_w_grad, False) ref_v_grad, ref_g_grad = torch.autograd.grad( @@ -289,8 +303,10 @@ def test_accuracy_weightnorm(shape, dtype, dim): @pytest.mark.parametrize("shape, dtype, dim", WEIGHT_NORM_SHAPE_DTYPE_DIM) def test_accuracy_weightnorm_interface(shape, dtype, dim): dim = dim % len(shape) - v = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - g = torch.randn(shape[dim], dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + g = torch.randn( + shape[dim], dtype=dtype, device=flag_gems.device, requires_grad=True + ) reduce_size = v.numel() // shape[dim] ref_v = to_reference(v, True) @@ -326,8 +342,8 @@ def test_accuracy_rmsnorm(shape, dtype): layer_shape = [ N, ] - inp = torch.randn(shape[:2], dtype=dtype, device="cuda") - weight = torch.randn(layer_shape, dtype=dtype, device="cuda") + inp = torch.randn(shape[:2], dtype=dtype, device=flag_gems.device) + weight = torch.randn(layer_shape, dtype=dtype, device=flag_gems.device) eps = 1e-5 ref_inp = to_reference(inp, True) @@ -353,10 +369,10 @@ def test_accuracy_skip_layernorm(shape, dtype): layer_shape = [ N, ] - inp = torch.randn(shape[:2], dtype=dtype, device="cuda") - residual = torch.randn(shape[:2], dtype=dtype, device="cuda") - weight = torch.randn(layer_shape, dtype=dtype, device="cuda") - bias = torch.randn(layer_shape, dtype=dtype, device="cuda") + inp = torch.randn(shape[:2], dtype=dtype, device=flag_gems.device) + residual = torch.randn(shape[:2], dtype=dtype, device=flag_gems.device) + weight = torch.randn(layer_shape, dtype=dtype, device=flag_gems.device) + bias = torch.randn(layer_shape, dtype=dtype, device=flag_gems.device) eps = 1e-5 ref_inp = to_reference(inp, True) @@ -386,9 +402,9 @@ def test_accuracy_skip_rmsnorm(shape, dtype): layer_shape = [ N, ] - inp = torch.randn(shape[:2], dtype=dtype, device="cuda") - residual = torch.randn(shape[:2], dtype=dtype, device="cuda") - weight = torch.randn(layer_shape, dtype=dtype, device="cuda") + inp = torch.randn(shape[:2], dtype=dtype, device=flag_gems.device) + residual = torch.randn(shape[:2], dtype=dtype, device=flag_gems.device) + weight = torch.randn(layer_shape, dtype=dtype, device=flag_gems.device) eps = 1e-5 ref_inp = to_reference(inp, True) @@ -423,7 +439,7 @@ def _torch_rms_norm(x, residual, weight, eps): @pytest.mark.parametrize("keepdim, dim", KEEPDIM_DIMS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.linalg.vector_norm(ref_inp, ord, dim, keepdim) diff --git a/tests/test_pointwise_dynamic.py b/tests/test_pointwise_dynamic.py index 990b585c9..477c8ec86 100644 --- a/tests/test_pointwise_dynamic.py +++ b/tests/test_pointwise_dynamic.py @@ -2,6 +2,8 @@ import torch import triton +import flag_gems +from flag_gems.runtime import torch_device_fn from flag_gems.utils.pointwise_dynamic import ( CodeGenConfig, FunctionSchema, @@ -158,7 +160,7 @@ def add(x, y): SIZE = 2 for ndim in range(8): shape = [SIZE] * ndim - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) out = add(x, y) torch.testing.assert_close(out, x + y) @@ -187,7 +189,7 @@ def axpy(x, y, alpha): SIZE = 2 for ndim in range(8): shape = [SIZE] * ndim - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) alpha = 2.0 out = axpy(x, y, alpha) @@ -218,7 +220,7 @@ def multiple_out(x, y, alpha): SIZE = 2 for ndim in range(8): shape = [SIZE] * ndim - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) alpha = 2.0 out0, out1 = multiple_out(x, y, alpha) @@ -250,8 +252,8 @@ def axpy(x, y, alpha): return alpha * x + y SIZE = 10 - x = torch.randn([SIZE, 1, SIZE], device="cuda") - y = torch.randn([1, SIZE, 1], device="cuda") + x = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + y = torch.randn([1, SIZE, 1], device=flag_gems.device) alpha = 2.0 out = axpy(x, y, alpha) torch.testing.assert_close(out, alpha * x + y) @@ -279,8 +281,8 @@ def axpy(x, y, alpha): return alpha * x + y SIZE = 10 - x = torch.randn([SIZE, 1, SIZE], device="cuda") - y = torch.randn([], device="cuda") + x = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) alpha = 2.0 out = axpy(x, y, alpha) torch.testing.assert_close(out, alpha * x + y) @@ -307,10 +309,10 @@ def axpy(x, y, alpha): return alpha * x + y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) alpha = 2.0 - o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) out = axpy(x, y, alpha, out0=o) torch.testing.assert_close(out, alpha * x + y) @@ -336,10 +338,10 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) alpha = 2.0 - o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) out0, out1 = axpyaxmy(x, y, alpha, out0=o) assert out0 is o torch.testing.assert_close(out0, alpha * x + y) @@ -367,10 +369,10 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([], device=flag_gems.device) alpha = 2.0 - o = torch.empty([SIZE, SIZE, SIZE], device="cuda") + o = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) out0, out1 = axpyaxmy(x, y, alpha, out1=o) assert out1 is o torch.testing.assert_close(out0, alpha * x + y) @@ -398,7 +400,7 @@ def invert(x): return ~x SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") > 0 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) > 0 notx = invert(x) torch.testing.assert_close(notx, ~x) @@ -425,7 +427,7 @@ def invert(x): return ~x SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") > 0 + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) > 0 o = torch.empty_like(x) # manually instantiated overload does not handle output allocation # since it is kind of low level @@ -455,10 +457,10 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y M, N, K = 40, 60, 80 - x = torch.randn([M, N, K], device="cuda")[::2, ::2, ::2] - y = torch.randn([N // 2, K // 2, M // 2], device="cuda").permute(2, 0, 1) + x = torch.randn([M, N, K], device=flag_gems.device)[::2, ::2, ::2] + y = torch.randn([N // 2, K // 2, M // 2], device=flag_gems.device).permute(2, 0, 1) alpha = 2.0 - o = torch.empty([M // 2, N // 2, K // 2], device="cuda") + o = torch.empty([M // 2, N // 2, K // 2], device=flag_gems.device) out0, out1 = axpyaxmy(x, y, alpha, out0=o) assert out0 is o torch.testing.assert_close(out0, alpha * x + y) @@ -486,10 +488,10 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y M, N, K = 40, 60, 80 - x = torch.randn([M, N, K], device="cuda") - y = torch.randn([N, K, M], device="cuda").permute(2, 0, 1) + x = torch.randn([M, N, K], device=flag_gems.device) + y = torch.randn([N, K, M], device=flag_gems.device).permute(2, 0, 1) alpha = 2.0 - o = torch.empty([M, N, K], device="cuda") + o = torch.empty([M, N, K], device=flag_gems.device) out0, out1 = axpyaxmy(x, y, alpha, out0=o) assert out0 is o torch.testing.assert_close(out0, alpha * x + y) @@ -519,11 +521,11 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([SIZE, SIZE, SIZE], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) alpha = 2.0 - _out0 = torch.empty([SIZE, SIZE, SIZE], device="cuda") - _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device="cuda")) + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) assert isinstance(out0, torch.Tensor) @@ -554,11 +556,11 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([1, SIZE], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([1, SIZE], device=flag_gems.device) alpha = 2.0 - _out0 = torch.empty([SIZE, SIZE, SIZE], device="cuda") - _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device="cuda")) + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) with pytest.raises(Exception): out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) @@ -588,11 +590,11 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([SIZE, 1, SIZE], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) alpha = 2.0 - _out0 = torch.empty([SIZE, SIZE, SIZE], device="cuda") - _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device="cuda")) + _out0 = torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device) + _out1 = StridedBuffer(torch.empty([SIZE, SIZE, SIZE], device=flag_gems.device)) with pytest.raises(Exception): out0, out1 = axpyaxmy.instantiate(3)(x, y, alpha, out0=_out0, out1=_out1) @@ -622,8 +624,8 @@ def axpyaxmy(x, y, alpha): return alpha * x + y, alpha * x - y SIZE = 10 - x = torch.randn([SIZE, SIZE, SIZE], device="cuda") - y = torch.randn([SIZE, 1, SIZE], device="cuda") + x = torch.randn([SIZE, SIZE, SIZE], device=flag_gems.device) + y = torch.randn([SIZE, 1, SIZE], device=flag_gems.device) alpha = 2.0 with pytest.raises(Exception): @@ -650,14 +652,14 @@ def add(x, y): SIZE = 2 for ndim in range(8): shape = [SIZE] * ndim - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) out = add(x, y) torch.testing.assert_close(out, x + y) @pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < (80 * 1024**3), + torch_device_fn.get_device_properties(0).total_memory < (80 * 1024**3), reason="This test requires a lot of memory.", ) @pytest.mark.parametrize("use_block_pointer", USE_BLOCK_POINTER) @@ -675,7 +677,7 @@ def test_dynamic_function_int64_index(use_block_pointer): def f(x): return x * 2.0 - x = torch.randn((2, 1024, 1024, 1024), dtype=torch.float16, device="cuda") + x = torch.randn((2, 1024, 1024, 1024), dtype=torch.float16, device=flag_gems.device) y1 = f(x) y2 = x * 2.0 torch.testing.assert_close(y1, y2) @@ -700,7 +702,7 @@ def add(x, y): return x + y shape = () - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) out = add(x, y) torch.testing.assert_close(out, x + y) @@ -723,7 +725,7 @@ def f(x): return x * 2.0 shape = (0, 10) - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) out = f(x) torch.testing.assert_close(out, x * 2.0) @@ -747,7 +749,7 @@ def f(x, y): return x * 2.0 + y shape = (0, 10) - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=flag_gems.device) y = torch.randn_like(x) out = f(x, y) torch.testing.assert_close(out, x * 2.0 + y) diff --git a/tests/test_pointwise_type_promotion.py b/tests/test_pointwise_type_promotion.py index 4b81c1b54..66d1d3332 100644 --- a/tests/test_pointwise_type_promotion.py +++ b/tests/test_pointwise_type_promotion.py @@ -19,8 +19,8 @@ @pytest.mark.parametrize("alpha", SCALARS) @pytest.mark.parametrize("float_type", FLOAT_DTYPES) def test_type_promotion_default(shape, alpha, float_type): - inp1 = torch.randint(10, shape, device="cuda") - inp2 = torch.randn(shape, dtype=float_type, device="cuda") + inp1 = torch.randint(10, shape, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=float_type, device=flag_gems.device) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) # arg0:int arg1:float @@ -38,8 +38,8 @@ def test_type_promotion_default(shape, alpha, float_type): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("float_type", FLOAT_DTYPES) def test_type_promotion_no_opmath(shape, float_type): - inp1 = torch.randint(10, shape, device="cuda") - inp2 = torch.randn(shape, dtype=float_type, device="cuda") + inp1 = torch.randint(10, shape, device=flag_gems.device) + inp2 = torch.randn(shape, dtype=float_type, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) # arg0:bool arg1:int arg2:float @@ -59,7 +59,7 @@ def test_type_promotion_no_opmath(shape, float_type): @pytest.mark.parametrize("float_type", FLOAT_DTYPES) def test_type_promotion_int_to_float(shape, float_type): # arg0:float - inp_float = torch.randn(shape, dtype=float_type, device="cuda") + inp_float = torch.randn(shape, dtype=float_type, device=flag_gems.device) ref_inp = to_reference(inp_float) ref_out = torch.sin(ref_inp) with flag_gems.use_gems(): @@ -67,7 +67,7 @@ def test_type_promotion_int_to_float(shape, float_type): gems_assert_close(res_out, ref_out, float_type) # arg0:int - inp_int = torch.randint(10, shape, device="cuda") + inp_int = torch.randint(10, shape, device=flag_gems.device) ref_inp_int = to_reference(inp_int) ref_out = torch.sin(ref_inp_int) with flag_gems.use_gems(): @@ -78,8 +78,8 @@ def test_type_promotion_int_to_float(shape, float_type): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) def test_type_promotion_always_bool(shape): # arg0:int arg0:int - inp1 = torch.randint(0, 10, shape, device="cuda") - inp2 = torch.randint(0, 10, shape, device="cuda") + inp1 = torch.randint(0, 10, shape, device=flag_gems.device) + inp2 = torch.randint(0, 10, shape, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) ref_out = torch.eq(ref_inp1, ref_inp2) @@ -92,7 +92,7 @@ def test_type_promotion_always_bool(shape): @pytest.mark.parametrize("float_type", FLOAT_DTYPES) def test_type_promotion_complex_to_long(shape, float_type): # arg0:float - inp = torch.randn(shape, dtype=float_type, device="cuda") + inp = torch.randn(shape, dtype=float_type, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.abs(ref_inp) with flag_gems.use_gems(): @@ -100,7 +100,7 @@ def test_type_promotion_complex_to_long(shape, float_type): gems_assert_equal(res_out, ref_out) # arg0:int - inp1 = torch.randint(0, 10, shape, device="cuda") + inp1 = torch.randint(0, 10, shape, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_out1 = torch.abs(ref_inp1) with flag_gems.use_gems(): @@ -111,8 +111,8 @@ def test_type_promotion_complex_to_long(shape, float_type): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) def test_type_promotion_bool_to_long(shape, float_dtype): - inp1 = torch.randn(shape, dtype=float_dtype, device="cuda") - inp2 = torch.randint(0, 10, shape, device="cuda") + inp1 = torch.randn(shape, dtype=float_dtype, device=flag_gems.device) + inp2 = torch.randint(0, 10, shape, device=flag_gems.device) ref_inp1 = to_reference(inp1) ref_inp2 = to_reference(inp2) # arg0: float arg1: int diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index bbaa390f4..f3fbc7037 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -68,7 +68,7 @@ @pytest.mark.parametrize("keepdim, dim, shape", KEEPDIM_DIMS_SHAPE) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_amax(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.amax(ref_inp, dim=dim, keepdim=keepdim) @@ -85,7 +85,7 @@ def test_accuracy_amax(shape, dim, keepdim, dtype): @pytest.mark.parametrize("keepdim", [True, False]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_argmax(shape, dim, keepdim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.argmax(ref_inp, dim=dim, keepdim=keepdim) @@ -108,13 +108,13 @@ def test_accuracy_cross_entropy_loss_indices( target_shape = list(shape) del target_shape[dim] - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - target = torch.randint(0, up_limit, target_shape, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + target = torch.randint(0, up_limit, target_shape, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_target = to_reference(target) if weight: - wgt = torch.randn(shape[dim], dtype=dtype, device="cuda") + wgt = torch.randn(shape[dim], dtype=dtype, device=flag_gems.device) ref_wgt = to_reference(wgt, True) else: wgt = None @@ -152,9 +152,9 @@ def test_accuracy_cross_entropy_loss_probabilities( shape, dtype, reduction, label_smoothing ): dim = 1 - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) - target = torch.randn(shape, dtype=dtype, device="cuda") - weight = torch.randn(shape[dim], dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + target = torch.randn(shape, dtype=dtype, device=flag_gems.device) + weight = torch.randn(shape[dim], dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_target = to_reference(target, True) ref_weight = to_reference(weight, True) @@ -192,9 +192,9 @@ def test_accuracy_cross_entropy_loss_probabilities( def test_accuracy_cumsum(shape, dtype): dim = 1 if shape == REDUCTION_SHAPES[-1] else -1 if dtype in INT_DTYPES: - inp = torch.randint(-3, 3, shape, device="cuda").to(dtype) + inp = torch.randint(-3, 3, shape, device=flag_gems.device).to(dtype) else: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.cumsum(ref_inp, dim=dim) @@ -239,11 +239,13 @@ def test_accuracy_cummin(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool]) def test_accuracy_nonzero(shape, dtype): if dtype == torch.bool: - inp = torch.randint(0, 2, shape, dtype=torch.int, device="cuda").to(dtype) + inp = torch.randint(0, 2, shape, dtype=torch.int, device=flag_gems.device).to( + dtype + ) elif dtype in INT_DTYPES: - inp = torch.randint(-3, 3, shape, device="cuda").to(dtype) + inp = torch.randint(-3, 3, shape, device=flag_gems.device).to(dtype) else: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, False) ref_out = torch.nonzero(ref_inp) @@ -276,7 +278,7 @@ def test_accuracy_count_nonzero(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_log_softmax(shape, dtype): dim = 1 - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.nn.functional.log_softmax(ref_inp, dim=dim) @@ -300,7 +302,7 @@ def test_accuracy_log_softmax(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("dim", DIM_LIST) def test_accuracy_softmax(shape, dtype, dim): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.nn.functional.softmax(ref_inp, dim=dim) @@ -323,7 +325,7 @@ def test_accuracy_softmax(shape, dtype, dim): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("dim", DIM_LIST) def test_accuracy_softmax_with_neg_inf(shape, dtype, dim): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) inp = torch.where(inp < 0.0, float("-inf"), inp) ref_inp = to_reference(inp, True) @@ -351,7 +353,7 @@ def test_accuracy_softmax_with_neg_inf(shape, dtype, dim): def test_accuracy_varmean(shape, dim, correction, keepdim, dtype): if shape[0] == 1: # TODO: res is inf, while ref is nan shape = (2, 2) - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_var, ref_mean = torch.var_mean( @@ -376,8 +378,8 @@ def test_accuracy_varmean(shape, dim, correction, keepdim, dtype): @pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_scatter_src(src_shape, inp_shape, dim, dtype): - inp = torch.randn(inp_shape, dtype=dtype, device="cuda") - src = torch.randn(src_shape, dtype=dtype, device="cuda") + inp = torch.randn(inp_shape, dtype=dtype, device=flag_gems.device) + src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device) size_dim = min(src_shape[dim], inp_shape[dim]) import random @@ -387,7 +389,7 @@ def test_accuracy_scatter_src(src_shape, inp_shape, dim, dtype): random.randint(1, min(src_shape[1], inp_shape[1])), random.randint(1, min(src_shape[2], inp_shape[2])), ] - index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + index = torch.empty(tuple(index_shape), dtype=torch.long, device=flag_gems.device) m, n, o = index_shape @@ -420,8 +422,8 @@ def test_accuracy_scatter_src(src_shape, inp_shape, dim, dtype): @pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_scatter_add(src_shape, inp_shape, dim, dtype): - inp = torch.randn(inp_shape, dtype=dtype, device="cuda") - src = torch.randn(src_shape, dtype=dtype, device="cuda") + inp = torch.randn(inp_shape, dtype=dtype, device=flag_gems.device) + src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device) size_dim = min(src_shape[dim], inp_shape[dim]) import random @@ -431,7 +433,7 @@ def test_accuracy_scatter_add(src_shape, inp_shape, dim, dtype): random.randint(1, min(src_shape[1], inp_shape[1])), random.randint(1, min(src_shape[2], inp_shape[2])), ] - index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + index = torch.empty(tuple(index_shape), dtype=torch.long, device=flag_gems.device) m, n, o = index_shape @@ -464,8 +466,8 @@ def test_accuracy_scatter_add(src_shape, inp_shape, dim, dtype): @pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_scatter_mul(src_shape, inp_shape, dim, dtype): - inp = torch.randn(inp_shape, dtype=dtype, device="cuda") - src = torch.randn(src_shape, dtype=dtype, device="cuda") + inp = torch.randn(inp_shape, dtype=dtype, device=flag_gems.device) + src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device) size_dim = min(src_shape[dim], inp_shape[dim]) import random @@ -475,7 +477,7 @@ def test_accuracy_scatter_mul(src_shape, inp_shape, dim, dtype): random.randint(1, min(src_shape[1], inp_shape[1])), random.randint(1, min(src_shape[2], inp_shape[2])), ] - index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + index = torch.empty(tuple(index_shape), dtype=torch.long, device=flag_gems.device) m, n, o = index_shape @@ -506,7 +508,7 @@ def test_accuracy_scatter_mul(src_shape, inp_shape, dim, dtype): @pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_gather(inp_shape, dim, dtype): - inp = torch.randn(inp_shape, dtype=dtype, device="cuda") + inp = torch.randn(inp_shape, dtype=dtype, device=flag_gems.device) size_dim = inp_shape[dim] import random @@ -516,7 +518,7 @@ def test_accuracy_gather(inp_shape, dim, dtype): random.randint(1, inp_shape[1]), random.randint(1, inp_shape[2]), ] - index = torch.empty(tuple(index_shape), dtype=torch.long, device="cuda") + index = torch.empty(tuple(index_shape), dtype=torch.long, device=flag_gems.device) m, n, o = index_shape @@ -547,11 +549,11 @@ def test_accuracy_select_scatter(shape, dim, dtype): import random index = random.randint(0, shape[dim] - 1) - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) src_shape = list(inp.shape) del src_shape[dim] - src = torch.randn(src_shape, dtype=dtype, device="cuda") + src = torch.randn(src_shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_src = to_reference(src) @@ -569,7 +571,7 @@ def test_accuracy_select_scatter(shape, dim, dtype): @pytest.mark.parametrize("end", [1024, 256]) @pytest.mark.parametrize("step", [1, 2]) def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step): - inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + inp = torch.empty_strided(shape, stride, dtype=dtype, device=flag_gems.device) inp.copy_(1) valid_shape = list(inp.shape) @@ -585,7 +587,7 @@ def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step): valid_shape[dim] = (end - start + step - 1) // step - src = torch.rand(valid_shape, dtype=dtype, device="cuda") + src = torch.rand(valid_shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_src = to_reference(src) @@ -607,7 +609,7 @@ def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step): @pytest.mark.parametrize("end", [1024, 256]) @pytest.mark.parametrize("step", [1, 2]) def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, step): - inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + inp = torch.empty_strided(shape, stride, dtype=dtype, device=flag_gems.device) inp.copy_(1) valid_shape = list(inp.shape) @@ -623,7 +625,7 @@ def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, valid_shape[dim] = (end - start + step - 1) // step - src = torch.rand(valid_shape, dtype=dtype, device="cuda") + src = torch.rand(valid_shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_src = to_reference(src) @@ -644,11 +646,13 @@ def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, @pytest.mark.parametrize("dim", DIM_LIST) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_index_select(shape, dim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) index_size = inp.size(dim) from math import floor - index = torch.randint(0, index_size, [floor(index_size * 0.8)], device="cuda") + index = torch.randint( + 0, index_size, [floor(index_size * 0.8)], device=flag_gems.device + ) ref_inp = to_reference(inp) ref_index = to_reference(index) @@ -663,8 +667,8 @@ def test_accuracy_index_select(shape, dim, dtype): @pytest.mark.parametrize("threshold, shape", THRESHOLD_SHAPE) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_masked_select(shape, dtype, threshold): - inp = torch.randn(shape, dtype=dtype, device="cuda") - mask = torch.randn(shape, dtype=dtype, device="cuda") < threshold + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + mask = torch.randn(shape, dtype=dtype, device=flag_gems.device) < threshold ref_inp = to_reference(inp) ref_mask = to_reference(mask) diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index 0c84937a8..9b3e111e3 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -24,6 +24,8 @@ ) from .conftest import TO_CPU +device = flag_gems.device + # TODO: sometimes failed at (8192,), 0.6, bfloat16 @pytest.mark.dropout @@ -34,7 +36,7 @@ def test_accuracy_dropout(shape, p, dtype): if TO_CPU or shape == (1,): shape = (32768,) - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp) # NOTE: ensure that scalars are float32(instead of float64) @@ -78,7 +80,7 @@ def test_accuracy_dropout(shape, p, dtype): ), f"num_equal: {num_equal}, exp_equal: {exp_equal}, num_total: {inp.numel()}" -def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device="cuda"): +def get_rope_cos_sin(max_seq_len, dim, dtype, base=10000, device=flag_gems.device): inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) t = torch.arange(max_seq_len, device=device, dtype=inv_freq.dtype) freqs = torch.outer(t, inv_freq) @@ -156,14 +158,16 @@ def test_apply_rotary_pos_emb( ): seq_len = torch.randint(1, max_seq_len, (1,)).item() q = torch.randn( - (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device="cuda" + (batch_size, seq_len, q_heads, head_dim), dtype=dtype, device=flag_gems.device ) k = torch.randn( - (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device="cuda" + (batch_size, seq_len, k_heads, head_dim), dtype=dtype, device=flag_gems.device ) - position_ids = torch.randint(0, max_seq_len, (batch_size, seq_len), device="cuda") - cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device="cuda") + position_ids = torch.randint( + 0, max_seq_len, (batch_size, seq_len), device=flag_gems.device + ) + cos, sin = get_rope_cos_sin(max_seq_len, head_dim, dtype, device=flag_gems.device) ref_q = to_reference(q, True) ref_k = to_reference(k, True) @@ -204,10 +208,10 @@ def test_apply_rotary_pos_emb( @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, dtype): indices = torch.randint( - 0, EmbeddingSize, (Batch, M), device="cuda", requires_grad=False + 0, EmbeddingSize, (Batch, M), device=flag_gems.device, requires_grad=False ) embedding = torch.randn( - (EmbeddingSize, N), device="cuda", dtype=dtype, requires_grad=True + (EmbeddingSize, N), device=flag_gems.device, dtype=dtype, requires_grad=True ) ref_embedding = to_reference(embedding) ref_indices = to_reference(indices) @@ -233,7 +237,7 @@ def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, @pytest.mark.parametrize("shape", SPECIAL_SHAPES) @pytest.mark.parametrize("dtype", [torch.cfloat]) def test_accuracy_resolve_neg(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=flag_gems.device) y = x.conj() z = y.imag assert z.is_neg() @@ -255,7 +259,7 @@ def test_topk( largest, dtype, ): - x = torch.arange(hiddensize, dtype=dtype, device="cuda") + x = torch.arange(hiddensize, dtype=dtype, device=flag_gems.device) x = x.repeat(batch_size).reshape(batch_size, hiddensize) # Each row use different shuffled index. @@ -276,7 +280,7 @@ def test_topk( @pytest.mark.parametrize("shape", SPECIAL_SHAPES) @pytest.mark.parametrize("dtype", [torch.cfloat]) def test_accuracy_resolve_conj(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=flag_gems.device) y = x.conj() assert y.is_conj() with flag_gems.use_gems(): @@ -292,9 +296,9 @@ def test_accuracy_resolve_conj(shape, dtype): @pytest.mark.parametrize("return_counts", [False, True]) def test_accuracy_unique(shape, dtype, sorted, return_inverse, return_counts): if dtype in FLOAT_DTYPES: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(-10, 10, shape, device="cuda").to(dtype) + inp = torch.randint(-10, 10, shape, device=flag_gems.device).to(dtype) ref_inp = to_reference(inp, False) if return_counts: @@ -371,14 +375,14 @@ def test_accuracy_unique(shape, dtype, sorted, return_inverse, return_counts): @pytest.mark.parametrize("n_samples", [1000]) def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples): if shape[-1] == 1: - dist = torch.rand(size=shape, dtype=dtype, device="cuda") + dist = torch.rand(size=shape, dtype=dtype, device=flag_gems.device) with flag_gems.use_gems(): res_out = torch.multinomial(dist, n_samples, True) assert torch.all(res_out == 0) else: # Mask p% off of the categories and test the sampling results fall in the rest for p in (0.1, 0.5, 0.9): - dist = torch.rand(size=shape, dtype=dtype, device="cuda") + dist = torch.rand(size=shape, dtype=dtype, device=flag_gems.device) dist[torch.rand(shape) < p] = 0 # Make sure there's at least one non-zero probability dist[..., -1] = 0.5 @@ -393,7 +397,7 @@ def test_accuracy_multinomial_with_replacement(shape, dtype, n_samples): @pytest.mark.parametrize("pool", UT_SHAPES_2D) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_multinomial_without_replacement(pool, dtype): - dist = torch.rand(size=pool, dtype=dtype, device="cuda") + dist = torch.rand(size=pool, dtype=dtype, device=flag_gems.device) k = pool[-1] if k > 1: ns = [k // 2, k] @@ -413,7 +417,7 @@ def test_accuracy_multinomial_without_replacement(pool, dtype): @pytest.mark.parametrize("pad_mode", ["constant", "reflect", "replicate", "circular"]) @pytest.mark.parametrize("contiguous", [True, False]) def test_pad(shape, dtype, pad_mode, contiguous): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=flag_gems.device) if not contiguous: x = x[::2, ::2] @@ -455,7 +459,7 @@ def test_pad(shape, dtype, pad_mode, contiguous): ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) def test_upsample_bicubic2d_aa(dtype, shape, scale, align_corners): - input = torch.rand(shape, dtype=dtype, device="cuda") + input = torch.rand(shape, dtype=dtype, device=flag_gems.device) ref_i = to_reference(input, True) output_size = tuple([int(input.shape[i + 2] * scale[i]) for i in range(2)]) ref_out = torch._C._nn._upsample_bicubic2d_aa( @@ -480,7 +484,7 @@ def span(scale): @pytest.mark.parametrize("shape", UPSAMPLE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_upsample_nearest2d(dtype, shape, scale): - input = torch.randn(shape, dtype=dtype, device="cuda") + input = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_i = to_reference(input).to(torch.float32) output_size = [int(input.shape[i + 2] * scale[i]) for i in range(2)] ref_out = torch._C._nn.upsample_nearest2d(ref_i, output_size=output_size).to(dtype) @@ -494,7 +498,7 @@ def test_upsample_nearest2d(dtype, shape, scale): @pytest.mark.parametrize("step", [1, 2, 5]) @pytest.mark.parametrize("end", [128, 256, 1024]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES + ALL_INT_DTYPES + [None]) -@pytest.mark.parametrize("device", ["cuda", None]) +@pytest.mark.parametrize("device", [device, None]) @pytest.mark.parametrize( "pin_memory", [False, None] ) # Since triton only target to GPU, pin_memory only used in CPU tensors. @@ -518,10 +522,10 @@ def test_arange(start, step, end, dtype, device, pin_memory): @pytest.mark.parametrize("assume_unique", [False, True]) @pytest.mark.parametrize("invert", [False, True]) def test_accuracy_isin(shape, dtype, assume_unique, invert): - inp1 = torch.randint(-100, 100, shape, device="cuda").to(dtype) + inp1 = torch.randint(-100, 100, shape, device=flag_gems.device).to(dtype) test_numel = inp1.numel() // 2 if inp1.numel() > 1 else 1 test_shape = (test_numel,) - inp2 = torch.randint(-10, 10, test_shape, device="cuda").to(dtype) + inp2 = torch.randint(-10, 10, test_shape, device=flag_gems.device).to(dtype) inp1.ravel()[-1] = 0 if assume_unique: inp1 = torch.unique(inp1) @@ -546,7 +550,7 @@ def test_accuracy_isin(shape, dtype, assume_unique, invert): ref2_out = torch.isin(ref_inp1, inp2_s, assume_unique=assume_unique, invert=invert) gems_assert_equal(res2_out, ref2_out) - inp0 = torch.tensor([], device="cuda") + inp0 = torch.tensor([], device=flag_gems.device) ref_inp0 = to_reference(inp0, False) with flag_gems.use_gems(): res0_out = torch.isin(inp0, inp2, assume_unique=assume_unique, invert=invert) @@ -562,7 +566,7 @@ def test_accuracy_isin(shape, dtype, assume_unique, invert): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_fill(value, shape, dtype): # Test fill.Scalar - x = torch.ones(shape, device="cuda", dtype=dtype) + x = torch.ones(shape, device=flag_gems.device, dtype=dtype) ref_x = to_reference(x, False) ref_out = torch.fill(ref_x, value) @@ -572,7 +576,7 @@ def test_fill(value, shape, dtype): gems_assert_equal(res_out, ref_out) # Test fill.Tensor - value_tensor = torch.tensor(value, device="cuda", dtype=dtype) + value_tensor = torch.tensor(value, device=flag_gems.device, dtype=dtype) ref_out_tensor = torch.fill(ref_x, value_tensor) with flag_gems.use_gems(): res_out_tensor = torch.fill(x, value_tensor) @@ -586,12 +590,12 @@ def test_fill(value, shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_accuracy_stack(shape, dim, dtype): if dtype in FLOAT_DTYPES: - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] else: inp = [ - torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( - dtype - ) + torch.randint( + low=0, high=0x7FFF, size=s, dtype=dtype, device=flag_gems.device + ).to(dtype) for s in shape ] ref_inp = [to_reference(_) for _ in inp] @@ -614,12 +618,12 @@ def test_accuracy_stack(shape, dim, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_accuracy_hstack(shape, dtype): if dtype in FLOAT_DTYPES: - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] else: inp = [ - torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( - dtype - ) + torch.randint( + low=0, high=0x7FFF, size=s, dtype=dtype, device=flag_gems.device + ).to(dtype) for s in shape ] ref_inp = [to_reference(_) for _ in inp] @@ -641,12 +645,12 @@ def test_accuracy_hstack(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_exception_hstack(shape, dtype): if dtype in FLOAT_DTYPES: - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] else: inp = [ - torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( - dtype - ) + torch.randint( + low=0, high=0x7FFF, size=s, dtype=dtype, device=flag_gems.device + ).to(dtype) for s in shape ] @@ -694,12 +698,12 @@ def gen_cat_shapes_dim(shapes): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_accuracy_cat(shape, dim, dtype): if dtype in FLOAT_DTYPES: - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] else: inp = [ - torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( - dtype - ) + torch.randint( + low=0, high=0x7FFF, size=s, dtype=dtype, device=flag_gems.device + ).to(dtype) for s in shape ] ref_inp = [to_reference(_) for _ in inp] @@ -721,7 +725,7 @@ def test_accuracy_cat(shape, dim, dtype): ) @pytest.mark.parametrize("dtype", [torch.float32]) def test_accuracy_cat_empty_tensor(shape, dim, dtype): - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] ref_inp = [to_reference(_) for _ in inp] ref_out = torch.cat(ref_inp, dim) @@ -749,12 +753,12 @@ def test_accuracy_cat_empty_tensor(shape, dim, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_accuracy_vstack(shape, dtype): if dtype in FLOAT_DTYPES: - inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape] + inp = [torch.randn(s, dtype=dtype, device=flag_gems.device) for s in shape] else: inp = [ - torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to( - dtype - ) + torch.randint( + low=0, high=0x7FFF, size=s, dtype=dtype, device=flag_gems.device + ).to(dtype) for s in shape ] ref_inp = [to_reference(_) for _ in inp] @@ -780,7 +784,7 @@ def test_accuracy_vstack(shape, dtype): @pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_repeat_interleave_self_int(shape, dim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) repeats = 2 ref_inp = to_reference(inp) @@ -795,7 +799,7 @@ def test_accuracy_repeat_interleave_self_int(shape, dim, dtype): @pytest.mark.parametrize("dim", REPEAT_INTERLEAVE_DIM) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_repeat_interleave_self_int_non_contiguous(shape, dim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda")[::2] + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)[::2] repeats = 2 ref_inp = to_reference(inp) @@ -809,7 +813,7 @@ def test_accuracy_repeat_interleave_self_int_non_contiguous(shape, dim, dtype): @pytest.mark.parametrize("shape", UT_SHAPES_1D) @pytest.mark.parametrize("dtype", [torch.int32]) def test_accuracy_repeat_interleave_tensor(shape, dtype): - repeats = torch.randint(0, 30, shape, dtype=dtype, device="cuda") + repeats = torch.randint(0, 30, shape, dtype=dtype, device=flag_gems.device) ref_repeats = to_reference(repeats) ref_out = torch.repeat_interleave(ref_repeats) @@ -823,8 +827,8 @@ def test_accuracy_repeat_interleave_tensor(shape, dtype): @pytest.mark.parametrize("dim", [-1, 0, 1]) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_repeat_interleave_self_tensor(shape, dim, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") - repeats = torch.randint(0, 30, (shape[dim],), device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + repeats = torch.randint(0, 30, (shape[dim],), device=flag_gems.device) ref_inp = to_reference(inp) ref_repeats = to_reference(repeats) @@ -840,11 +844,11 @@ def test_accuracy_repeat_interleave_self_tensor(shape, dim, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + BOOL_TYPES) def test_accuracy_diag(shape, diagonal, dtype): if dtype in FLOAT_DTYPES: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) elif dtype in BOOL_TYPES: - inp = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(0, 0x7FFF, size=shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 0x7FFF, size=shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.diag(ref_inp, diagonal) @@ -881,11 +885,15 @@ def get_diag_embed_shape_and_dims(): @pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + BOOL_TYPES) def test_accuracy_diag_embed(shape, dtype, offset, dim1, dim2): if dtype in FLOAT_DTYPES: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) elif dtype in INT_DTYPES: - inp = torch.randint(low=0, high=0x7FFF, size=shape, dtype=dtype, device="cuda") + inp = torch.randint( + low=0, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device + ) else: - inp = torch.randint(low=0, high=2, size=shape, dtype=dtype, device="cuda") + inp = torch.randint( + low=0, high=2, size=shape, dtype=dtype, device=flag_gems.device + ) ref_inp = to_reference(inp) diff --git a/tests/test_tensor_constructor_ops.py b/tests/test_tensor_constructor_ops.py index 1e1a8742d..1403d1e0e 100644 --- a/tests/test_tensor_constructor_ops.py +++ b/tests/test_tensor_constructor_ops.py @@ -14,13 +14,15 @@ ) from .conftest import TO_CPU +device = flag_gems.device + @pytest.mark.rand @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_rand(shape, dtype): with flag_gems.use_gems(): - res_out = torch.rand(shape, dtype=dtype, device="cuda") + res_out = torch.rand(shape, dtype=dtype, device=device) assert (res_out <= 1.0).all() assert (res_out >= 0.0).all() @@ -30,7 +32,7 @@ def test_accuracy_rand(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_randn(shape, dtype): with flag_gems.use_gems(): - res_out = torch.randn(shape, dtype=dtype, device="cuda") + res_out = torch.randn(shape, dtype=dtype, device=device) mean = torch.mean(res_out) std = torch.std(res_out) assert torch.abs(mean) < 0.01 @@ -41,7 +43,7 @@ def test_accuracy_randn(shape, dtype): @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_rand_like(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=device) with flag_gems.use_gems(): res_out = torch.rand_like(x) assert (res_out <= 1.0).all() @@ -52,7 +54,7 @@ def test_accuracy_rand_like(shape, dtype): @pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_randn_like(shape, dtype): - x = torch.randn(size=shape, dtype=dtype, device="cuda") + x = torch.randn(size=shape, dtype=dtype, device=device) with flag_gems.use_gems(): res_out = torch.randn_like(x) mean = torch.mean(res_out) @@ -67,14 +69,14 @@ def test_accuracy_randn_like(shape, dtype): def test_accuracy_zeros(shape, dtype): # without dtype with flag_gems.use_gems(): - res_out = torch.zeros(shape, device="cuda") - gems_assert_equal(res_out, torch.zeros(shape, device="cpu" if TO_CPU else "cuda")) + res_out = torch.zeros(shape, device=flag_gems.device) + gems_assert_equal(res_out, torch.zeros(shape, device="cpu" if TO_CPU else device)) # with dtype with flag_gems.use_gems(): - res_out = torch.zeros(shape, dtype=dtype, device="cuda") + res_out = torch.zeros(shape, dtype=dtype, device=flag_gems.device) gems_assert_equal( - res_out, torch.zeros(shape, dtype=dtype, device="cpu" if TO_CPU else "cuda") + res_out, torch.zeros(shape, dtype=dtype, device="cpu" if TO_CPU else device) ) @@ -84,14 +86,14 @@ def test_accuracy_zeros(shape, dtype): def test_accuracy_ones(shape, dtype): # without dtype with flag_gems.use_gems(): - res_out = torch.ones(shape, device="cuda") - gems_assert_equal(res_out, torch.ones(shape, device="cpu" if TO_CPU else "cuda")) + res_out = torch.ones(shape, device=flag_gems.device) + gems_assert_equal(res_out, torch.ones(shape, device="cpu" if TO_CPU else device)) # with dtype with flag_gems.use_gems(): - res_out = torch.ones(shape, dtype=dtype, device="cuda") + res_out = torch.ones(shape, dtype=dtype, device=flag_gems.device) gems_assert_equal( - res_out, torch.ones(shape, dtype=dtype, device="cpu" if TO_CPU else "cuda") + res_out, torch.ones(shape, dtype=dtype, device="cpu" if TO_CPU else device) ) @@ -101,17 +103,17 @@ def test_accuracy_ones(shape, dtype): @pytest.mark.parametrize("fill_value", [3.1415926, 2, False]) def test_accuracy_full(shape, dtype, fill_value): # without dtype - ref_out = torch.full(shape, fill_value, device="cpu" if TO_CPU else "cuda") + ref_out = torch.full(shape, fill_value, device="cpu" if TO_CPU else device) with flag_gems.use_gems(): - res_out = torch.full(shape, fill_value, device="cuda") + res_out = torch.full(shape, fill_value, device=flag_gems.device) gems_assert_equal(res_out, ref_out) # with dtype ref_out = torch.full( - shape, fill_value, dtype=dtype, device="cpu" if TO_CPU else "cuda" + shape, fill_value, dtype=dtype, device="cpu" if TO_CPU else device ) with flag_gems.use_gems(): - res_out = torch.full(shape, fill_value, dtype=dtype, device="cuda") + res_out = torch.full(shape, fill_value, dtype=dtype, device=flag_gems.device) gems_assert_equal(res_out, ref_out) @@ -119,7 +121,7 @@ def test_accuracy_full(shape, dtype, fill_value): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_zeros_like(shape, dtype): - x = torch.empty(size=shape, dtype=dtype, device="cpu" if TO_CPU else "cuda") + x = torch.empty(size=shape, dtype=dtype, device="cpu" if TO_CPU else device) with flag_gems.use_gems(): res_out = torch.zeros_like(x) gems_assert_equal(res_out, torch.zeros_like(x)) @@ -129,7 +131,7 @@ def test_accuracy_zeros_like(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_ones_like(shape, dtype): - x = torch.empty(size=shape, dtype=dtype, device="cpu" if TO_CPU else "cuda") + x = torch.empty(size=shape, dtype=dtype, device="cpu" if TO_CPU else device) with flag_gems.use_gems(): res_out = torch.ones_like(x) gems_assert_equal(res_out, torch.ones_like(x)) @@ -141,7 +143,7 @@ def test_accuracy_ones_like(shape, dtype): @pytest.mark.parametrize("xdtype", BOOL_TYPES + ALL_INT_DTYPES + ALL_FLOAT_DTYPES) @pytest.mark.parametrize("fill_value", [3.1415926, 2, False]) def test_accuracy_full_like(shape, dtype, xdtype, fill_value): - x = torch.empty(size=shape, dtype=xdtype, device="cpu" if TO_CPU else "cuda") + x = torch.empty(size=shape, dtype=xdtype, device="cpu" if TO_CPU else device) # without dtype with flag_gems.use_gems(): @@ -161,9 +163,9 @@ def test_accuracy_randperm(n, dtype): if n > torch.iinfo(torch.int16).max and dtype == torch.int16: return - ref_out = torch.randperm(n, dtype=dtype, device="cpu" if TO_CPU else "cuda") + ref_out = torch.randperm(n, dtype=dtype, device="cpu" if TO_CPU else device) with flag_gems.use_gems(): - res_out = torch.randperm(n, dtype=dtype, device="cuda") + res_out = torch.randperm(n, dtype=dtype, device=flag_gems.device) sorted_ref, _ = torch.sort(ref_out) sorted_res, _ = torch.sort(res_out) gems_assert_equal(sorted_res, sorted_ref) diff --git a/tests/test_tensor_wrapper.py b/tests/test_tensor_wrapper.py index 30eb3eedc..b56cdf9ac 100644 --- a/tests/test_tensor_wrapper.py +++ b/tests/test_tensor_wrapper.py @@ -2,6 +2,7 @@ import triton from triton import language as tl +import flag_gems from flag_gems.utils import tensor_wrapper @@ -16,8 +17,8 @@ def double(in_ptr, out_ptr, n, TILE_SIZE: tl.constexpr): def test_typed_pointer(): - real = torch.randn(10, 10, device="cuda") - imag = torch.randn(10, 10, device="cuda") + real = torch.randn(10, 10, device=flag_gems.device) + imag = torch.randn(10, 10, device=flag_gems.device) x = torch.complex(real, imag) out = torch.empty_like(x) @@ -34,8 +35,8 @@ def test_typed_pointer(): def test_typed_pointer_reinterpret_with_offset(): - real = torch.randn(100, device="cuda") - imag = torch.randn(100, device="cuda") + real = torch.randn(100, device=flag_gems.device) + imag = torch.randn(100, device=flag_gems.device) x = torch.complex(real, imag) out = torch.empty_like(x) @@ -55,7 +56,7 @@ def test_typed_pointer_reinterpret_with_offset(): def test_typed_pointer_as_is(): - x = torch.randn(100, device="cuda") + x = torch.randn(100, device=flag_gems.device) out = torch.empty_like(x) TILE_SIZE = 128 k = 10 @@ -71,7 +72,7 @@ def test_typed_pointer_as_is(): def test_strided_buffer_slice(): - x = torch.randn(100, 100, device="cuda") + x = torch.randn(100, 100, device=flag_gems.device) x_buffer = tensor_wrapper.StridedBuffer(x, (10, 10), (100, 1)) assert x_buffer.size() == (10, 10) assert x.element_size() == x.element_size() diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 22192feee..f3aceaf54 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -22,7 +22,7 @@ @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_abs(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.abs(ref_inp) @@ -37,10 +37,10 @@ def test_accuracy_abs(shape, dtype): @pytest.mark.parametrize("dtype", INT_DTYPES + BOOL_TYPES) def test_accuracy_bitwisenot(shape, dtype): if dtype in BOOL_TYPES: - inp = torch.randint(0, 2, size=shape, dtype=dtype, device="cuda") + inp = torch.randint(0, 2, size=shape, dtype=dtype, device=flag_gems.device) else: inp = torch.randint( - low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device="cuda" + low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=flag_gems.device ) ref_inp = to_reference(inp) @@ -55,7 +55,7 @@ def test_accuracy_bitwisenot(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_cos(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.cos(ref_inp) @@ -69,7 +69,7 @@ def test_accuracy_cos(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_exp(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.exp(ref_inp) @@ -84,7 +84,7 @@ def test_accuracy_exp(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("approximate", ["none", "tanh"]) def test_accuracy_gelu(shape, dtype, approximate): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.nn.functional.gelu(ref_inp, approximate=approximate) @@ -105,7 +105,7 @@ def test_accuracy_gelu(shape, dtype, approximate): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_isinf(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp = torch.masked_fill(inp, inp > 1.0, -float("inf")) ref_inp = to_reference(inp) @@ -120,7 +120,7 @@ def test_accuracy_isinf(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_isnan(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp = torch.masked_fill(inp, inp > 1.0, float("nan")) ref_inp = to_reference(inp) @@ -135,7 +135,7 @@ def test_accuracy_isnan(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_neg(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.neg(ref_inp) @@ -149,7 +149,7 @@ def test_accuracy_neg(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_reciprocal(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.reciprocal(ref_inp) @@ -163,7 +163,7 @@ def test_accuracy_reciprocal(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_relu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.nn.functional.relu(ref_inp) @@ -184,7 +184,7 @@ def test_accuracy_relu(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_rsqrt(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.rsqrt(ref_inp) @@ -198,7 +198,7 @@ def test_accuracy_rsqrt(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sigmoid(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.sigmoid(ref_inp) @@ -219,7 +219,7 @@ def test_accuracy_sigmoid(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_silu(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.nn.functional.silu(ref_inp) @@ -240,7 +240,7 @@ def test_accuracy_silu(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_sin(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp, True) ref_out = torch.sin(ref_inp) @@ -254,7 +254,7 @@ def test_accuracy_sin(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_tanh(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) ref_out = torch.tanh(ref_inp) @@ -278,7 +278,7 @@ def test_accuracy_tanh(shape, dtype): @pytest.mark.parametrize("shape, diagonal", SHAPE_DIAGONAL) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_triu(shape, diagonal, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp = unsqueeze_tensor(inp, 2) ref_inp = to_reference(inp) @@ -293,7 +293,7 @@ def test_accuracy_triu(shape, diagonal, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_erf(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.erf(ref_inp) @@ -307,7 +307,7 @@ def test_accuracy_erf(shape, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", ALL_FLOAT_DTYPES) def test_accuracy_isfinite(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) inp = torch.masked_fill(inp, inp > 1.0, float("inf")) inp = torch.masked_fill(inp, inp < -1.0, float("-inf")) inp = torch.masked_fill(inp, (inp > -0.1) & (inp < 0.1), float("nan")) @@ -337,9 +337,9 @@ def get_max_ndim(shape, dims): @pytest.mark.parametrize("dims", FLIP_DIMS) def test_accuracy_flip_general(shape, dtype, dims): if dtype in ALL_FLOAT_DTYPES: - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) else: - inp = torch.randint(-1000, 1000, shape, device="cuda").to(dtype) + inp = torch.randint(-1000, 1000, shape, device=flag_gems.device).to(dtype) max_ndim = get_max_ndim(shape, dims) inp = unsqueeze_tensor(inp, max_ndim) ref_inp = to_reference(inp, False) @@ -361,11 +361,11 @@ def test_accuracy_flip_with_non_dense_input(shape, dtype, dims): shape_dialted = tuple(item * 2 for item in shape) if dtype in ALL_FLOAT_DTYPES: - inp = torch.randn(shape_dialted, dtype=dtype, device="cuda")[::2, ::2] + inp = torch.randn(shape_dialted, dtype=dtype, device=flag_gems.device)[::2, ::2] else: - inp = torch.randint(-1000, 1000, shape_dialted, device="cuda").to(dtype)[ - ::2, ::2 - ] + inp = torch.randint(-1000, 1000, shape_dialted, device=flag_gems.device).to( + dtype + )[::2, ::2] ref_inp = to_reference(inp, False) with flag_gems.use_gems(): @@ -380,11 +380,15 @@ def test_accuracy_flip_with_non_dense_input(shape, dtype, dims): @pytest.mark.parametrize("threshold", [0.3, 0.5, 0.7]) @pytest.mark.parametrize( "value", - [torch.tensor(1024, device="cuda"), torch.scalar_tensor(1024, device="cuda"), 1024], + [ + torch.tensor(1024, device=flag_gems.device), + torch.scalar_tensor(1024, device=flag_gems.device), + 1024, + ], ) def test_accuracy_masked_fill(shape, dtype, threshold, value): - inp = torch.zeros(shape, dtype=dtype, device="cuda") - mask = torch.randn(shape, dtype=dtype, device="cuda") < threshold + inp = torch.zeros(shape, dtype=dtype, device=flag_gems.device) + mask = torch.randn(shape, dtype=dtype, device=flag_gems.device) < threshold ref_inp = to_reference(inp) ref_mask = to_reference(mask) @@ -404,11 +408,15 @@ def test_accuracy_masked_fill(shape, dtype, threshold, value): @pytest.mark.parametrize("threshold", [0.3, 0.5, 0.7]) @pytest.mark.parametrize( "value", - [torch.tensor(1024, device="cuda"), torch.scalar_tensor(1024, device="cuda"), 1024], + [ + torch.tensor(1024, device=flag_gems.device), + torch.scalar_tensor(1024, device=flag_gems.device), + 1024, + ], ) def test_accuracy_masked_fill_(shape, dtype, threshold, value): - inp = torch.zeros(shape, dtype=dtype, device="cuda") - mask = torch.randn(shape, dtype=dtype, device="cuda") < threshold + inp = torch.zeros(shape, dtype=dtype, device=flag_gems.device) + mask = torch.randn(shape, dtype=dtype, device=flag_gems.device) < threshold ref_inp = to_reference(inp) ref_mask = to_reference(mask) @@ -430,7 +438,7 @@ def test_accuracy_masked_fill_(shape, dtype, threshold, value): @pytest.mark.parametrize("dims", TILE_DIMS) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_tile(shape, dims, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) ref_out = torch.tile(ref_inp, dims) @@ -448,7 +456,7 @@ def test_accuracy_tile(shape, dims, dtype): @pytest.mark.parametrize("sizes", REPEAT_SIZES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_repeat(shape, sizes, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda") + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) ref_inp = to_reference(inp) sizes = unsqueeze_tuple(sizes, inp.ndim)