Skip to content

Commit f12bdfb

Browse files
committed
update test case
1 parent 59e541f commit f12bdfb

File tree

3 files changed

+28
-34
lines changed

3 files changed

+28
-34
lines changed

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
130130
output = torch_op(*fake_args, **kwargs)
131131

132132
# We assume that number of dimensions are the same in torch op
133-
# shape_calc_fns = [None] * args[0].ndim
134133
shape_calc_fns = [None] * output.ndim
135-
134+
136135
for i in range(output.ndim):
137-
input_node_expr = [syms_arg[j].node.expr for syms_arg in syms_args for j in range(len(syms_arg))]
138-
shape_calc_fns[i] = lambdify(tuple(input_node_expr), output.shape[i].node.expr, "math")
139-
140-
# for i in range(args[0].ndim):
141-
# input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
142-
# shape_calc_fns[i] = lambdify(
143-
# tuple(input_node_expr), output.shape[i].node.expr, "math"
144-
# )
136+
input_node_expr = [
137+
syms_arg[j].node.expr
138+
for syms_arg in syms_args
139+
for j in range(len(syms_arg))
140+
]
141+
shape_calc_fns[i] = lambdify(
142+
tuple(input_node_expr), output.shape[i].node.expr, "math"
143+
)
145144

146145
out_desc = tensor_args[0].like()
147146
for i in range(out_desc.ndim):
148-
input_shape_expr = [arg.shape_expr[j] for arg in tensor_args for j in range(len(arg.shape_expr))]
149-
# input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
147+
input_shape_expr = [
148+
arg.shape_expr[j]
149+
for arg in tensor_args
150+
for j in range(len(arg.shape_expr))
151+
]
150152
if output.shape[i].node.expr is None:
151153
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
152154
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

tests/py/dynamo/automatic_plugin/test_automatic_plugin.py

-9
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,3 @@ def forward(self, lhs, rhs):
8181

8282
if __name__ == "__main__":
8383
run_tests()
84-
85-
# Example Usage
86-
# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
87-
# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
88-
89-
# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B)
90-
91-
# print("C (Addition):", C)
92-
# print("D (Multiplication):", D)
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,50 @@
11
import flashinfer
2-
32
import torch
43
import torch.nn as nn
54
import torch_tensorrt
65
from parameterized import parameterized
76
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt._enums import dtype
88

99
from ..conversion.harness import DispatchTestCase
10-
import flashinfer
1110

1211

13-
@torch.library.custom_op("torchtrt_ex::flashinfer_rmsnorm", mutates_args=()) # type: ignore[misc]
12+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
1413
def flashinfer_rmsnorm(
1514
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
1615
) -> torch.Tensor:
1716
return flashinfer.norm.rmsnorm(input, weight)
1817

1918

20-
@torch.library.register_fake("torchtrt_ex::flashinfer_rmsnorm")
19+
@torch.library.register_fake("flashinfer::rmsnorm")
2120
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
2221
return input
2322

2423

25-
2624
torch_tensorrt.dynamo.conversion.plugins.custom_op(
27-
"torchtrt_ex::flashinfer_rmsnorm", supports_dynamic_shapes=True
25+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
2826
)
2927

3028

3129
class TestAutomaticPlugin(DispatchTestCase):
3230
@parameterized.expand(
3331
[
34-
((64, 64), (64, ), torch.float16),
35-
((256, 256), (256, ), torch.float16),
32+
((64, 64), (64,), torch.float16),
33+
((256, 256), (256,), torch.float16),
3634
]
3735
)
38-
def test_rmsnorm_float(self, input_shape, weight_shape, dtype):
36+
def test_rmsnorm_float(self, input_shape, weight_shape, data_type):
3937
class rmsnorm(nn.Module):
4038
def forward(self, input, weight):
41-
return torch.ops.torchtrt_ex.flashinfer_rmsnorm.default(input, weight)
39+
return torch.ops.flashinfer.rmsnorm.default(input, weight)
4240

43-
inputs = [torch.randn(input_shape, device="cuda", dtype=dtype), torch.randn(weight_shape, device="cuda", dtype=dtype)]
41+
inputs = [
42+
torch.randn(input_shape, device="cuda", dtype=data_type),
43+
torch.randn(weight_shape, device="cuda", dtype=data_type),
44+
]
4445

45-
self.run_test(rmsnorm(), inputs)
46+
self.run_test(rmsnorm(), inputs, precision=dtype.f16)
4647

4748

4849
if __name__ == "__main__":
49-
run_tests()
50+
run_tests()

0 commit comments

Comments
 (0)