Skip to content

Commit

Permalink
[Muti_backend]part_2_device (#344)
Browse files Browse the repository at this point in the history
* new feature, muti_backend

* update auto_tune_module

* update auto_tune_module

* update auto_tune_module

* update __init__

* rebase

* fix bug

* modifiy auto_tune_config

* fix bug

* fix bug

* update

* update

* update scatter&gather

* fix auto_tune

* add gen_torch_device_fn

* fix codestyle

* fix codestyle

* Modify code based on comments

* Modify gen_impl with loops instead of recursion

* Update code structure

* Polish code

* update

* Polish code

* Modify code based on comments

* modify based on comment

* Modify code based on comments

* update

* final fix

* modify_device_unify

* fix bug

* update

* remove '-> list'

* modify

* modify

* modify

* fix bug
  • Loading branch information
Galaxy1458 authored Dec 12, 2024
1 parent 3a740e0 commit 9de501e
Show file tree
Hide file tree
Showing 103 changed files with 843 additions and 699 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ pip install .
import flag_gems

M, N, K = 1024, 1024, 1024
A = torch.randn((M, K), dtype=torch.float16, device="cuda")
B = torch.randn((K, N), dtype=torch.float16, device="cuda")
A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device)
B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device)
with flag_gems.use_gems():
C = torch.mm(A, B)
```

### Execute

1. Test Operator Accuracy
- Run reference on cuda
- Run reference on specific backend like cuda
```shell
cd tests
pytest test_xx_ops.py
Expand Down
6 changes: 3 additions & 3 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,16 @@ pip install .
import flag_gems

M, N, K = 1024, 1024, 1024
A = torch.randn((M, K), dtype=torch.float16, device="cuda")
B = torch.randn((K, N), dtype=torch.float16, device="cuda")
A = torch.randn((M, K), dtype=torch.float16, device=flag_gems.device)
B = torch.randn((K, N), dtype=torch.float16, device=flag_gems.device)
with flag_gems.use_gems():
C = torch.mm(A, B)
```

### 执行

1. 算子正确性测试
- CUDA上运行参考实现
- 在例如CUDA的异构设备上运行参考实现
```shell
cd tests
pytest test_xx_ops.py
Expand Down
15 changes: 10 additions & 5 deletions benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import pytest
import torch

import flag_gems
from flag_gems.runtime import torch_device_fn

from .attri_util import (
ALL_AVAILABLE_METRICS,
BOOL_DTYPES,
Expand All @@ -17,6 +20,8 @@
get_recommended_shapes,
)

device = flag_gems.device


class BenchConfig:
def __init__(self):
Expand All @@ -38,12 +43,12 @@ def pytest_addoption(parser):
parser.addoption(
"--mode",
action="store",
default="cuda",
default=device,
required=False,
choices=["cuda", "cpu"],
choices=[device, "cpu"],
help=(
"Specify how to measure latency, "
"'cpu' for CPU-side measurement or 'cuda' for GPU-side measurement."
f"'cpu' for CPU-side measurement or {device} for GPU-side measurement."
),
)

Expand Down Expand Up @@ -186,13 +191,13 @@ def setup_once(request):
@pytest.fixture(scope="function", autouse=True)
def clear_function_cache():
yield
torch.cuda.empty_cache()
torch_device_fn.empty_cache()


@pytest.fixture(scope="module", autouse=True)
def clear_module_cache():
yield
torch.cuda.empty_cache()
torch_device_fn.empty_cache()


@pytest.fixture()
Expand Down
12 changes: 7 additions & 5 deletions benchmark/performance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import yaml

import flag_gems
from flag_gems.runtime import torch_backend_device, torch_device_fn

from .attri_util import (
BOOL_DTYPES,
Expand All @@ -24,11 +25,12 @@
)
from .conftest import Config

torch.backends.cuda.matmul.allow_tf32 = False
torch_backend_device.matmul.allow_tf32 = False
device = flag_gems.device


class Benchmark:
device: str = "cuda"
device: str = device
DEFAULT_METRICS = DEFAULT_METRICS
DEFAULT_DTYPES = FLOAT_DTYPES
DEFAULT_SHAPES = DEFAULT_SHAPES
Expand Down Expand Up @@ -191,11 +193,11 @@ def get_latency(self, op, *args, **kwargs):
if Config.cpu_mode:
for i in range(Config.warm_up):
fn()
torch.cuda.synchronize()
torch_device_fn.synchronize()
start = time.time()
for i in range(Config.repetition):
fn()
torch.cuda.synchronize()
torch_device_fn.synchronize()
end = time.time()
latency = (end - start) / Config.repetition * 1000
else:
Expand Down Expand Up @@ -309,7 +311,7 @@ def run(self):
level=Config.bench_level.value,
op_name=self.op_name,
dtype=str(dtype),
mode="cpu" if Config.cpu_mode else "cuda",
mode="cpu" if Config.cpu_mode else device,
result=metrics,
)
print(result)
Expand Down
8 changes: 5 additions & 3 deletions examples/model_bert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import flag_gems

device = flag_gems.device


@pytest.mark.parametrize(
"prompt",
Expand All @@ -16,16 +18,16 @@ def test_accuracy_bert(prompt, dtype):
config = BertConfig()
model = BertModel(config)
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
inputs = tokenizer(prompt, return_tensors="pt").to(device)

ref_model = copy.deepcopy(model)
ref_model.to(torch.float64).to("cuda").eval()
ref_model.to(torch.float64).to(device).eval()
ref_inputs = copy.deepcopy(inputs).to(torch.float64)
with torch.no_grad():
ref_outputs = ref_model(**ref_inputs).last_hidden_state.to(dtype)

res_model = copy.deepcopy(model)
res_model.to(dtype).to("cuda").eval()
res_model.to(dtype).to(device).eval()
res_inputs = copy.deepcopy(inputs).to(dtype)
with flag_gems.use_gems():
with torch.no_grad():
Expand Down
6 changes: 4 additions & 2 deletions examples/model_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import flag_gems

device = flag_gems.device


@pytest.mark.parametrize(
"prompt",
Expand All @@ -13,8 +15,8 @@ def test_accuracy_llama(prompt):
tokenizer = AutoTokenizer.from_pretrained("sharpbai/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("sharpbai/Llama-2-7b-hf")

model.to("cuda").eval()
inputs = tokenizer(prompt, return_tensors="pt").to(device="cuda")
model.to(device).eval()
inputs = tokenizer(prompt, return_tensors="pt").to(device=device)
with torch.no_grad():
ref_output = model.generate(**inputs, max_length=100, num_beams=5)

Expand Down
8 changes: 6 additions & 2 deletions examples/model_llava_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import flag_gems

device = flag_gems.device


@pytest.mark.parametrize(
"prompt", ["USER: <image>\nWhat's the content of the image? ASSISTANT:"]
Expand All @@ -22,9 +24,11 @@ def test_accuracy_llava(prompt, url):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
torch.manual_seed(1234)
model.to("cuda").eval()
model.to(device).eval()
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device="cuda")
inputs = processor(text=prompt, images=image, return_tensors="pt").to(
device=flag_gems.device
)

with torch.no_grad():
ref_output = model(**inputs).logits
Expand Down
11 changes: 4 additions & 7 deletions src/flag_gems/fused/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import triton
import triton.language as tl

from ..runtime.moduel_tool import tl_extra_module
from ..utils import pointwise_dynamic

try:
from triton.language.extra.cuda.libdevice import erf, pow, tanh
except ImportError:
try:
from triton.language.math import erf, pow, tanh
except ImportError:
from triton.language.libdevice import erf, pow, tanh
erf = tl_extra_module.erf
pow = tl_extra_module.pow
tanh = tl_extra_module.tanh


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/fused/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton
import triton.language as tl

from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -163,7 +164,7 @@ def apply_rotary_pos_emb(
padded_head_dim = max(triton.next_power_of_2(head_dim), 16)

grid = (n_tokens,)
with torch.cuda.device(q_embed.device):
with torch_device_fn.device(q_embed.device):
apply_rotary_pos_emb_kernel[grid](
q_embed,
k_embed,
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/fused/skip_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -71,7 +72,7 @@ def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5):
bias = bias.contiguous()
y = torch.empty_like(x)

with torch.cuda.device(x.device):
with torch_device_fn.device(x.device):
skip_layer_norm_kernel[M,](
y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
)
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/fused/skip_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -60,7 +61,7 @@ def forward(ctx, x, residual, normalized_shape, weight, eps=1e-5):
weight = weight.contiguous()
y = torch.empty_like(x)

with torch.cuda.device(x.device):
with torch_device_fn.device(x.device):
skip_rms_norm_kernel[M,](
y, x, residual, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
)
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -85,7 +86,7 @@ def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
with torch.cuda.device(mat1.device):
with torch_device_fn.device(mat1.device):
addmm_kernel[grid](
mat1,
mat2,
Expand Down
7 changes: 4 additions & 3 deletions src/flag_gems/ops/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -89,7 +90,7 @@ def all(inp):
mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
out = torch.empty([], dtype=torch.bool, device=inp.device)

with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
all_kernel_2[(1, 1)](mid, out, mid_size, block_mid)

Expand All @@ -114,7 +115,7 @@ def all_dim(inp, dim=None, keepdim=False):
out = torch.empty(shape, dtype=torch.bool, device=inp.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
all_kernel_dim[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
Expand All @@ -140,7 +141,7 @@ def all_dims(inp, dim=None, keepdim=False):
out = torch.empty(shape, dtype=torch.bool, device=inp.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
all_kernel_dim[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
Expand Down
5 changes: 3 additions & 2 deletions src/flag_gems/ops/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -86,7 +87,7 @@ def amax(inp, dim=None, keepdim=False):
for i in range(0, inp.dim()):
shape[i] = 1
out = torch.empty(shape, dtype=dtype, device=inp.device)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
amax_kernel_1[(mid_size, 1)](
inp,
mid,
Expand Down Expand Up @@ -115,7 +116,7 @@ def amax(inp, dim=None, keepdim=False):
out = torch.empty(shape, dtype=dtype, device=inp.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
amax_kernel[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
Expand Down
7 changes: 4 additions & 3 deletions src/flag_gems/ops/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle

Expand Down Expand Up @@ -89,7 +90,7 @@ def any(inp):
mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device)
out = torch.empty([], dtype=torch.bool, device=inp.device)

with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
any_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size)
any_kernel_2[(1, 1)](mid, out, mid_size, block_mid)

Expand All @@ -114,7 +115,7 @@ def any_dim(inp, dim=None, keepdim=False):
out = torch.empty(shape, dtype=torch.bool, device=inp.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
any_kernel_dim[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
Expand All @@ -140,7 +141,7 @@ def any_dims(inp, dim=None, keepdim=False):
out = torch.empty(shape, dtype=torch.bool, device=inp.device)

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
with torch.cuda.device(inp.device):
with torch_device_fn.device(inp.device):
any_kernel_dim[grid](inp, out, M, N)
if not keepdim:
out = out.squeeze(dim=dim)
Expand Down
Loading

0 comments on commit 9de501e

Please sign in to comment.