Skip to content

Commit e63c923

Browse files
authored
Error out if registering prim ops multiple times
Differential Revision: D69090850 Pull Request resolved: #8172
1 parent 81f7c4f commit e63c923

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

codegen/tools/gen_all_oplist.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
4747
return real_path
4848

4949

50+
def _raise_if_check_prim_ops_fail(options):
51+
52+
# Error out if we have more than one targets registering prim ops.
53+
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
54+
assert (
55+
options.DEBUG_ONLY_check_prim_ops[0] == "@"
56+
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
57+
58+
prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
59+
with open(prim_ops_targets_file, "r") as file:
60+
prim_ops_targets = file.read().split()
61+
if len(prim_ops_targets) > 1:
62+
# Yellow bold: \033[33;1m
63+
# Red bold: \033[31;1m
64+
# Green bold: \033[32;1m
65+
error = (
66+
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
67+
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
68+
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
69+
+ "\nTo find out the dependency chain, run the following command: "
70+
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
71+
)
72+
raise Exception(error)
73+
74+
5075
def main(argv: List[Any]) -> None:
5176
"""This binary generates 3 files:
5277
@@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
95120
default=False,
96121
required=False,
97122
)
123+
parser.add_argument(
124+
"--DEBUG-ONLY-check-prim-ops",
125+
"--DEBUG_ONLY_check_prim_ops",
126+
help=(
127+
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
128+
),
129+
required=False,
130+
)
98131
options = parser.parse_args(argv)
99132

133+
_raise_if_check_prim_ops_fail(options)
134+
100135
# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
101136
# 1. a yaml file containing selected ops (could be empty), or
102137
# 2. a non-empty list of yaml files in the `model_file_list_path` or
@@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
153188
debug_info_2 = ",".join(
154189
model_dict["operators"][op_name]["debug_info"]
155190
)
156-
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
191+
# Yellow bold: \033[33;1m
192+
# Red bold: \033[31;1m
193+
# Green bold: \033[32;1m
194+
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
157195
if "//" not in debug_info_1 and "//" not in debug_info_2:
158196
error += "\nWe can't determine what BUCK targets these model files belong to."
159197
tail = "."
160198
else:
161199
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
162-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
163-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
200+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
201+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
164202
tail = "as well as results from BUCK commands listed above."
165203

166204
error += (

shim/xplat/executorch/codegen/codegen.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ def executorch_ops_check(
692692
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
693693
"--allow_include_all_overloads " +
694694
"--check_ops_not_overlapping " +
695+
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
695696
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
696697
define_static_target = False,
697698
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),

0 commit comments

Comments
 (0)