diff --git a/test/data/unicode_example.csv b/test/data/unicode_example.csv new file mode 100644 index 0000000..8846970 --- /dev/null +++ b/test/data/unicode_example.csv @@ -0,0 +1,2 @@ +Jamés,Likes,Coffee +Анна,Likes,Tea \ No newline at end of file diff --git a/test/fuzz.py b/test/fuzz.py new file mode 100644 index 0000000..d9ae212 --- /dev/null +++ b/test/fuzz.py @@ -0,0 +1,141 @@ +import random +from kgl.graph import KnowledgeGraph +import os +import lark +from nltk.corpus import stopwords +from nltk import download as nltk_download + +print("Downloading stopwords...") +nltk_download("stopwords") + +print("Running tests...") + +test_dir = os.path.dirname(os.path.abspath(__file__)) + +kg = KnowledgeGraph().load_from_csv(os.path.join(test_dir, "data", "example.csv")) + +seeds = [ + "{ coffee -> is }", + "{ coffee -> is -> coffee }", + "{ tea -> type-of }", + "{ James -> favourite-songs } + { Taylor -> favourite-songs }", + "{ coffee } INTERSECTION { tea }", + "{ coffee } - { tea }", + "{ coffee -> is } - { tea -> is }", +] + +seed_templates = { + # query structure, number of words to generate + "single_query": ("{ %s }", 1), + "single_query_with_two_word_clause": ("{ %s %s -> %s }", 3), + "set_union": ("{ %s } + { %s }", 2), + "set_intersection": ("{ %s } INTERSECTION { %s }", 2), + "set_difference": ("{ %s } - { %s }", 2), +} + +supported_languages = stopwords.fileids() + +character_ranges = { + file_id: list(stopwords.words(file_id)) for file_id in supported_languages +} +character_ranges["unicode"] = [chr(i) for i in range(0x0000, 0x10FFFF)] +character_ranges["numbers"] = [str(random.randint(1, 10_000_000)) for _ in range(1000)] +character_ranges["long_numbers"] = [ + str(random.randint(10_000_000_000_000, 10_000_000_000_000_000)) for _ in range(1000) +] + +supported_languages.append("unicode") + +CHANGE_RATE = 0.1 +ITERATIONS_PER_SEED = 100 + + +def change(): + return ( + random.choices( + population=[["do not change"], ["change"]], + weights=[1 - CHANGE_RATE, CHANGE_RATE], + k=1, + )[0][0] + == "change" + ) + + +def mutate( + seed, characters_to_skip=["{", "}", "-", ">", "<"], character_range="unicode" +): + seed = list(seed) + + for i in range(len(seed)): + if change() and i not in characters_to_skip: + seed[i] = random.choice(character_ranges[character_range]) + + return "".join(seed) + + +def get_random_word_from_random_language(): + return random.choice(character_ranges[random.choice(supported_languages)]) + + +def generate_query_from_scratch(template, num_words_to_generate): + return template % tuple( + get_random_word_from_random_language() for _ in range(num_words_to_generate) + ) + + +def execute_query(query): + try: + kg.evaluate(query) + except (lark.exceptions.UnexpectedCharacters, ValueError): + # In this case, the program has successfully detected an invalid input. + return False + except Exception as e: + # In this case, an unknown error has been raised. + return True + + +def test_fuzzer(): + failed_tests = [] + + tests = [] + + tests.extend([mutate(seed) for seed in seeds for _ in range(ITERATIONS_PER_SEED)]) + tests.extend( + [mutate(seed, []) for seed in seeds for _ in range(ITERATIONS_PER_SEED)] + ) + + for character_range in character_ranges: + tests.extend( + [ + mutate(seed, [], character_range) + for seed in seeds + for _ in range(ITERATIONS_PER_SEED) + ] + ) + + tests.extend( + [ + generate_query_from_scratch(template, num_words) + for template, num_words in seed_templates.values() + for _ in range(ITERATIONS_PER_SEED) + ] + ) + + test_count = len(tests) + + for test in tests: + if execute_query(test): + failed_tests.append(test) + if __name__ != "__main__": + print(test) + assert False + + failed_tests_count = len(failed_tests) + + print( + f"Ran {test_count} tests with {failed_tests_count} failures ({(test_count - failed_tests_count) / test_count * 100}% success rate)" + ) + + +if __name__ == "__main__": + test_fuzzer() diff --git a/test/test.py b/test/test.py index 580cb3b..85ea986 100644 --- a/test/test.py +++ b/test/test.py @@ -12,25 +12,41 @@ def kg(): kg = KnowledgeGraph().load_from_csv(os.path.join(test_dir, "data", "example.csv")) return kg +@pytest.fixture +def unicode_kg(): + from kgl import KnowledgeGraph + + kg = KnowledgeGraph().load_from_csv(os.path.join(test_dir, "data", "unicode_example.csv")) + return kg + def test_evaluate(kg): assert kg.evaluate("{ James }")[0] == [{"Likes": ["Coffee"]}] assert kg.evaluate("{ James -> Likes }")[0] == [["Coffee"]] assert kg.evaluate("{ James <-> Coffee }")[0] == [["James", ("Coffee", "Likes")]] +def test_unicode_query(unicode_kg): + assert unicode_kg.evaluate("{ Jamés }")[0] == [{"Likes": ["Coffee"]}] + assert unicode_kg.evaluate("{ Анна -> Likes }")[0] == [["Tea"]] + assert unicode_kg.evaluate("{ Анна <-> Tea }")[0] == [["Анна", ("Tea", "Likes")]] + + def test_returns_query_time(kg): _, time_taken = kg.evaluate("{ James }") - + assert time_taken > 0 + def test_evaluate_operations(kg): assert kg.evaluate("{ James -> Likes }#")[0] == 1 assert kg.evaluate("{ James -> Likes }?")[0] == True assert kg.evaluate("{ James <-> Coffee }?")[0] == True + def test_add_node_with_query(kg): assert kg.evaluate("{evermore, is, amazing}")[0] == {"is": ["amazing"]} + def test_adding_valid_triple_with_list_value(kg): kg.add_node(("James", "Likes", ["Terraria", "Cats"])) result = kg.evaluate("{ James -> Likes }")[0] @@ -120,3 +136,25 @@ def test_read_from_json(kg): ) assert kg.evaluate("{ James }")[0] == [{"Likes": ["Coffee"]}] assert kg.evaluate("{ Anna }")[0] == [{"Likes": ["Tea"]}] + + +def test_max_query_call_invocation_error(kg): + from kgl import QueryDepthExceededError + + # length of this will be 150 calls, over default max of 50 + query = "{" + ("coffee -> is -> coffee" * 50) + "}" + + with pytest.raises(QueryDepthExceededError): + kg.evaluate(query) + + +def test_incomplete_queries(kg): + with pytest.raises(ValueError): + kg.evaluate("{ James") + kg.evaluate("{ James -> Likes") + kg.evaluate("{") + kg.evaluate("}") + kg.evaluate("{{ James -> Likes }") + kg.evaluate("{ James -> Likes }}") + kg.evaluate("{ James -> Likes }{") + kg.evaluate("{ James -> Likes } + ")