Skip to content

Commit

Permalink
[Tool] Detect SQL-Tester conf and filter the cases (#50336)
Browse files Browse the repository at this point in the history
Signed-off-by: AndyZiYe <[email protected]>
  • Loading branch information
andyziye authored Sep 4, 2024
1 parent 5822240 commit 4eff9df
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 124 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ jobs:
build:
runs-on: [self-hosted, normal]
needs: [test-checker, clang-tidy, fe-ut, thirdparty-info]
needs: [test-checker, be-ut, fe-ut, thirdparty-info]
name: BUILD
env:
PR_NUMBER: ${{ github.event.number }}
Expand All @@ -606,7 +606,7 @@ jobs:
is_self_build: ${{ steps.run_build.outputs.is_self_build }}
build_nece: ${{ steps.check-necessity.outputs.BUILD_NECE }}
if: >
always() && needs.clang-tidy.result != 'failure' && needs.fe-ut.result != 'failure'
always() && needs.be-ut.result != 'failure' && needs.fe-ut.result != 'failure' && (needs.be-ut.result == 'success' || needs.fe-ut.result == 'success' || needs.test-checker.result == 'success')
steps:
- name: CLEAN
run: |
Expand Down Expand Up @@ -666,9 +666,9 @@ jobs:
- name: Check necessity
id: check-necessity
if: >
(needs.clang-tidy.result == 'success' && needs.fe-ut.result == 'success') ||
(needs.be-ut.result == 'success' && needs.fe-ut.result == 'success') ||
(steps.parsing-be-path-filter.outputs.src_filter != 'true' && steps.parsing-fe-path-filter.outputs.src_filter == 'true' && needs.fe-ut.result == 'success') ||
(steps.parsing-fe-path-filter.outputs.src_filter != 'true' && steps.parsing-be-path-filter.outputs.src_filter == 'true' && needs.clang-tidy.result == 'success') ||
(steps.parsing-fe-path-filter.outputs.src_filter != 'true' && steps.parsing-be-path-filter.outputs.src_filter == 'true' && needs.be-ut.result == 'success') ||
(steps.parsing-be-path-filter.outputs.src_filter != 'true' && steps.parsing-fe-path-filter.outputs.src_filter != 'true' && needs.test-checker.outputs.output1 == 'true')
run: |
echo "BUILD_NECE=true" >> $GITHUB_OUTPUT
Expand Down
125 changes: 65 additions & 60 deletions test/conf/sr.conf
Original file line number Diff line number Diff line change
@@ -1,69 +1,74 @@
[mysql-client]
host =
port =
user =
password =
http_port =
host_user =
host_password =
cluster_path =
[cluster]
host =
port =
user =
password =
http_port =
host_user =
host_password =
cluster_path =

[trino-client]
host =
port =
user =
[client]
[.trino-client]
host =
port =
user =

[hive-client]
host =
port =
user =

[spark-client]
host =
port =
user =
[.hive-client]
host =
port =
user =
[.spark-client]
host =
port =
user =

[replace]
url = http://${mysql-client:host}:${mysql-client:http_port}
mysql_cmd = mysql -h${mysql-client:host} -P${mysql-client:port} -u${mysql-client:user}
url = http://${cluster.host}:${cluster.http_port}
mysql_cmd = mysql -h${cluster.host} -P${cluster.port} -u${cluster.user}

[env]
oss_bucket =
oss_ak =
oss_sk =
oss_region =
oss_endpoint =

hdfs_host =
hdfs_port =
hdfs_user =
hdfs_passwd =
hdfs_path = /starrocks_ci_data
hdfs_broker_name = hdfs_broker

hive_metastore_uris =

hudi_hive_metastore_uris =

iceberg_catalog_hive_metastore_uris =

deltalake_catalog_hive_metastore_uris =

external_mysql_ip =
external_mysql_port =
external_mysql_user =
external_mysql_password =
jdbc_url =

aws_ak =
aws_sk =
aws_region =
aws_assume_role =
aws_sts_region =
aws_sts_endpoint =
[.oss]
oss_bucket =
oss_ak =
oss_sk =
oss_region =
oss_endpoint =

udf_url = http://starrocks-thirdparty.oss-cn-zhangjiakou.aliyuncs.com
[.hdfs]
hdfs_host =
hdfs_port =
hdfs_user =
hdfs_passwd =
hdfs_path = /starrocks_ci_data
hdfs_broker_name = hdfs_broker

[.hive]
hive_metastore_uris =
deltalake_catalog_hive_metastore_uris =
hudi_hive_metastore_uris =
iceberg_catalog_hive_metastore_uris =

[.kafka]
broker_list =
kafka_tool_path =

[.mysql]
external_mysql_ip =
external_mysql_port =
external_mysql_user =
external_mysql_password =
jdbc_url =

[.aws]
aws_ak =
aws_sk =
aws_region =
aws_assume_role =
aws_sts_region =
aws_sts_endpoint =


broker_list =
kafka_tool_path =
[.others]
udf_url = http://starrocks-thirdparty.oss-cn-zhangjiakou.aliyuncs.com
164 changes: 143 additions & 21 deletions test/lib/choose_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@


CASE_DIR = "sql"
LOG_FILTERED_WARN = "You can use `--log_filtered` to show the details..."


class ChooseCase(object):
Expand Down Expand Up @@ -86,7 +87,7 @@ def __str__(self):
def __init__(self, case_dir=None, record_mode=False, file_regex=None, case_regex=None):
"""init"""
super().__init__()
# self.sr_lib_obj = sr_sql_lib.StarrocksSQLApiLib()
self.sr_lib_obj = sr_sql_lib.StarrocksSQLApiLib()

# case_dir = sql dir by default
self.case_dir = os.path.join(sr_sql_lib.root_path, CASE_DIR) if case_dir is None else case_dir
Expand All @@ -98,6 +99,8 @@ def __init__(self, case_dir=None, record_mode=False, file_regex=None, case_regex

self.list_t_r_files(file_regex)
self.get_cases(record_mode, case_regex)
self.filter_cases_by_component_status()
self.filter_cases_by_data_status()

# sort
self.case_list.sort()
Expand Down Expand Up @@ -149,13 +152,146 @@ def get_cases(self, record_mode, case_regex):
for file in file_list:
base_file = os.path.basename(file)
if base_file in skip.skip_files:
print('skip file {} because it is in skip_files'.format(file))
sr_sql_lib.self_print(f'skip file {file} because it is in skip_files', color=ColorEnum.YELLOW)
continue

self.read_t_r_file(file, case_regex)

self.case_list = list(filter(lambda x: x.name.strip() != "", self.case_list))

def filter_cases_by_component_status(self):
""" filter cases by component status """
new_case_list = []

filtered_cases_dict = {}

for each_case in self.case_list:
_case_sqls = []
for each_stat in each_case.sql:
if isinstance(each_stat, str):
_case_sqls.append(each_stat)
elif isinstance(each_stat, dict) and each_stat.get("type", "") == LOOP_FLAG:
tools.assert_in("stat", each_stat, "LOOP STATEMENT FORMAT ERROR!")
_case_sqls.extend(each_stat["stat"])
elif isinstance(each_stat, dict) and each_stat.get("type", "") == CONCURRENCY_FLAG:
tools.assert_in("thread", each_stat, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in each_stat["thread"]:
_case_sqls.extend(each_thread["cmd"])
else:
tools.ok_(False, "Init data error!")

is_pass = True

# check trino/spark/hive flag
for each_client_flag in [TRINO_FLAG, SPARK_FLAG, HIVE_FLAG]:
if any(_case_sql.lstrip().startswith(each_client_flag) for _case_sql in _case_sqls):
# check client status
client_name = each_client_flag.split(":")[0].lower() + "-client"
if client_name not in self.sr_lib_obj.component_status:
sr_sql_lib.self_print(f"[Config ERROR]: {client_name} config not found!")
filtered_cases_dict.setdefault(client_name, []).append(each_case.name)
is_pass = False
break

if not self.sr_lib_obj.component_status[client_name]["status"]:
filtered_cases_dict.setdefault(client_name, []).append(each_case.name)
is_pass = False
break

if not is_pass:
# client check failed, no need to check component info
continue

# check ${} contains component info
_case_sqls = " ".join(_case_sqls)
_vars = re.findall(r"\${([a-zA-Z0-9._-]+)}", _case_sqls)
for _var in _vars:

if not is_pass:
break

if _var not in self.sr_lib_obj.__dict__:
continue

for each_component_name, each_component_info in self.sr_lib_obj.component_status.items():
if _var in each_component_info["keys"] and not each_component_info["status"]:
filtered_cases_dict.setdefault(each_component_name, []).append(each_case.name)
is_pass = False
break

if is_pass:
new_case_list.append(each_case)

if filtered_cases_dict:
if os.environ.get("log_filtered") == "True":
sr_sql_lib.self_print(f"\n{'-' * 60}\n[Component filter]\n{'-' * 60}", color=ColorEnum.BLUE,
logout=True, bold=True)
for k, cases in filtered_cases_dict.items():
sr_sql_lib.self_print(f"▶ {k.upper()}", color=ColorEnum.BLUE, logout=True, bold=True)
sr_sql_lib.self_print(f" ▶ %s" % '\n ▶ '.join(cases), logout=True)
sr_sql_lib.self_print('-' * 60, color=ColorEnum.BLUE, logout=True, bold=True)
else:
filtered_count = sum([len(x) for x in filtered_cases_dict.values()])
sr_sql_lib.self_print(f"\n{'-' * 60}\n[Component filter]: {filtered_count}\n{LOG_FILTERED_WARN}\n{'-' * 60}",
color=ColorEnum.BLUE, logout=True, bold=True)

self.case_list = new_case_list

def filter_cases_by_data_status(self):
""" filter cases by data status """
new_case_list = []

filtered_cases_dict = {}

for each_case in self.case_list:
_case_sqls = []
for each_stat in each_case.sql:
if isinstance(each_stat, str):
_case_sqls.append(each_stat)
elif isinstance(each_stat, dict) and each_stat.get("type", "") == LOOP_FLAG:
tools.assert_in("stat", each_stat, "LOOP STATEMENT FORMAT ERROR!")
_case_sqls.extend(each_stat["stat"])
elif isinstance(each_stat, dict) and each_stat.get("type", "") == CONCURRENCY_FLAG:
tools.assert_in("thread", each_stat, "CONCURRENCY THREAD FORMAT ERROR!")
for each_thread in each_stat["thread"]:
_case_sqls.extend(each_thread["cmd"])
else:
tools.ok_(False, "Init data error!")

is_pass = True
# check trino/spark/hive flag
function_stats = list(filter(lambda x: x.lstrip().startswith(FUNCTION_FLAG) and "prepare_data(" in x, _case_sqls))

for func_stat in function_stats:
if not is_pass:
break

data_source_names = re.findall(r"prepare_data\(['|\"]([a-zA-Z_-]+)['|\"]", func_stat)
for data_source in data_source_names:

if self.sr_lib_obj.data_status.get(data_source, False) is False:
filtered_cases_dict.setdefault(data_source, []).append(each_case.name)
is_pass = False
break

if is_pass:
new_case_list.append(each_case)

if filtered_cases_dict:
if os.environ.get("log_filtered") == "True":
sr_sql_lib.self_print(f"\n{'-' * 60}\n[Data filter]\n{'-' * 60}", color=ColorEnum.BLUE, logout=True, bold=True)
for k in list(sorted(filtered_cases_dict.keys())):
cases = filtered_cases_dict[k]
sr_sql_lib.self_print(f"▶ {k.upper()}", color=ColorEnum.BLUE, logout=True, bold=True)
sr_sql_lib.self_print(f" ▶ %s" % '\n ▶ '.join(cases), logout=True)
sr_sql_lib.self_print('-' * 60, color=ColorEnum.BLUE, logout=True, bold=True)
else:
filtered_count = sum([len(x) for x in filtered_cases_dict.values()])
sr_sql_lib.self_print(f"\n{'-' * 60}\n[Data filter]: {filtered_count}\n{LOG_FILTERED_WARN}\n{'-' * 60}",
color=ColorEnum.BLUE, logout=True, bold=True)

self.case_list = new_case_list

def read_t_r_file(self, file, case_regex):
"""read t r file and get case & result"""

Expand Down Expand Up @@ -316,7 +452,8 @@ def __read_single_stat_and_result(_line_content, _line_id, _stat_list, _res_list
and not re.compile(f'}}(\\s)*{END_LOOP_FLAG}').fullmatch(f_lines[line_id].strip())):
# read loop stats, unnecessary to record result
line_content = f_lines[line_id].strip()
line_id = __read_single_stat_and_result(line_content, line_id, tmp_sql, tmp_res, in_loop_flag, tmp_loop_stat)
line_id = __read_single_stat_and_result(line_content, line_id, tmp_sql, tmp_res, in_loop_flag,
tmp_loop_stat)
tools.assert_less(line_id, len(f_lines), "LOOP FORMAT ERROR!")

# reach the end loop line
Expand All @@ -328,7 +465,7 @@ def __read_single_stat_and_result(_line_content, _line_id, _stat_list, _res_list
"type": LOOP_FLAG,
"stat": tmp_loop_stat,
"prop": tmp_loop_prop,
"ori": f_lines[l_loop_line: r_loop_line+1]
"ori": f_lines[l_loop_line: r_loop_line + 1]
})
tmp_res.append(None)
line_id += 1
Expand Down Expand Up @@ -439,9 +576,8 @@ def choose_cases(record_mode=False):
filename_regex = os.environ.get("file_filter")
case_name_regex = os.environ.get("case_filter")

run_info = f"""
{'-' * 60}
[DIR]: {confirm_case_dir}
run_info = f"""{'-' * 60}
[DIR]: {"DEFAULT" if confirm_case_dir is None else confirm_case_dir}
[Mode]: {"RECORD" if record_mode else "VALIDATE"}
[file regex]: {filename_regex}
[case regex]: {case_name_regex}
Expand All @@ -458,17 +594,3 @@ def choose_cases(record_mode=False):
log.info("%s:%s" % (case.file, case.name))

return cases


def check_db_unique(case_list: List[ChooseCase.CaseTR]):
"""check db unique in case list"""
db_and_case_dict = {}

# get info dict, key: db value: [..case_names]
for case in case_list:
for each_db in case.db:
db_and_case_dict.setdefault(each_db, []).append(case.name)

error_info_dict = {db: cases for db, cases in db_and_case_dict.items() if len(cases) > 1}

tools.assert_true(len(error_info_dict) <= 0, "Duplicate DBs: \n%s" % json.dumps(error_info_dict, indent=2))
Loading

0 comments on commit 4eff9df

Please sign in to comment.