Skip to content

Commit

Permalink
Add local declaration of vertical extent, needed for vertical flipping
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinswales committed Oct 26, 2023
1 parent aae4b43 commit c3d2b3a
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions scripts/suite_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,8 @@ def __init__(self, scheme_xml, context, parent, run_env):
self.__lib = scheme_xml.get('lib', None)
self.__has_vertical_dimension = False
self.__group = None
self.__alloc_transforms = list()
self.__depend_transforms = list()
self.__forward_transforms = list()
self.__reverse_transforms = list()
super().__init__(name, context, parent, run_env, active_call_list=True)
Expand Down Expand Up @@ -1194,19 +1196,27 @@ def analyze(self, phase, group, scheme_library, suite_vars, level):
# end if
# Are there any forward/reverse transforms for this variable?
if compat_obj is not None and (compat_obj.has_vert_transforms or compat_obj.has_unit_transforms):
# Add local variable (<var>_local) needed for transformation.
tmp_var = var.clone(var.get_prop_value('local_name')+'_local')
alloc_stmt = "allocate({}({}))"
self.__alloc_transforms.append(alloc_stmt.format(var.get_prop_value('local_name')+'_local',''))
self.__group.manage_variable(tmp_var)

# Create indices for vertical flipping (if needed)
# Add local variable (<var>_nlay) needed for vertical flipping.
indices = [':']*var.get_rank()
dim = find_vertical_dimension(var.get_dimensions())[0]
for dpart in dim.split(':'):
if (dpart in var_local["std_name"]):
vli = var_local["std_name"].index(dpart)
if (compat_obj.has_vert_transforms):
indices[find_vertical_dimension(var.get_dimensions())[1]] = var_local["local_name"][vli] + ':1:-1'
else:
indices[find_vertical_dimension(var.get_dimensions())[1]] = '1:' + var_local["local_name"][vli]
if compat_obj.has_vert_transforms:
verti_var = Var({'local_name':var.get_prop_value('local_name')+'_nlay',
'standard_name':var.get_prop_value('local_name')+'_nlay',
'type':'integer', 'units':'count',
'dimensions':'()'}, _API_LOCAL, self.run_env)
self.__group.manage_variable(verti_var)
# Set indices for vertical flipping.
dim = find_vertical_dimension(var.get_dimensions())
for dpart in dim[0].split(':'):
indices[dim[1]] = var.get_prop_value('local_name')+'_nlay:1:-1'
# Create statement for use in write stage.
write_stmt = var.get_prop_value('local_name')+"_nlay = size({},{})"
self.__depend_transforms.append(write_stmt.format(var.get_prop_value('local_name'),dim[1]+1))

# Add any forward transforms.
if (var.get_prop_value('intent') != 'in'):
Expand Down Expand Up @@ -1242,14 +1252,18 @@ def write(self, outfile, errcode, indent):
is_func_call=True,
subname=self.subroutine_name)
stmt = 'call {}({})'
# Write any reverse transforms.
for reverse_transform in self.__reverse_transforms: outfile.write(reverse_transform, indent)
# Write the scheme call.
outfile.write('if ({} == 0) then'.format(errcode), indent)
# Write any allocate statements needed for transforms.
#for alloc_transform in self.__alloc_transforms: outfile.write(alloc_transform, indent+1)
# Write any dependencies needed for transforms.
for depend_transform in self.__depend_transforms: outfile.write(depend_transform, indent+1)
# Write any reverse transforms.
for reverse_transform in self.__reverse_transforms: outfile.write(reverse_transform, indent+1)
outfile.write(stmt.format(self.subroutine_name, my_args), indent+1)
outfile.write('end if', indent)
# Write any forward transforms.
for forward_transform in self.__forward_transforms: outfile.write(forward_transform, indent)
for forward_transform in self.__forward_transforms: outfile.write(forward_transform, indent+1)
outfile.write('end if', indent)

def schemes(self):
"""Return self as a list for consistency with subcycle"""
Expand Down

0 comments on commit c3d2b3a

Please sign in to comment.