diff --git a/gem/impero_utils.py b/gem/impero_utils.py index e5635975..6dea6f5f 100644 --- a/gem/impero_utils.py +++ b/gem/impero_utils.py @@ -10,6 +10,8 @@ from functools import singledispatch from itertools import chain, groupby +from numpy import find_common_type + from gem.node import traversal, collect_refcount from gem import gem, impero as imp, optimise, scheduling @@ -21,7 +23,9 @@ # temporaries - List of GEM expressions which have assigned temporaries # declare - Where to declare temporaries to get correct C code # indices - Indices for declarations and referencing values -ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices']) +# return_variable - 2-tuple of gem return variable and inferred numpy dtype +ImperoC = collections.namedtuple('ImperoC', ['tree', 'temporaries', 'declare', 'indices', + 'return_variable']) class NoopError(Exception): @@ -38,11 +42,12 @@ def preprocess_gem(expressions, replace_delta=True, remove_componenttensors=True return expressions -def compile_gem(assignments, prefix_ordering, remove_zeros=False): +def compile_gem(assignments, prefix_ordering, scalar_type, remove_zeros=False): """Compiles GEM to Impero. :arg assignments: list of (return variable, expression DAG root) pairs :arg prefix_ordering: outermost loop indices + :arg scalar_type: default scalar type :arg remove_zeros: remove zero assignment to return variables """ # Remove zeros @@ -52,6 +57,9 @@ def nonzero(assignment): return not isinstance(expression, gem.Zero) assignments = list(filter(nonzero, assignments)) + # Type inference for return value + return_variable = infer_dtype(assignments, scalar_type) + # Just the expressions expressions = [expression for variable, expression in assignments] @@ -88,7 +96,26 @@ def nonzero(assignment): declare, indices = place_declarations(tree, temporaries, get_indices) # Prepare ImperoC (Impero AST + other data for code generation) - return ImperoC(tree, temporaries, declare, indices) + return ImperoC(tree, temporaries, declare, indices, return_variable) + + +def infer_dtype(assignments, scalar_type): + from tsfc.loopy import assign_dtypes + from gem.node import traversal + + def extract_variable(expr): + x, = set(v for v in traversal([expr]) if isinstance(v, gem.Variable)) + return x + + vars = set() + dtypes = set() + for var, expression in assignments: + var = extract_variable(var) + ((_, dtype), ) = assign_dtypes([expression], scalar_type) + vars.add(var) + dtypes.add(dtype) + var, = vars + return var, find_common_type([], dtypes) def make_prefix_ordering(indices, prefix_ordering): diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 8d0bc796..3cc660de 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -1,4 +1,5 @@ import pytest +import numpy from gem import impero_utils from gem.gem import Index, Indexed, IndexSum, Product, Variable @@ -18,7 +19,7 @@ def make_expression(i, j): e2 = make_expression(i, i) def gencode(expr): - impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j)) + impero_c = impero_utils.compile_gem([(Ri, expr)], (i, j), numpy.dtype(numpy.float64)) return impero_c.tree assert len(gencode(e1).children) == len(gencode(e2).children) diff --git a/tsfc/driver.py b/tsfc/driver.py index cc98895d..131b45ac 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -239,7 +239,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co for var in return_variables])) index_ordering = tuple(quadrature_indices) + split_argument_indices try: - impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True) + impero_c = impero_utils.compile_gem(assignments, index_ordering, + parameters["scalar_type"], remove_zeros=True) except impero_utils.NoopError: # No operations, construct empty kernel return builder.construct_empty_kernel(kernel_name) @@ -421,7 +422,8 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, inte # TODO: one should apply some GEM optimisations as in assembly, # but we don't for now. ir, = impero_utils.preprocess_gem([ir]) - impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices) + impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices, + parameters["scalar_type"]) index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices)) # Handle kernel interface requirements builder.register_requirements([ir]) diff --git a/tsfc/fem.py b/tsfc/fem.py index 1726c468..b8d580b3 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -67,6 +67,14 @@ def __init__(self, interface, **kwargs): raise ValueError("unexpected keyword argument '{0}'".format(invalid_keywords.pop())) self.__dict__.update(kwargs) + def reify(self, expr): + if self.complex_mode: + indices = gem.indices(len(expr.shape)) + return gem.ComponentTensor(gem.MathFunction("real", gem.Indexed(expr, indices)), + indices) + else: + return expr + @cached_property def fiat_cell(self): return as_fiat_cell(self.ufl_cell) @@ -136,7 +144,7 @@ def config(self): return config def cell_size(self): - return self.interface.cell_size(self.mt.restriction) + return self.interface.reify(self.interface.cell_size(self.mt.restriction)) def jacobian_at(self, point): expr = Jacobian(self.mt.terminal.ufl_domain()) @@ -427,7 +435,7 @@ def translate_spatialcoordinate(terminal, mt, ctx): # Rebuild modified terminal expr = construct_modified_terminal(mt, terminal) # Translate replaced UFL snippet - return ctx.translator(expr) + return ctx.reify(ctx.translator(expr)) class CellVolumeKernelInterface(ProxyKernelInterface): diff --git a/tsfc/kernel_interface/firedrake.py b/tsfc/kernel_interface/firedrake.py index a7321edf..816e69dd 100644 --- a/tsfc/kernel_interface/firedrake.py +++ b/tsfc/kernel_interface/firedrake.py @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs): class Kernel(object): __slots__ = ("ast", "integral_type", "oriented", "subdomain_id", "domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule", - "coefficient_numbers", "__weakref__") + "return_dtype", "coefficient_numbers", "__weakref__") """A compiled Kernel object. :kwarg ast: The COFFEE ast for the kernel. @@ -40,12 +40,14 @@ class Kernel(object): :kwarg coefficient_numbers: A list of which coefficients from the form the kernel needs. :kwarg quadrature_rule: The finat quadrature rule used to generate this kernel + :kwarg return_dtype: numpy dtype of the return value. :kwarg tabulations: The runtime tabulations this kernel requires :kwarg needs_cell_sizes: Does the kernel require cell sizes. """ def __init__(self, ast=None, integral_type=None, oriented=False, subdomain_id=None, domain_number=None, quadrature_rule=None, coefficient_numbers=(), + return_dtype=None, needs_cell_sizes=False): # Defaults self.ast = ast @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False, self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers self.needs_cell_sizes = needs_cell_sizes + self.return_dtype = return_dtype super(Kernel, self).__init__() diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index 4eb86afe..915b8360 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -27,7 +27,7 @@ def make_builder(*args, **kwargs): class Kernel(object): __slots__ = ("ast", "integral_type", "oriented", "subdomain_id", "domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule", - "coefficient_numbers", "__weakref__") + "return_dtype", "coefficient_numbers", "__weakref__") """A compiled Kernel object. :kwarg ast: The loopy kernel object. @@ -40,12 +40,14 @@ class Kernel(object): :kwarg coefficient_numbers: A list of which coefficients from the form the kernel needs. :kwarg quadrature_rule: The finat quadrature rule used to generate this kernel + :kwarg return_dtype: numpy dtype of the return value. :kwarg tabulations: The runtime tabulations this kernel requires :kwarg needs_cell_sizes: Does the kernel require cell sizes. """ def __init__(self, ast=None, integral_type=None, oriented=False, subdomain_id=None, domain_number=None, quadrature_rule=None, coefficient_numbers=(), + return_dtype=None, needs_cell_sizes=False): # Defaults self.ast = ast @@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False, self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers self.needs_cell_sizes = needs_cell_sizes + self.return_dtype = return_dtype super(Kernel, self).__init__() @@ -164,8 +167,8 @@ def construct_kernel(self, return_arg, impero_c, precision, index_names): for name_, shape in self.tabulations: args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape)) - loopy_kernel = generate_loopy(impero_c, args, precision, self.scalar_type, - "expression_kernel", index_names) + loopy_kernel, _ = generate_loopy(impero_c, args, precision, self.scalar_type, + "expression_kernel", index_names, ignore_return_type=True) return ExpressionKernel(loopy_kernel, self.oriented, self.cell_sizes, self.coefficients, self.tabulations) @@ -207,6 +210,7 @@ def set_arguments(self, arguments, multiindices): :arg multiindices: GEM argument multiindices :returns: GEM expression representing the return variable """ + self.rank = len(arguments) self.local_tensor, expressions = prepare_arguments( arguments, multiindices, self.scalar_type, interior_facet=self.interior_facet, diagonal=self.diagonal) @@ -277,7 +281,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru :returns: :class:`Kernel` object """ - args = [self.local_tensor, self.coordinates_arg] + ignore_return_type = self.rank > 0 + if ignore_return_type: + args = [self.local_tensor, self.coordinates_arg] + else: + args = [self.coordinates_arg] if self.kernel.oriented: args.append(self.cell_orientations_loopy_arg) if self.kernel.needs_cell_sizes: @@ -292,8 +300,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape)) self.kernel.quadrature_rule = quadrature_rule - self.kernel.ast = generate_loopy(impero_c, args, precision, - self.scalar_type, name, index_names) + ast, dtype = generate_loopy(impero_c, args, precision, + self.scalar_type, name, index_names, + ignore_return_type=ignore_return_type) + self.kernel.ast = ast + self.kernel.return_dtype = dtype return self.kernel def construct_empty_kernel(self, name): diff --git a/tsfc/loopy.py b/tsfc/loopy.py index e04ba2ba..b761ec0f 100644 --- a/tsfc/loopy.py +++ b/tsfc/loopy.py @@ -186,15 +186,17 @@ def active_indices(mapping, ctx): ctx.active_indices.pop(key) -def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[]): +def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[], + ignore_return_type=True): """Generates loopy code. :arg impero_c: ImperoC tuple with Impero AST and other data :arg args: list of loopy.GlobalArgs :arg precision: floating-point precision for printing - :arg scalar_type: type of scalars as C typename string + :arg scalar_type: type of scalars as numpy dtype :arg kernel_name: function name of the kernel :arg index_names: pre-assigned index names + :arg ignore_return_type: Ignore inferred return type from impero_c? :returns: loopy kernel """ ctx = LoopyContext() @@ -205,7 +207,12 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", ctx.epsilon = 10.0 ** (-precision) # Create arguments - data = list(args) + if ignore_return_type: + return_dtype = scalar_type + data = list(args) + else: + A, return_dtype = impero_c.return_variable + data = [lp.GlobalArg(A.name, shape=A.shape, dtype=return_dtype)] + list(args) for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)): name = "t%d" % i if isinstance(temp, gem.Constant): @@ -240,7 +247,7 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", insn_new.append(insn.copy(priority=len(knl.instructions) - i)) knl = knl.copy(instructions=insn_new) - return knl + return knl, return_dtype @singledispatch