Skip to content

Commit

Permalink
Code style changes
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Gu <[email protected]>
  • Loading branch information
tylergu committed Jan 21, 2024
1 parent 4c646e5 commit be68d37
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 63 deletions.
114 changes: 68 additions & 46 deletions acto/input/generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""This module provides a decorator for generating test cases for a schema and
a function to get all test cases for a schema."""

from collections import namedtuple
from dataclasses import dataclass
from typing import Callable, Literal, Optional

from acto.input.k8s_schemas import KubernetesObjectSchema, KubernetesSchema
from acto.input.k8s_schemas import KubernetesObjectSchema
from acto.input.testcase import TestCase
from acto.schema import (
AnyOfSchema,
Expand All @@ -19,21 +19,34 @@
StringSchema,
)

TestGenerator = namedtuple(
"TestGeneratorObject",
[
"k8s_schema_name",
"field_name",
"field_type",
"paths",
"priority",
"func",
],
)

@dataclass
class TestGenerator:
"""A test generator object"""

k8s_schema_name: Optional[str]
field_name: Optional[str]
field_type: Optional[
Literal[
"AnyOf",
"Array",
"Boolean",
"Integer",
"Number",
"Object",
"OneOf",
"Opaque",
"String",
]
]
paths: Optional[list[str]]
priority: int
func: Callable[[BaseSchema], list[TestCase]]


# singleton
# global variable for registered test generators
test_generators: TestGenerator = []
TEST_GENERATORS: list[TestGenerator] = []


def generator(
Expand Down Expand Up @@ -79,75 +92,84 @@ def wrapped_func(func: Callable[[BaseSchema], list[TestCase]]):
priority,
func,
)
test_generators.append(gen_obj)
TEST_GENERATORS.append(gen_obj)
return func

return wrapped_func


def get_testcases(
schema: BaseSchema,
matched_schemas: [tuple[BaseSchema, KubernetesSchema]],
) -> list[tuple[list[str], TestCase]]:
matched_schemas: list[tuple[BaseSchema, KubernetesObjectSchema]],
) -> list[tuple[list[str], list[TestCase]]]:
"""Get all test cases for a schema from registered test generators"""
matched_schemas: dict[str, KubernetesObjectSchema] = {
matched_schemas_dict: dict[str, KubernetesObjectSchema] = {
"/".join(s.path): m for s, m in matched_schemas
}

def get_testcases_helper(schema: BaseSchema, field_name: Optional[str]):
def get_testcases_helper(
schema: BaseSchema, field_name: Optional[str]
) -> list[tuple[list[str], list[TestCase]]]:
# print(schema_name, schema.path, type(schema))
test_cases = []
generator_candidates = []
test_cases: list[tuple[list[str], list[TestCase]]] = []
generator_candidates: list[TestGenerator] = []
# check paths
path_str = "/".join(schema.path)
matched_schema = matched_schemas.get(path_str)
for test_gen in test_generators:
matched_schema = matched_schemas_dict.get(path_str)
for test_generator in TEST_GENERATORS:
# check paths
for path in test_gen.paths or []:
for path in test_generator.paths or []:
if path_str.endswith(path):
generator_candidates.append(test_gen)
generator_candidates.append(test_generator)
continue

# check field name
if (
test_gen.field_name is not None
and test_gen.field_name == field_name
test_generator.field_name is not None
and test_generator.field_name == field_name
):
generator_candidates.append(test_gen)
generator_candidates.append(test_generator)
continue

# check k8s schema name
if (
test_gen.k8s_schema_name is not None
test_generator.k8s_schema_name is not None
and matched_schema is not None
and matched_schema.k8s_schema_name.endswith(
test_gen.k8s_schema_name
test_generator.k8s_schema_name
)
):
generator_candidates.append(test_gen)
generator_candidates.append(test_generator)
continue

# check type
matching_types = {
"AnyOf": AnyOfSchema,
"Array": ArraySchema,
"Boolean": BooleanSchema,
"Integer": IntegerSchema,
"Number": NumberSchema,
"Object": ObjectSchema,
"OneOf": OneOfSchema,
"Opaque": OpaqueSchema,
"String": StringSchema,
}
if schema_type_obj := matching_types.get(test_gen.field_type):
if isinstance(schema, schema_type_obj):
generator_candidates.append(test_gen)
if test_generator.field_type is not None:
matching_types = {
"AnyOf": AnyOfSchema,
"Array": ArraySchema,
"Boolean": BooleanSchema,
"Integer": IntegerSchema,
"Number": NumberSchema,
"Object": ObjectSchema,
"OneOf": OneOfSchema,
"Opaque": OpaqueSchema,
"String": StringSchema,
}
if schema_type_obj := matching_types.get(
test_generator.field_type
):
if isinstance(schema, schema_type_obj):
generator_candidates.append(test_generator)
else:
raise ValueError(

Check warning on line 164 in acto/input/generator.py

View workflow job for this annotation

GitHub Actions / coverage-report

Missing coverage

Missing coverage on line 164
f"Unknown schema type: {test_generator.field_type}"
)

# sort by priority
generator_candidates.sort(key=lambda x: x.priority, reverse=True)
if len(generator_candidates) > 0:
test_cases.append(
(schema.path, generator_candidates[0].func(schema))
(schema.path, generator_candidates[0].func(schema)),
)

# check sub schemas
Expand Down
34 changes: 17 additions & 17 deletions test/integration_tests/test_testcase_generator_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import yaml

from acto.input.generator import generator, get_testcases, test_generators
from acto.input.generator import TEST_GENERATORS, generator, get_testcases
from acto.input.k8s_schemas import K8sSchemaMatcher
from acto.input.testcase import TestCase
from acto.schema import extract_schema
Expand Down Expand Up @@ -37,12 +37,12 @@ def setUpClass(cls):
cls.matches = schema_matcher.find_matched_schemas(cls.spec_schema)

def test_path_suffix(self):
test_generators.clear()
TEST_GENERATORS.clear()
generator(paths=["serviceAccountToken/expirationSeconds"])(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 1)

test_generators.clear()
TEST_GENERATORS.clear()
generator(
paths=[
"serviceAccountToken/expirationSeconds",
Expand All @@ -53,75 +53,75 @@ def test_path_suffix(self):
self.assertEqual(len(testcases), 2)

def test_k8s_schema_name(self):
test_generators.clear()
TEST_GENERATORS.clear()
generator(k8s_schema_name="v1.NodeAffinity")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 2)

test_generators.clear()
TEST_GENERATORS.clear()
generator(k8s_schema_name="HTTPHeader")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 15)

def test_field_name(self):
test_generators.clear()
TEST_GENERATORS.clear()
generator(field_name="ports")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 4)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_name="image")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 5)

def test_field_type(self):
test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="AnyOf")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 38)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="Array")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 173)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="Boolean")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 73)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="Integer")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 106)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="Number")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 106)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="Object")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 368)

# test_generators.clear()
# TEST_GENERATORS.clear()
# generator(field_type="OneOf")(gen)
# testcases = get_testcases(self.spec_schema, self.matches)
# self.assertEqual(len(testcases), 0)

# test_generators.clear()
# TEST_GENERATORS.clear()
# generator(field_type="Opaque")(gen)
# testcases = get_testcases(self.spec_schema, self.matches)
# self.assertEqual(len(testcases), 0)

test_generators.clear()
TEST_GENERATORS.clear()
generator(field_type="String")(gen)
testcases = get_testcases(self.spec_schema, self.matches)
self.assertEqual(len(testcases), 550)

def test_priority(self):
test_generators.clear()
TEST_GENERATORS.clear()

@generator(field_type="Integer", priority=0)
def gen0(_):
Expand Down

0 comments on commit be68d37

Please sign in to comment.