Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose function evaluation from KLLVM bindings #4242

Merged
merged 2 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyk/src/pyk/kllvm/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import ModuleType
from typing import Any

from .ast import Pattern, Sort
from .ast import CompositePattern, Pattern, Sort


class Runtime:
Expand Down Expand Up @@ -43,6 +43,11 @@ def simplify_bool(self, pattern: Pattern) -> bool:
self._module.free_all_gc_memory()
return res

def evaluate(self, pattern: CompositePattern) -> Pattern:
res = self._module.evaluate_function(pattern)
self._module.free_all_gc_memory()
return res


class Term:
_block: Any # module.InternalTerm
Expand Down
49 changes: 49 additions & 0 deletions pyk/src/tests/integration/kllvm/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import pyk.kllvm.load # noqa: F401
from pyk.kllvm import parser
from pyk.testing import RuntimeTest

from ..utils import K_FILES

if TYPE_CHECKING:
from pyk.kllvm.runtime import Runtime


EVALUATE_TEST_DATA = (
('1 + 2', r"""Lbl'UndsPlus'Int'Unds'{}(\dv{SortInt{}}("1"),\dv{SortInt{}}("2"))""", r'\dv{SortInt{}}("3")'),
('1 * 2', r"""Lbl'UndsStar'Int'Unds'{}(\dv{SortInt{}}("1"),\dv{SortInt{}}("2"))""", r'\dv{SortInt{}}("2")'),
(
'1 + (2 * 3)',
r"""
Lbl'UndsPlus'Int'Unds'{}(
\dv{SortInt{}}("1"),
Lbl'UndsStar'Int'Unds'{}(\dv{SortInt{}}("2"), \dv{SortInt{}}("3"))
)
""",
r'\dv{SortInt{}}("7")',
),
)


class TestEvaluate(RuntimeTest):
KOMPILE_MAIN_FILE = K_FILES / 'imp.k'

@pytest.mark.parametrize(
'test_id,pattern_text,expected',
EVALUATE_TEST_DATA,
ids=[test_id for test_id, *_ in EVALUATE_TEST_DATA],
)
def test_simplify(self, runtime: Runtime, test_id: str, pattern_text: str, expected: str) -> None:
# Given
pattern = parser.parse_pattern(pattern_text)

# When
actual = runtime.evaluate(pattern)

# Then
assert str(actual) == expected
Loading