Skip to content

Commit

Permalink
#2768 refactor fix and apply to abs2codetrans too
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Nov 6, 2024
1 parent 6ac8e1f commit 02cd43e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 33 deletions.
29 changes: 19 additions & 10 deletions src/psyclone/psyir/transformations/intrinsics/abs2code_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def __init__(self):
super().__init__()
self._intrinsic = IntrinsicCall.Intrinsic.ABS

def validate(self, node, options=None):
'''
Check that it is safe to apply the transformation to the supplied node.
:param node: the SIGN call to transform.
:type node: :py:class:`psyclone.psyir.nodes.IntrinsicCall`
:param options: any of options for the transformation.
:type options: dict[str, Any]
'''
super().validate(node, options=options)
super()._validate_scalar_arg(node)

def apply(self, node, options=None):
'''Apply the ABS intrinsic conversion transformation to the specified
node. This node must be an ABS UnaryOperation. The ABS
Expand Down Expand Up @@ -112,16 +125,12 @@ def apply(self, node, options=None):
symbol_table = node.scope.symbol_table
assignment = node.ancestor(Assignment)

# Create two temporary variables. There is an assumption here
# that the ABS Operator returns a PSyIR real type. This might
# not be what is wanted (e.g. the args might PSyIR integers),
# or there may be errors (arguments are of different types)
# but this can't be checked as we don't have the appropriate
# methods to query nodes (see #658).
# Create two temporary variables.
result_type = node.arguments[0].datatype
symbol_res_var = symbol_table.new_symbol(
"res_abs", symbol_type=DataSymbol, datatype=REAL_TYPE)
"res_abs", symbol_type=DataSymbol, datatype=result_type)
symbol_tmp_var = symbol_table.new_symbol(
"tmp_abs", symbol_type=DataSymbol, datatype=REAL_TYPE)
"tmp_abs", symbol_type=DataSymbol, datatype=result_type)

# Replace operation with a temporary (res_X).
node.replace_with(Reference(symbol_res_var))
Expand All @@ -134,7 +143,7 @@ def apply(self, node, options=None):

# if condition: tmp_var>0.0
lhs = Reference(symbol_tmp_var)
rhs = Literal("0.0", REAL_TYPE)
rhs = Literal("0", result_type)
if_condition = BinaryOperation.create(BinaryOperation.Operator.GT,
lhs, rhs)

Expand All @@ -146,7 +155,7 @@ def apply(self, node, options=None):
# else_body: res_var=-1.0*tmp_var
lhs = Reference(symbol_res_var)
lhs_child = Reference(symbol_tmp_var)
rhs_child = Literal("-1.0", REAL_TYPE)
rhs_child = Literal("-1", result_type)
rhs = BinaryOperation.create(BinaryOperation.Operator.MUL, lhs_child,
rhs_child)
else_body = [Assignment.create(lhs, rhs)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
import abc
from psyclone.psyGen import Transformation
from psyclone.psyir.nodes import Assignment, IntrinsicCall
from psyclone.psyir.transformations.transformation_error import \
TransformationError
from psyclone.psyir.symbols import ArrayType, ScalarType
from psyclone.psyir.transformations.transformation_error import (
TransformationError)


class Intrinsic2CodeTrans(Transformation, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -95,3 +96,31 @@ def validate(self, node, options=None):
f"Error in {self.name} transformation. This transformation "
f"requires the operator to be part of an assignment "
f"statement, but no such assignment was found.")

def _validate_scalar_arg(self, node, options=None):
'''
Check that the argument to the intrinsic is a scalar of known type.
:param node: the target intrinsic call.
:type node: :py:class:`psyclone.psyir.nodes.IntrinsicCall`
:param options: any options for the transformation.
:type options: dict[str, Any]
:raises TransformationError: if the supplied SIGN call operates on
an argument of array type or unsupported/unresolved type.
'''
result_type = node.arguments[0].datatype
if isinstance(result_type, ArrayType):
raise TransformationError(
f"Transformation {self.name} cannot be applied to SIGN calls "
f"which have an array as argument but "
f"'{node.arguments[0].debug_string()}' is of array type. It "
f"may be possible to use the ArrayAssignment2LoopsTrans "
f"to convert this to a scalar argument.")
if not isinstance(result_type, ScalarType):
raise TransformationError(
f"Transformation {self.name} cannot be applied to "
f"'{node.debug_string()} because the type of the "
f"argument '{node.arguments[0].debug_string()}' is "
f"{result_type}")
25 changes: 4 additions & 21 deletions src/psyclone/psyir/transformations/intrinsics/sign2code_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@
'''
from psyclone.psyir.transformations.intrinsics.intrinsic2code_trans import (
Intrinsic2CodeTrans)
from psyclone.psyir.transformations import (
Abs2CodeTrans, TransformationError)
from psyclone.psyir.transformations import Abs2CodeTrans
from psyclone.psyir.nodes import (
BinaryOperation, Assignment, Reference, Literal, IfBlock, IntrinsicCall)
from psyclone.psyir.symbols import ArrayType, DataSymbol, ScalarType
from psyclone.psyir.symbols import DataSymbol


class Sign2CodeTrans(Intrinsic2CodeTrans):
Expand Down Expand Up @@ -81,28 +80,12 @@ def validate(self, node, options=None):
:param node: the SIGN call to transform.
:type node: :py:class:`psyclone.psyir.nodes.IntrinsicCall`
:param options: any of options for the transformation.
:param options: any options for the transformation.
:type options: dict[str, Any]
:raises TransformationError: if the supplied SIGN call operates on
an argument of array type or unsupported/unresolved type.
'''
super().validate(node, options=options)
result_type = node.arguments[0].datatype
if isinstance(result_type, ArrayType):
raise TransformationError(
f"Transformation {self.name} cannot be applied to SIGN calls "
f"which have an array as argument but "
f"'{node.arguments[0].debug_string()}' is of array type. It "
f"may be possible to use the ArrayAssignment2LoopsTrans "
f"to convert this to a scalar argument.")
if not isinstance(result_type, ScalarType):
raise TransformationError(
f"Transformation {self.name} cannot be applied to "
f"'{node.debug_string()} because the type of the "
f"argument '{node.arguments[0].debug_string()}' is "
f"{result_type}")
super()._validate_scalar_arg(node)

def apply(self, node, options=None):
'''Apply the SIGN intrinsic conversion transformation to the specified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,26 @@ def test_correct_2sign(tmpdir, fortran_writer):
assert Compile(tmpdir).string_compiles(result)


def test_sign_with_integer_arg(fortran_reader, fortran_writer, tmpdir):
'''
Test that the transformation works when the SIGN argument is an
integer.
'''
code = '''\
program test_prog
integer, parameter :: idef = kind(1)
integer(idef) :: my_arg, other_arg
my_arg = SIGN(my_arg, other_arg)
end program test_prog'''
psyir = fortran_reader.psyir_from_source(code)
trans = Sign2CodeTrans()
sgn_call = psyir.walk(IntrinsicCall)
trans.apply(sgn_call)
result = fortran_writer(psyir)
assert "dadadada" in result


def test_sign_of_unknown_type(fortran_reader):
'''
Check that we refuse to apply the transformation if the argument
Expand Down

0 comments on commit 02cd43e

Please sign in to comment.