Skip to content

Commit

Permalink
#2739: Remove test scripts and replace them with a script_factory fix…
Browse files Browse the repository at this point in the history
…ture
  • Loading branch information
sergisiso committed Nov 7, 2024
1 parent 5d05863 commit bd9e340
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 395 deletions.
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ extend-ignore = E266
# errors on specific lines with # noqa: <error> in a few other files.
exclude =
.git,__pycache__,conf.py,
# Contain deliberate Python errors for testing purposes
src/psyclone/tests/test_files/dynamo0p3/error_syntax.py,
src/psyclone/tests/test_files/dynamo0p3/error_import.py,
src/psyclone/tests/test_files/dynamo0p3/runtime_error.py,
# Contain multiple imports flagged with F401 imported but unused
__init__.py,
tutorial/practicals/LFRic/single_node/1_openmp/omp_script.py,
Expand Down
192 changes: 126 additions & 66 deletions src/psyclone/tests/generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,32 @@
"test_files", "gocean1p0")


def delete_module(modname):
'''A function to remove a module from Python's internal modules
list. This is useful as some tests affect others by importing
modules.
@pytest.fixture(name="script_factory", scope="function")
def create_script_factor(tmpdir):
''' Fixture that creates a psyclone optimisation script given the string
representing the body of the script:
script_path = script_factory("def trans(psyir):\n pass")
It has a 'function' scope and a tear down section because using a script
imports the file and this is kept in the python interpreter state, so we
delete it for future tests.
'''
del modules[modname]
tmpfile = os.path.join(tmpdir, "test_script.py")

def populate_script(string):
with open(tmpfile, 'w+', encoding="utf8") as script:
script.write(string)
return tmpfile

yield populate_script
# Tear down section executed after each test that uses the fixture
# If the created script was used, then its module (file) was imported
# into the interpreter runtime, we need to make sure it is deleted
modname = "test_script"
if modname in modules:
del modules[modname]
for mod in modules.values():
try:
delattr(mod, modname)
Expand Down Expand Up @@ -142,66 +161,77 @@ def test_script_file_wrong_extension():
"extension" in str(error.value))


def test_script_invalid_content():
def test_script_invalid_content(script_factory):
'''Checks that load_script() in generator.py raises the expected
exception when a script file does not contain valid python. This
test uses the generate() function to call load_script as this is
a simple way to create its required arguments.
'''
with pytest.raises(Exception) as error_syntax:
error_syntax = script_factory("""
this is invalid python
""")
with pytest.raises(Exception) as err:
_, _ = generate(
os.path.join(BASE_PATH, "dynamo0p3", "1_single_invoke.f90"),
api="lfric", script_name=os.path.join(BASE_PATH, "dynamo0p3",
"error_syntax.py"))
assert ("invalid syntax (error_syntax.py, line 5)"
in str(error_syntax.value))
api="lfric", script_name=error_syntax)
assert ("invalid syntax (test_script.py, line 2)" in str(err.value))

with pytest.raises(Exception) as error_import:
error_import = script_factory("""
import non_existent
""")
with pytest.raises(Exception) as err:
_, _ = generate(
os.path.join(BASE_PATH, "dynamo0p3", "1_single_invoke.f90"),
api="lfric", script_name=os.path.join(BASE_PATH, "dynamo0p3",
"error_import.py"))
assert "No module named 'non_existent'" in str(error_import.value)
api="lfric", script_name=error_import)
assert "No module named 'non_existent'" in str(err.value)


def test_script_invalid_content_runtime():
def test_script_invalid_content_runtime(script_factory):
'''Checks that load_script() function in generator.py raises the
expected exception when a script file contains valid python
syntactically but produces a runtime exception. This test uses the
generate() function to call load_script as this is a simple way
to create its required arguments.
'''
runtime_error = script_factory("""
def trans(psyir):
# this will produce a runtime error as b has not been assigned
psyir = b
""")
with pytest.raises(Exception) as error:
_, _ = generate(
os.path.join(BASE_PATH, "dynamo0p3", "1_single_invoke.f90"),
api="lfric",
script_name=os.path.join(
BASE_PATH, "dynamo0p3", "runtime_error.py"))
api="lfric", script_name=runtime_error)
assert "name 'b' is not defined" in str(error.value)


def test_script_no_trans():
def test_script_no_trans(script_factory):
'''Checks that load_script() function in generator.py raises the
expected exception when a script file does not contain a trans()
function. This test uses the generate() function to call
load_script as this is a simple way to create its required
arguments.
'''
no_trans_script = script_factory("""
def nottrans(psyir):
pass
def tran():
pass
""")
with pytest.raises(GenerationError) as error:
_, _ = generate(
os.path.join(BASE_PATH, "dynamo0p3", "1_single_invoke.f90"),
api="lfric",
script_name=os.path.join(
BASE_PATH, "dynamo0p3", "no_trans.py"))
api="lfric", script_name=no_trans_script)
assert ("attempted to use specified PSyclone transformation module "
"'no_trans' but it does not contain a callable 'trans' function"
"'test_script' but it does not contain a callable 'trans' function"
in str(error.value))


def test_script_no_trans_alg(capsys):
def test_script_no_trans_alg(capsys, script_factory):
'''Checks that load_script() function in generator.py does not raise
an exception when a script file does not contain a trans_alg()
function as these are optional. At the moment this function is
Expand All @@ -210,17 +240,17 @@ def test_script_no_trans_alg(capsys):
its required arguments.
'''
no_alg_script = script_factory("def trans(psyir):\n pass")
_, _ = generate(
os.path.join(BASE_PATH, "gocean1p0", "single_invoke.f90"),
api="gocean",
script_name=os.path.join(BASE_PATH, "gocean1p0", "script.py"))
api="gocean", script_name=no_alg_script)

# The legacy script deprecation warning is not printed in this case
captured = capsys.readouterr()
assert "Deprecation warning:" not in captured.err


def test_script_with_legacy_trans_signature(capsys):
def test_script_with_legacy_trans_signature(capsys, script_factory):
'''Checks that load_script() function in generator.py does not raise
an exception when a script file uses the legacy trans signature.
Expand All @@ -230,10 +260,16 @@ def test_script_with_legacy_trans_signature(capsys):
This will eventually be deprecated.
'''
legacy_script = script_factory("""
def trans(psy):
# The following are backwards-compatible expressions with legacy scripts
_ = psy.invokes.invoke_list
_ = psy.invokes.names
return psy
""")
_, _ = generate(
os.path.join(BASE_PATH, "gocean1p0", "single_invoke.f90"),
api="gocean",
script_name=os.path.join(BASE_PATH, "gocean1p0", "legacy_script.py"))
api="gocean", script_name=legacy_script)

# The deprecation warning message was printed
captured = capsys.readouterr()
Expand Down Expand Up @@ -391,16 +427,23 @@ def test_no_script_gocean():
assert "MODULE psy_single_invoke_test" in str(psy)


def test_script_gocean():
def test_script_gocean(script_factory):
'''Test that the generate function in generator.py returns
successfully if a script (containing both trans_alg() and trans()
functions) is specified.
'''
alg_script = script_factory("""
def trans_alg(psyir):
pass
def trans(psyir):
pass
""")

_, _ = generate(
os.path.join(BASE_PATH, "gocean1p0", "single_invoke.f90"),
api="gocean",
script_name=os.path.join(BASE_PATH, "gocean1p0", "alg_script.py"))
api="gocean", script_name=alg_script)


def test_profile_gocean():
Expand All @@ -417,41 +460,43 @@ def test_profile_gocean():
Profiler._options = []


def test_script_attr_error():
def test_script_attr_error(script_factory):
'''Checks that generator.py raises an appropriate error when a script
file contains a trans() function which raises an attribute
error. This is what we previously used to check for a script file
not containing a trans() function.
file contains a trans() function which raises an attribute error.
'''
error_script = script_factory("""
from psyclone.psyGen import Loop
from psyclone.transformations import ColourTrans
def trans(psyir):
''' A valid trans function which produces an attribute error as
we have mistyped apply()'''
ctrans = ColourTrans()
for child in psyir.walk(Loop):
if isinstance(child, Loop) and child.field_space != "w3":
ctrans.appy(child)
""")
with pytest.raises(Exception) as excinfo:
_, _ = generate(os.path.join(BASE_PATH, "dynamo0p3",
"1_single_invoke.f90"),
api="lfric",
script_name=os.path.join(BASE_PATH,
"dynamo0p3",
"error_trans.py"))
api="lfric", script_name=error_script)
assert 'object has no attribute' in str(excinfo.value)


def test_script_null_trans():
def test_script_null_trans(script_factory):
'''Checks that generator.py works correctly when the trans() function
in a valid script file does no transformations (it simply passes
input to output). In this case the valid script file has an
explicit path and must therefore exist at this location.
in a valid script file does no transformations. In this case the valid
script file has an absolut path and must therefore exist at this location.
'''
empty_script = script_factory("def trans(psyir):\n pass")
alg1, psy1 = generate(os.path.join(BASE_PATH, "dynamo0p3",
"1_single_invoke.f90"),
api="lfric")
alg2, psy2 = generate(os.path.join(BASE_PATH, "dynamo0p3",
"1_single_invoke.f90"),
api="lfric",
script_name=os.path.join(BASE_PATH,
"dynamo0p3",
"null_trans.py"))
# remove module so we do not affect any following tests
delete_module("null_trans")
api="lfric", script_name=empty_script)
# we need to remove the first line before comparing output as
# this line is an instance specific header
assert '\n'.join(str(alg1).split('\n')[1:]) == \
Expand All @@ -460,7 +505,7 @@ def test_script_null_trans():
'\n'.join(str(psy2).split('\n')[1:])


def test_script_null_trans_relative():
def test_script_null_trans_relative(script_factory):
'''Checks that generator.py works correctly when the trans() function
in a valid script file does no transformations (it simply passes
input to output). In this case the valid script file contains no
Expand All @@ -470,13 +515,15 @@ def test_script_null_trans_relative():
alg1, psy1 = generate(os.path.join(BASE_PATH, "dynamo0p3",
"1_single_invoke.f90"),
api="lfric")
# set up the python path so that null_trans.py can be found
os.sys.path.append(os.path.join(BASE_PATH, "dynamo0p3"))
empty_script = script_factory("def trans(psyir):\n pass")
basename = os.path.basename(empty_script)
path = os.path.dirname(empty_script)
# Set the script directory in the PYTHONPATH
os.sys.path.append(path)
alg2, psy2 = generate(os.path.join(BASE_PATH, "dynamo0p3",
"1_single_invoke.f90"),
api="lfric", script_name="null_trans.py")
# remove imported module so we do not affect any following tests
delete_module("null_trans")
api="lfric", script_name=basename)
# Remove the path from PYTHONPATH
os.sys.path.pop()
# we need to remove the first line before comparing output as
# this line is an instance specific header
Expand All @@ -485,12 +532,22 @@ def test_script_null_trans_relative():
assert str(psy1) == str(psy2)


def test_script_trans_dynamo0p3():
def test_script_trans_dynamo0p3(script_factory):
'''Checks that generator.py works correctly when a transformation is
provided as a script, i.e. it applies the transformations
correctly. We use loop fusion as an example.
correctly.
'''
fuse_loop_script = script_factory("""
from psyclone.domain.lfric.transformations import LFRicLoopFuseTrans
def trans(psyir):
module = psyir.children[0]
schedule = [x for x in module.children if x.name == "invoke_0"][0]
loop1 = schedule.children[4]
loop2 = schedule.children[5]
transform = LFRicLoopFuseTrans()
transform.apply(loop1, loop2)
""")
root_path = os.path.dirname(os.path.abspath(__file__))
base_path = os.path.join(root_path, "test_files", "dynamo0p3")
# First loop fuse explicitly (without using generator.py)
Expand All @@ -506,10 +563,7 @@ def test_script_trans_dynamo0p3():
generated_code_1 = psy.gen
# Second loop fuse using generator.py and a script
_, generated_code_2 = generate(parse_file, api="lfric",
script_name=os.path.join(
base_path, "loop_fuse_trans.py"))
# remove module so we do not affect any following tests
delete_module("loop_fuse_trans")
script_name=fuse_loop_script)
# third - check that the results are the same ...
assert str(generated_code_1) == str(generated_code_2)

Expand Down Expand Up @@ -1347,7 +1401,7 @@ def test_no_script_lfric_new(monkeypatch):
assert "use _psyclone_builtins" not in alg


def test_script_lfric_new(monkeypatch):
def test_script_lfric_new(monkeypatch, script_factory):
'''Test that the generate function in generator.py returns
successfully if a script (containing both trans_alg() and trans()
functions) is specified. This test uses the new PSyIR approach to
Expand All @@ -1356,11 +1410,17 @@ def test_script_lfric_new(monkeypatch):
monkeypatching.
'''
alg_script = script_factory("""
def trans_alg(psyir):
pass
def trans(psyir):
pass
""")
monkeypatch.setattr(generator, "LFRIC_TESTING", True)
alg, _ = generate(
os.path.join(BASE_PATH, "dynamo0p3", "1_single_invoke.f90"),
api="lfric",
script_name=os.path.join(BASE_PATH, "dynamo0p3", "alg_script.py"))
api="lfric", script_name=alg_script)
# new call replaces invoke
assert "use single_invoke_psy, only : invoke_0_testkern_type" in alg
assert "call invoke_0_testkern_type(a, f1, f2, m1, m2)" in alg
Expand Down
Loading

0 comments on commit bd9e340

Please sign in to comment.