Skip to content

Commit d285d27

Browse files
committed
fix dynamic shape bugs for test_binary_ops_aten
1 parent b0e92d8 commit d285d27

File tree

1 file changed

+40
-31
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/elementwise

1 file changed

+40
-31
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+40-31
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
cast_trt_tensor,
1414
get_trt_tensor,
1515
)
16-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
16+
from torch_tensorrt.fx.converters.converter_utils import (
17+
broadcast,
18+
has_dynamic_shape,
19+
set_layer_name,
20+
)
1721
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTTensor
1822
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1923

@@ -148,38 +152,43 @@ def convert_binary_elementwise(
148152
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
149153
)
150154

151-
lhs_val_shape = lhs_val.shape
152-
rhs_val_shape = rhs_val.shape
153-
rank_diff = len(lhs_val_shape) - len(rhs_val_shape)
154-
if rank_diff > 0:
155-
rhs_val = impl.slice.expand(
156-
ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape
157-
)
158-
elif rank_diff < 0:
159-
lhs_val = impl.slice.expand(
160-
ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape
155+
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
156+
lhs_val, rhs_val = broadcast(
157+
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
161158
)
162159
else:
163-
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
164-
sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape)
165-
if sum_diff > 0:
166-
rhs_val = impl.slice.expand(
167-
ctx,
168-
target,
169-
source_ir,
170-
f"{name}_expand_rhs_val",
171-
rhs_val,
172-
lhs_val_shape,
173-
)
174-
elif sum_diff < 0:
175-
lhs_val = impl.slice.expand(
176-
ctx,
177-
target,
178-
source_ir,
179-
f"{name}_expand_lhs_val",
180-
lhs_val,
181-
rhs_val_shape,
182-
)
160+
lhs_val_shape = lhs_val.shape
161+
rhs_val_shape = rhs_val.shape
162+
rank_diff = len(lhs_val_shape) - len(rhs_val_shape)
163+
if rank_diff > 0:
164+
rhs_val = impl.slice.expand(
165+
ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape
166+
)
167+
elif rank_diff < 0:
168+
lhs_val = impl.slice.expand(
169+
ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape
170+
)
171+
else:
172+
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
173+
sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape)
174+
if sum_diff > 0:
175+
rhs_val = impl.slice.expand(
176+
ctx,
177+
target,
178+
source_ir,
179+
f"{name}_expand_rhs_val",
180+
rhs_val,
181+
lhs_val_shape,
182+
)
183+
elif sum_diff < 0:
184+
lhs_val = impl.slice.expand(
185+
ctx,
186+
target,
187+
source_ir,
188+
f"{name}_expand_lhs_val",
189+
lhs_val,
190+
rhs_val_shape,
191+
)
183192

184193
layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type)
185194
set_layer_name(layer, target, name, source_ir)

0 commit comments

Comments
 (0)