Skip to content

Commit

Permalink
Initial code changes (not tidy), modified and added some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DrTVockerodtMO committed Nov 18, 2024
1 parent d865146 commit b7f4844
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 75 deletions.
170 changes: 111 additions & 59 deletions src/psyclone/psyir/transformations/intrinsics/matmul2code_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# -----------------------------------------------------------------------------
# Author: R. W. Ford, STFC Daresbury Laboratory
# Modified: S. Siso, A. R. Porter and N. Nobre, STFC Daresbury Lab
# T. Vockerodt, Met Office

'''Module providing a transformation from a PSyIR MATMUL operator to
PSyIR code. This could be useful if the MATMUL operator is not
Expand All @@ -51,7 +52,7 @@
Intrinsic2CodeTrans)


def _create_matrix_ref(matrix_symbol, loop_idx_symbols, other_dims):
def _create_matrix_ref(matrix_symbol, loop_idx_symbols, other_dims, order):
'''
Utility function to create a reference to a matrix element being
accessed using one or more loop indices followed by zero or more
Expand All @@ -65,13 +66,24 @@ def _create_matrix_ref(matrix_symbol, loop_idx_symbols, other_dims):
not being looped over. (If there are none then this \
must be an empty list.)
:type other_dims: List[:py:class:`psyclone.psyir.nodes.ExpressionNode`]
:param order: list of indices in the original matrix where the
other_dims are.
:type order: List[int]
:returns: the new reference to a matrix element.
:rtype: :py:class:`psyclone.psyir.nodes.ArrayReference`
'''
indices = [Reference(sym) for sym in loop_idx_symbols]
indices.extend(dim.copy() for dim in other_dims)
n_indices = len(loop_idx_symbols) + len(other_dims)
indices = [-1]*n_indices
for it, order_idx in enumerate(order):
indices[order_idx] = other_dims[it].copy()
# Fill the remaining indices with loop index symbols
for it, sym in enumerate(loop_idx_symbols):
for i in range(n_indices):
if indices[i] == -1:
indices[i] = Reference(sym)
break
return ArrayReference.create(matrix_symbol, indices)


Expand Down Expand Up @@ -133,6 +145,67 @@ def _get_array_bound(array, index):
return (lower_bound, upper_bound, step)


def _get_full_range_split(array):
'''A utility function that returns the number of full ranges
and the list of indices which are not full ranges.
:param array: the reference that we are interested in.
:type array: :py:class:`psyir.nodes.Reference`
:returns: tuple with number of full ranges and
the list of non full range indices for this array.
:rtype: (int, List[psyclone.psyir.nodes.DataNode])
:raises TransformationError: if any Range nodes are not full, or
if there are more than two full range nodes.
'''
n_full_ranges = 0
non_full_ranges = []
order = []
for it, idx in enumerate(array.children):
if isinstance(idx, Range):
if array.is_full_range(it):
n_full_ranges += 1
else:
from psyclone.psyir.transformations import TransformationError
raise TransformationError(
f"To use matmul2code_trans on matmul, each Range index "
f"of the argument '{array.name}' must be a full range "
f"but found non full range at position {it}.")
else:
non_full_ranges.append(idx)
order.append(it)
# Early error raising if we go above 2 full ranges
if n_full_ranges > 2:
from psyclone.psyir.transformations import TransformationError
raise TransformationError(
f"To use matmul2code_trans on matmul, no more than "
f"two indices of the argument '{array.name}' "
f"must be full ranges but found {n_full_ranges}.")
return (n_full_ranges, non_full_ranges, order)


def _get_first_full_range_idx(array, reverse = False):

# Find last index if asked to
child_enumerate = enumerate(array.children)
if reverse:
child_enumerate = reversed(list(child_enumerate))

for it, idx in child_enumerate:
if isinstance(idx, Range):
if array.is_full_range(it):
return it
else:
from psyclone.psyir.transformations import TransformationError
raise TransformationError(
f"To use matmul2code_trans on matmul, each Range index "
f"of the argument '{array.name}' must be a full range "
f"but found non full range at position {it}.")
# Default behaviour
return -int(reverse == True)


class Matmul2CodeTrans(Intrinsic2CodeTrans):
'''Provides a transformation from a PSyIR MATMUL Operator node to
equivalent code in a PSyIR tree. Validity checks are also
Expand Down Expand Up @@ -257,28 +330,13 @@ def validate(self, node, options=None):
else:
# There should be one index per dimension. This is enforced
# by the array create method so is not tested here.

# The first two indices should be ranges. This is a
# limitation of this transformation, not the PSyIR, as it
# would be valid to have any two indices being
# ranges. Further, this transformation is currently
# limited to Ranges which specify the full extent of the
# dimension.
if not (matrix1.is_full_range(0) and matrix1.is_full_range(1)):
# For matrix1, we must have exactly 2 full ranges.
n_full_ranges, _, _ = _get_full_range_split(matrix1)
if n_full_ranges != 2:
raise TransformationError(
f"To use matmul2code_trans on matmul, the first two "
f"indices of the 1st argument '{matrix1.name}' must be "
f"full ranges.")

if len(matrix1.children) > 2:
# The 3rd index and onwards must not be ranges.
for (count, index) in enumerate(matrix1.children[2:]):
if isinstance(index, Range):
raise TransformationError(
f"To use matmul2code_trans on matmul, only the "
f"first two indices of the 1st argument are "
f"permitted to be Ranges but found "
f"{type(index).__name__} at index {count+2}.")
f"To use matmul2code_trans on matmul, exactly two "
f"indices of the 1st argument '{matrix1.name}' "
f"must be full ranges but found {n_full_ranges}.")

if len(matrix2.symbol.shape) > 2 and not matrix2.children:
# If matrix2 has no children then it is a reference. If
Expand All @@ -295,32 +353,13 @@ def validate(self, node, options=None):
else:
# There should be one index per dimension. This is enforced
# by the array create method so is not tested here.

# The first index should be a range. This
# transformation is currently limited to Ranges which
# specify the full extent of the dimension.
if not matrix2.is_full_range(0):
raise TransformationError(
f"To use matmul2code_trans on matmul, the first index of "
f"the 2nd argument '{matrix2.name}' must be a full range.")
# Check that the second dimension is a full range if it is
# a range.
if (len(matrix2.symbol.shape) > 1 and
isinstance(matrix2.children[1], Range)
and not matrix2.is_full_range(1)):
# For matrix2, we can have 1 or 2 full ranges.
n_full_ranges, _, _ = _get_full_range_split(matrix2)
if n_full_ranges not in [1,2]:
raise TransformationError(
f"To use matmul2code_trans on matmul for a matrix-matrix "
f"multiplication, the second index of the 2nd "
f"argument '{matrix2.name}' must be a full range.")
if len(matrix2.children) > 2:
# The 3rd index and onwards must not be ranges.
for (count, index) in enumerate(matrix2.children[2:]):
if isinstance(index, Range):
raise TransformationError(
f"To use matmul2code_trans on matmul, only the "
f"first two indices of the 2nd argument are "
f"permitted to be a Range but found "
f"{type(index).__name__} at index {count+2}.")
f"To use matmul2code_trans on matmul, one or two "
f"indices of the 2nd argument '{matrix2.name}' "
f"must be full ranges but found {n_full_ranges}.")

# Make sure the result has as many full range as needed
if result.children:
Expand Down Expand Up @@ -403,10 +442,12 @@ def _apply_matrix_vector(node):
vector_dims.append(child.copy())
vector_array_reference = ArrayReference.create(
vector.symbol, vector_dims)
_, ref_non_full_ranges, order = _get_full_range_split(matrix)
# Create "matrix(i,j)"
matrix_array_reference = _create_matrix_ref(matrix.symbol,
[i_loop_sym, j_loop_sym],
matrix.children[2:])
ref_non_full_ranges,
order)
# Create "matrix(i,j) * vector(j)"
multiply = BinaryOperation.create(
BinaryOperation.Operator.MUL, matrix_array_reference,
Expand All @@ -418,14 +459,16 @@ def _apply_matrix_vector(node):
assign = Assignment.create(result_ref.copy(), rhs)
# Create j loop and add the above code as a child
# Work out the bounds
lower_bound, upper_bound, step = _get_array_bound(vector, 0)
first_pos = _get_first_full_range_idx(vector)
lower_bound, upper_bound, step = _get_array_bound(vector, first_pos)
jloop = Loop.create(j_loop_sym, lower_bound, upper_bound, step,
[assign])
# Create "result(i) = 0.0"
assign = Assignment.create(result_ref.copy(),
Literal("0.0", REAL_TYPE))
# Create i loop and add assignment and j loop as children
lower_bound, upper_bound, step = _get_array_bound(matrix, 0)
first_pos = _get_first_full_range_idx(matrix)
lower_bound, upper_bound, step = _get_array_bound(matrix, first_pos)
iloop = Loop.create(i_loop_sym, lower_bound, upper_bound, step,
[assign, jloop])
# Replace the existing assignment with the new loop.
Expand Down Expand Up @@ -456,17 +499,23 @@ def _apply_matrix_matrix(node):
ii_loop_sym = symbol_table.new_symbol("ii", symbol_type=DataSymbol,
datatype=INTEGER_TYPE)
# Create "result(i,j)"
_, res_non_full_ranges, order = _get_full_range_split(result)
result_ref = _create_matrix_ref(result.symbol,
[i_loop_sym, j_loop_sym],
result.children[2:])
res_non_full_ranges,
order)
# Create "matrix2(ii,j)"
_, m2_non_full_ranges, order = _get_full_range_split(matrix2)
m2_array_reference = _create_matrix_ref(matrix2.symbol,
[ii_loop_sym, j_loop_sym],
matrix2.children[2:])
m2_non_full_ranges,
order)
# Create "matrix1(i,ii)"
_, m1_non_full_ranges, order = _get_full_range_split(matrix1)
m1_array_reference = _create_matrix_ref(matrix1.symbol,
[i_loop_sym, ii_loop_sym],
matrix1.children[2:])
m1_non_full_ranges,
order)
# Create "matrix1(i,ii) * matrix2(ii,j)"
multiply = BinaryOperation.create(
BinaryOperation.Operator.MUL, m1_array_reference,
Expand All @@ -478,19 +527,22 @@ def _apply_matrix_matrix(node):
assign = Assignment.create(result_ref.copy(), rhs)
# Create ii loop and add the above code as a child
# Work out the bounds
lower_bound, upper_bound, step = _get_array_bound(matrix1, 1)
# Must be the same as _get_array_bound(matrix2, 0)
pos_last = _get_first_full_range_idx(matrix1, reverse=True)
lower_bound, upper_bound, step = _get_array_bound(matrix1, pos_last)
# Must be the same as _get_array_bound(matrix2, pos_first)
iiloop = Loop.create(ii_loop_sym, lower_bound, upper_bound, step,
[assign])
# Create "result(i,j) = 0.0"
assign = Assignment.create(result_ref.copy(),
Literal("0.0", REAL_TYPE))
# Create i loop and add assignment and ii loop as children.
lower_bound, upper_bound, step = _get_array_bound(matrix1, 0)
pos_first = _get_first_full_range_idx(matrix1)
lower_bound, upper_bound, step = _get_array_bound(matrix1, pos_first)
iloop = Loop.create(i_loop_sym, lower_bound, upper_bound, step,
[assign, iiloop])
# Create j loop and add i loop as child.
lower_bound, upper_bound, step = _get_array_bound(matrix2, 1)
pos_last = _get_first_full_range_idx(matrix2, reverse=True)
lower_bound, upper_bound, step = _get_array_bound(matrix2, pos_last)
jloop = Loop.create(j_loop_sym, lower_bound, upper_bound, step,
[iloop])
# Replace the original assignment with the new loop.
Expand Down
Loading

0 comments on commit b7f4844

Please sign in to comment.