Skip to content

Commit

Permalink
update tests to use spec strings instead of dict payloads
Browse files Browse the repository at this point in the history
  • Loading branch information
cmelone committed Mar 8, 2024
1 parent 3c8bc2c commit 106d3e1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 65 deletions.
50 changes: 11 additions & 39 deletions gantry/tests/defs/prediction.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,34 @@
# flake8: noqa
# fmt: off

NORMAL_BUILD = {
"hash": "testing",
"package": {
"name": "py-torch",
"version": "2.2.1",
"variants": "~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack build_system=python_pip cuda_arch=80",
},
"compiler": {
"name": "gcc",
"version": "11.4.0",
},
}
from gantry.util.spec import parse_alloc_spec

NORMAL_BUILD = parse_alloc_spec(
"[email protected] ~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack build_system=python_pip cuda_arch=80%[email protected]"
)

# everything in NORMAL_BUILD["package"]["variants"] except removing build_system=python_pip
# in order to test the expensive variants filter
EXPENSIVE_VARIANT_BUILD = {
"hash": "testing",
"package": {
"name": "py-torch",
"version": "2.2.1",
"variants": "~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack cuda_arch=80",
},
"compiler": {
"name": "gcc",
"version": "11.4.0",
},
}
EXPENSIVE_VARIANT_BUILD = parse_alloc_spec(
"[email protected] ~caffe2+cuda+cudnn~debug+distributed+fbgemm+gloo+kineto~magma~metal+mkldnn+mpi~nccl+nnpack+numa+numpy+onnx_ml+openmp+qnnpack~rocm+tensorpipe~test+valgrind+xnnpack cuda_arch=80%[email protected]"
)

# no variants should match this, so we expect the default prediction
BAD_VARIANT_BUILD = {
"hash": "testing",
"package": {
"name": "py-torch",
"version": "2.2.1",
"variants": "+no~expensive~variants+match",
},
"compiler": {
"name": "gcc",
"version": "11.4.0",
},
}
BAD_VARIANT_BUILD = parse_alloc_spec(
"[email protected] +no~expensive~variants+match%[email protected]"
)

# calculated by running the baseline prediction algorithm on the sample data in gantry/tests/sql/insert_prediction.sql
NORMAL_PREDICTION = {
"hash": "testing",
"variables": {
"KUBERNETES_CPU_REQUEST": "12",
"KUBERNETES_MEMORY_REQUEST": "9576M",
},
}


# this is what will get returned when there are no samples in the database
# that match what the client wants
DEFAULT_PREDICTION = {
"hash": "testing",
"variables": {
"KUBERNETES_CPU_REQUEST": "1",
"KUBERNETES_MEMORY_REQUEST": "2000M",
Expand Down
51 changes: 25 additions & 26 deletions gantry/tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from gantry.routes.prediction import prediction
from gantry.tests.defs import prediction as defs
from gantry.util.prediction import validate_payload
from gantry.util.spec import parse_alloc_spec


@pytest.fixture
Expand Down Expand Up @@ -57,7 +57,7 @@ async def test_partial_match(db_conn_inserted):

# same as NORMAL_BUILD, but with a different compiler name to test partial matching
diff_compiler_build = defs.NORMAL_BUILD.copy()
diff_compiler_build["compiler"]["name"] = "gcc-different"
diff_compiler_build["compiler_name"] = "gcc-different"

assert (
await prediction.predict_single(db_conn_inserted, diff_compiler_build)
Expand All @@ -75,37 +75,36 @@ async def test_empty_sample(db_conn):


# Test validate_payload
def test_valid_spec():
"""Tests that a valid spec is parsed correctly."""
assert parse_alloc_spec("[email protected] +json+native+treesitter%[email protected]") == {
"pkg_name": "emacs",
"pkg_version": "29.2",
"pkg_variants": '{"json": true, "native": true, "treesitter": true}',
"pkg_variants_dict": {"json": True, "native": True, "treesitter": True},
"compiler_name": "gcc",
"compiler_version": "12.3.0",
}


def test_valid_payload():
"""Tests that a valid payload returns True"""
assert validate_payload(defs.NORMAL_BUILD) is True
def test_invalid_specs():
"""Test a series of invalid specs"""

# not a spec
assert parse_alloc_spec("hi") == {}

def test_invalid_payloads():
"""Test a series of invalid payloads"""

# non dict
assert validate_payload("hi") is False

build = defs.NORMAL_BUILD.copy()
# missing package
del build["package"]
assert validate_payload(build) is False
assert parse_alloc_spec("@29.2 +json+native+treesitter%[email protected]") == {}

build = defs.NORMAL_BUILD.copy()
# missing compiler
del build["compiler"]
assert validate_payload(build) is False
assert parse_alloc_spec("[email protected] +json+native+treesitter") == {}

# variants not spaced correctly
assert parse_alloc_spec("[email protected]+json+native+treesitter%[email protected]") == {}

# name and version are strings in the package and compiler
for key in ["name", "version"]:
for field in ["package", "compiler"]:
build = defs.NORMAL_BUILD.copy()
build[field][key] = 123
assert validate_payload(build) is False
# missing versions
assert parse_alloc_spec("[email protected] +json+native+treesitter%gcc@") == {}
assert parse_alloc_spec("emacs@ +json+native+treesitter%[email protected]") == {}

# invalid variants
build = defs.NORMAL_BUILD.copy()
build["package"]["variants"] = "+++++"
assert validate_payload(build) is False
assert parse_alloc_spec("[email protected] this_is_not_a_thing%[email protected]") == {}

0 comments on commit 106d3e1

Please sign in to comment.