|
13 | 13 | cast_trt_tensor,
|
14 | 14 | get_trt_tensor,
|
15 | 15 | )
|
16 |
| -from torch_tensorrt.fx.converters.converter_utils import set_layer_name |
| 16 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 17 | + broadcast, |
| 18 | + has_dynamic_shape, |
| 19 | + set_layer_name, |
| 20 | +) |
17 | 21 | from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor
|
18 | 22 | from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
|
19 | 23 |
|
@@ -148,38 +152,43 @@ def convert_binary_elementwise(
|
148 | 152 | ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
|
149 | 153 | )
|
150 | 154 |
|
151 |
| - lhs_val_shape = lhs_val.shape |
152 |
| - rhs_val_shape = rhs_val.shape |
153 |
| - rank_diff = len(lhs_val_shape) - len(rhs_val_shape) |
154 |
| - if rank_diff > 0: |
155 |
| - rhs_val = impl.slice.expand( |
156 |
| - ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape |
157 |
| - ) |
158 |
| - elif rank_diff < 0: |
159 |
| - lhs_val = impl.slice.expand( |
160 |
| - ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape |
| 155 | + if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): |
| 156 | + lhs_val, rhs_val = broadcast( |
| 157 | + ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" |
161 | 158 | )
|
162 | 159 | else:
|
163 |
| - if tuple(lhs_val_shape) != tuple(rhs_val_shape): |
164 |
| - sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape) |
165 |
| - if sum_diff > 0: |
166 |
| - rhs_val = impl.slice.expand( |
167 |
| - ctx, |
168 |
| - target, |
169 |
| - source_ir, |
170 |
| - f"{name}_expand_rhs_val", |
171 |
| - rhs_val, |
172 |
| - lhs_val_shape, |
173 |
| - ) |
174 |
| - elif sum_diff < 0: |
175 |
| - lhs_val = impl.slice.expand( |
176 |
| - ctx, |
177 |
| - target, |
178 |
| - source_ir, |
179 |
| - f"{name}_expand_lhs_val", |
180 |
| - lhs_val, |
181 |
| - rhs_val_shape, |
182 |
| - ) |
| 160 | + lhs_val_shape = lhs_val.shape |
| 161 | + rhs_val_shape = rhs_val.shape |
| 162 | + rank_diff = len(lhs_val_shape) - len(rhs_val_shape) |
| 163 | + if rank_diff > 0: |
| 164 | + rhs_val = impl.slice.expand( |
| 165 | + ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape |
| 166 | + ) |
| 167 | + elif rank_diff < 0: |
| 168 | + lhs_val = impl.slice.expand( |
| 169 | + ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape |
| 170 | + ) |
| 171 | + else: |
| 172 | + if tuple(lhs_val_shape) != tuple(rhs_val_shape): |
| 173 | + sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape) |
| 174 | + if sum_diff > 0: |
| 175 | + rhs_val = impl.slice.expand( |
| 176 | + ctx, |
| 177 | + target, |
| 178 | + source_ir, |
| 179 | + f"{name}_expand_rhs_val", |
| 180 | + rhs_val, |
| 181 | + lhs_val_shape, |
| 182 | + ) |
| 183 | + elif sum_diff < 0: |
| 184 | + lhs_val = impl.slice.expand( |
| 185 | + ctx, |
| 186 | + target, |
| 187 | + source_ir, |
| 188 | + f"{name}_expand_lhs_val", |
| 189 | + lhs_val, |
| 190 | + rhs_val_shape, |
| 191 | + ) |
183 | 192 |
|
184 | 193 | layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type)
|
185 | 194 | set_layer_name(layer, target, name, source_ir)
|
|
0 commit comments