1
1
import flashinfer
2
-
3
2
import torch
4
3
import torch .nn as nn
5
4
import torch_tensorrt
6
5
from parameterized import parameterized
7
6
from torch .testing ._internal .common_utils import run_tests
7
+ from torch_tensorrt ._enums import dtype
8
8
9
9
from ..conversion .harness import DispatchTestCase
10
- import flashinfer
11
10
12
11
13
- @torch .library .custom_op ("torchtrt_ex::flashinfer_rmsnorm " , mutates_args = ()) # type: ignore[misc]
12
+ @torch .library .custom_op ("flashinfer::rmsnorm " , mutates_args = ()) # type: ignore[misc]
14
13
def flashinfer_rmsnorm (
15
14
input : torch .Tensor , weight : torch .Tensor , eps : float = 1e-6
16
15
) -> torch .Tensor :
17
16
return flashinfer .norm .rmsnorm (input , weight )
18
17
19
18
20
- @torch .library .register_fake ("torchtrt_ex::flashinfer_rmsnorm " )
19
+ @torch .library .register_fake ("flashinfer::rmsnorm " )
21
20
def _ (input : torch .Tensor , weight : torch .Tensor , b : float = 1e-6 ) -> torch .Tensor :
22
21
return input
23
22
24
23
25
-
26
24
torch_tensorrt .dynamo .conversion .plugins .custom_op (
27
- "torchtrt_ex::flashinfer_rmsnorm " , supports_dynamic_shapes = True
25
+ "flashinfer::rmsnorm " , supports_dynamic_shapes = True
28
26
)
29
27
30
28
31
29
class TestAutomaticPlugin (DispatchTestCase ):
32
30
@parameterized .expand (
33
31
[
34
- ((64 , 64 ), (64 , ), torch .float16 ),
35
- ((256 , 256 ), (256 , ), torch .float16 ),
32
+ ((64 , 64 ), (64 ,), torch .float16 ),
33
+ ((256 , 256 ), (256 ,), torch .float16 ),
36
34
]
37
35
)
38
- def test_rmsnorm_float (self , input_shape , weight_shape , dtype ):
36
+ def test_rmsnorm_float (self , input_shape , weight_shape , data_type ):
39
37
class rmsnorm (nn .Module ):
40
38
def forward (self , input , weight ):
41
- return torch .ops .torchtrt_ex . flashinfer_rmsnorm .default (input , weight )
39
+ return torch .ops .flashinfer . rmsnorm .default (input , weight )
42
40
43
- inputs = [torch .randn (input_shape , device = "cuda" , dtype = dtype ), torch .randn (weight_shape , device = "cuda" , dtype = dtype )]
41
+ inputs = [
42
+ torch .randn (input_shape , device = "cuda" , dtype = data_type ),
43
+ torch .randn (weight_shape , device = "cuda" , dtype = data_type ),
44
+ ]
44
45
45
- self .run_test (rmsnorm (), inputs )
46
+ self .run_test (rmsnorm (), inputs , precision = dtype . f16 )
46
47
47
48
48
49
if __name__ == "__main__" :
49
- run_tests ()
50
+ run_tests ()
0 commit comments