Skip to content

Commit 014daa6

Browse files
authored
add basic ci functionality (#46)
* add basic ci functionality for serial cpu tests
1 parent 08078ae commit 014daa6

File tree

5 files changed

+194
-2
lines changed

5 files changed

+194
-2
lines changed

.github/workflows/serial-tests.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: Serial CPU Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- develop
7+
pull_request:
8+
branches:
9+
- develop
10+
11+
jobs:
12+
serial-tests:
13+
runs-on: ubuntu-latest
14+
strategy:
15+
matrix:
16+
problem: ["00_dense_la_lu_decomp", "01_dense_la_solve", "02_dense_la_gemm", "03_dense_la_axpy", "04_dense_la_gemv", "05_fft_inverse_fft", "06_fft_dft", "07_fft_fft_conjugate", "08_fft_split_fft", "09_fft_fft_out_of_place", "10_geometry_convex_hull", "11_geometry_convex_hull_perimeter", "12_geometry_smallest_triangle", "13_geometry_closest_pair_2d", "14_geometry_closest_pair_1d", "15_graph_edge_count", "16_graph_largest_component", "17_graph_highest_degree", "18_graph_count_components", "19_graph_shortest_path", "20_histogram_pixel_histogram", "21_histogram_bin_0-100", "22_histogram_count_quadrants", "23_histogram_first_letter_counts", "24_histogram_count_quartile", "25_reduce_xor", "26_reduce_product_of_inverses", "27_reduce_average", "28_reduce_smallest_odd_number", "29_reduce_sum_of_min_of_pairs", "30_scan_prefix_sum", "31_scan_scan_with_min_function", "32_scan_sum_of_prefix_sum_array", "33_scan_reverse_prefix_sum", "34_scan_largest_contiguous_subarray_sum", "35_search_search_for_last_struct_by_key", "36_search_check_if_array_contains_value", "37_search_find_the_closest_number_to_pi", "38_search_find_the_first_even_number", "39_search_xor_contains", "40_sort_sort_an_array_of_complex_numbers_by_magnitude", "41_sort_k-th_smallest_element", "42_sort_sorted_ranks", "43_sort_sort_an_array_of_structs_by_key", "44_sort_sort_non-zero_elements", "45_sparse_la_sparse_solve", "46_sparse_la_spmm", "47_sparse_la_spmv", "48_sparse_la_sparse_axpy", "49_sparse_la_sparse_lu_decomp", "50_stencil_xor_kernel", "51_stencil_edge_kernel", "52_stencil_1d_jacobi_3-point_stencil", "53_stencil_2d_jacobi_5-point_stencil", "54_stencil_game_of_life", "55_transform_relu", "56_transform_negate_odds", "57_transform_inverse_offset", "58_transform_squaring", "59_transform_map_function"]
17+
steps:
18+
- name: Checkout repository
19+
uses: actions/checkout@v3
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v4
23+
with:
24+
python-version: '3.x'
25+
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install tqdm
30+
31+
- name: Run CPU test for ${{ matrix.problem }}
32+
run: bash test/test-serial.bash "${{ matrix.problem }}"

prompts/create-serial-tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_return_type(code: str) -> str:
2424
# then return the type
2525
lines = code.split('\n')
2626
for line in lines:
27-
if line.strip().endswith(') {'):
27+
if "NO_INLINE correct" in line and line.strip().endswith(') {'):
2828
return line.split()[0]
2929

3030
def main():
@@ -45,7 +45,8 @@ def main():
4545
continue
4646

4747
baseline = get_file_contents(baseline_fpath)
48-
impl = get_substr_after_first_of(baseline, ') {')
48+
func_start = get_substr_after_first_of(baseline, 'NO_INLINE correct')
49+
impl = get_substr_after_first_of(func_start, ') {')
4950
return_type = get_return_type(baseline)
5051
prompt['outputs'] = [
5152
impl,

test/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Tests
2+
3+
Testing of the benchmark. Currently only tests the sequential CPU capabilities.

test/test-serial.bash

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/bin/bash
2+
# Uses the baseline implementations to test the CPU capabilities of the system.
3+
4+
# usage: bash test/test-cpu.bash <?problem>
5+
if [ $# -eq 0 ]; then
6+
echo "No problem specified. Using default: 'all'."
7+
PROBLEM_ARG=""
8+
else
9+
PROBLEM_ARG="--problem $1"
10+
fi
11+
12+
# First, use the baseline implementations to mimic LLM outputs.
13+
python prompts/create-serial-tests.py drivers/cpp/benchmarks prompts/generation-prompts.json serial-generations.json
14+
15+
# make sure the model drivers are built
16+
cd drivers
17+
cd cpp
18+
make
19+
cd ..
20+
21+
# Run the drivers using these generations
22+
python run-all.py \
23+
../serial-generations.json \
24+
--output results.json \
25+
--launch-configs launch-configs.json \
26+
--problem-sizes problem-sizes.json \
27+
--yes-to-all \
28+
--include-models serial \
29+
${PROBLEM_ARG} \
30+
--build-timeout 60 \
31+
--run-timeout 120 \
32+
--log info
33+
34+
35+
# check results
36+
cd ..
37+
python test/validate-test-results.py \
38+
--results drivers/results.json \
39+
--problem $1 \
40+
--expected-write 3 \
41+
--expected-source-valid 3 \
42+
--expected-build 2 \
43+
--expected-run 2 \
44+
--expected-correct 1

test/validate-test-results.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
""" Checks if the expected test results are present in the output JSON file.
2+
usage: python test/validate-test-results.py \
3+
--results <results.json> \
4+
--problem <problem_name> \
5+
--expected-write <expected_write_count> \
6+
--expected-source-valid <expected_source_valid_count> \
7+
--expected-build <expected_build_count> \
8+
--expected-run <expected_run_count> \
9+
--expected-correct <expected_correct_count>
10+
"""
11+
from argparse import ArgumentParser
12+
import json
13+
from collections import Counter
14+
15+
16+
def parse_args():
17+
parser = ArgumentParser(description="Validate test results.")
18+
parser.add_argument(
19+
"--results",
20+
type=str,
21+
required=True,
22+
help="Path to the results JSON file.",
23+
)
24+
parser.add_argument(
25+
"--problem",
26+
type=str,
27+
required=True,
28+
help="Name of the problem to validate.",
29+
)
30+
parser.add_argument(
31+
"--expected-write",
32+
type=int,
33+
required=True,
34+
help="Expected number of write operations.",
35+
)
36+
parser.add_argument(
37+
"--expected-source-valid",
38+
type=int,
39+
required=True,
40+
help="Expected number of source valid operations.",
41+
)
42+
parser.add_argument(
43+
"--expected-build",
44+
type=int,
45+
required=True,
46+
help="Expected number of build operations.",
47+
)
48+
parser.add_argument(
49+
"--expected-run",
50+
type=int,
51+
required=True,
52+
help="Expected number of run operations.",
53+
)
54+
parser.add_argument(
55+
"--expected-correct",
56+
type=int,
57+
required=True,
58+
help="Expected number of correct operations.",
59+
)
60+
61+
return parser.parse_args()
62+
63+
64+
def validate_outputs(outputs, expected_counts):
65+
actual_counts = Counter()
66+
67+
for output in outputs:
68+
if output.get("source_write_success", False):
69+
actual_counts["write"] += 1
70+
if output.get("is_source_valid", False):
71+
actual_counts["source_valid"] += 1
72+
if output.get("did_build", False):
73+
actual_counts["build"] += 1
74+
if output.get("did_all_run", False):
75+
actual_counts["run"] += 1
76+
if output.get("are_all_valid", False):
77+
actual_counts["correct"] += 1
78+
79+
for key, expected in expected_counts.items():
80+
actual = actual_counts[key]
81+
if actual != expected:
82+
print(f"Expected {expected} for {key}, but got {actual}.")
83+
return False
84+
return True
85+
86+
87+
def main():
88+
args = parse_args()
89+
90+
# Load the results JSON file
91+
with open(args.results, "r") as f:
92+
results = json.load(f)
93+
94+
# Validate the results
95+
expected_counts = {
96+
"write": args.expected_write,
97+
"source_valid": args.expected_source_valid,
98+
"build": args.expected_build,
99+
"run": args.expected_run,
100+
"correct": args.expected_correct,
101+
}
102+
103+
results = [r for r in results if r["name"] == args.problem][0]
104+
105+
if not validate_outputs(results["outputs"], expected_counts):
106+
print(f"Validation failed for problem {args.problem}.")
107+
return 1
108+
109+
110+
if __name__ == "__main__":
111+
main()
112+

0 commit comments

Comments
 (0)