Skip to content

Commit 3cc6ecb

Browse files
committed
resolve comments
1 parent f12bdfb commit 3cc6ecb

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

.github/workflows/build-test-linux.yml

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ jobs:
143143
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
144144
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py
145145
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
146+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/flashinfer_plugin.py
146147
popd
147148
148149
tests-py-dynamo-fe:

py/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ torch>=2.7.0.dev,<2.8.0
66
torchvision>=0.22.0.dev,<0.23.0
77
--extra-index-url https://pypi.ngc.nvidia.com
88
pyyaml
9+
flashinfer-python

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from types import FunctionType
34
from typing import Any, Callable, Tuple
@@ -133,22 +134,22 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
133134
shape_calc_fns = [None] * output.ndim
134135

135136
for i in range(output.ndim):
136-
input_node_expr = [
137-
syms_arg[j].node.expr
138-
for syms_arg in syms_args
139-
for j in range(len(syms_arg))
140-
]
137+
input_node_expr = input_node_expr = list(
138+
itertools.chain.from_iterable(
139+
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
140+
)
141+
)
142+
141143
shape_calc_fns[i] = lambdify(
142144
tuple(input_node_expr), output.shape[i].node.expr, "math"
143145
)
144146

145147
out_desc = tensor_args[0].like()
146148
for i in range(out_desc.ndim):
147-
input_shape_expr = [
148-
arg.shape_expr[j]
149-
for arg in tensor_args
150-
for j in range(len(arg.shape_expr))
151-
]
149+
input_shape_expr = list(
150+
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
151+
)
152+
152153
if output.shape[i].node.expr is None:
153154
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
154155
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

0 commit comments

Comments
 (0)