Commit 8255f87 1 parent b459ccc commit 8255f87 Copy full SHA for 8255f87
File tree 1 file changed +5
-2
lines changed
transformer_engine/pytorch/module
1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -111,6 +111,7 @@ def initialize_ub(
111
111
shape : list ,
112
112
tp_size : int ,
113
113
use_fp8 : bool = False ,
114
+ dtype : torch .dtype = torch .bfloat16 ,
114
115
ub_cfgs : Optional [dict ] = None
115
116
) -> None :
116
117
"""Initialize communicators for TP comm overlap using userbuffers."""
@@ -151,8 +152,10 @@ def add_ub(
151
152
num_splits : int = 4 ,
152
153
aggregate : int = 0 ,
153
154
) -> None :
154
- dtype = torch .uint8 if (use_fp8 and name in fp8_buf ) else torch .bfloat16
155
- sample_buffer = torch .empty (shape , dtype = dtype , device = 'cuda' )
155
+ sample_buffer = torch .empty (
156
+ shape ,
157
+ dtype = torch .uint8 if (use_fp8 and name in fp8_buf ) else dtype ,
158
+ device = 'cuda' )
156
159
if method == 'ring_exchange' :
157
160
ub_obj = tex .UbufP2PCommOverlap (
158
161
sample_buffer , # Sample userbuffer
You can’t perform that action at this time.
0 commit comments