Skip to content

Commit

Permalink
bug[next]: respect DEFAULT_BACKEND and no_backend mechanism (#1380)
Browse files Browse the repository at this point in the history
fixes #1376.

Thanks @DropD for the testcase.
  • Loading branch information
havogt authored Jan 4, 2024
1 parent 7a9489f commit 27bf18f
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 31 deletions.
16 changes: 11 additions & 5 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, utils
from gt4py.next import common, constructors, errors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


Expand Down Expand Up @@ -77,17 +77,20 @@ def scan_loop(hpos):
def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
if "out" in kwargs:
# called from program or direct field_operator as program
offset_provider = kwargs.pop("offset_provider", None)

new_context_kwargs = {}
if embedded_context.within_context():
# called from program
assert offset_provider is None
assert "offset_provider" not in kwargs
else:
# field_operator as program
if "offset_provider" not in kwargs:
raise errors.MissingArgumentError(None, "offset_provider", True)
offset_provider = kwargs.pop("offset_provider", None)

new_context_kwargs["offset_provider"] = offset_provider

out = kwargs.pop("out")

domain = kwargs.pop("domain", None)

flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,))
Expand All @@ -105,7 +108,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
domain=out_domain,
)
else:
# called from other field_operator
# called from other field_operator or missing `out` argument
if "offset_provider" in kwargs:
# assuming we wanted to call the field_operator as program, otherwise `offset_provider` would not be there
raise errors.MissingArgumentError(None, "out", True)
return op(*args, **kwargs)


Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .exceptions import (
DSLError,
InvalidParameterAnnotationError,
MissingArgumentError,
MissingAttributeError,
MissingParameterAnnotationError,
UndefinedSymbolError,
Expand All @@ -33,6 +34,7 @@
"InvalidParameterAnnotationError",
"MissingAttributeError",
"MissingParameterAnnotationError",
"MissingArgumentError",
"UndefinedSymbolError",
"UnsupportedPythonFeatureError",
"set_verbose_exceptions",
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None:
self.attr_name = attr_name


class MissingArgumentError(DSLError):
arg_name: str
is_kwarg: bool

def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None:
super().__init__(
location, f"Expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'."
)
self.attr_name = arg_name
self.is_kwarg = is_kwarg


class TypeError_(DSLError):
def __init__(self, location: Optional[SourceLocation], message: str) -> None:
super().__init__(location, message)
Expand Down
47 changes: 28 additions & 19 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

from devtools import debug

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators, embedded as next_embedded
from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
Expand Down Expand Up @@ -61,11 +62,10 @@
sym,
)
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.program_processors.runners import roundtrip
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


DEFAULT_BACKEND: Callable = roundtrip.executor
DEFAULT_BACKEND: Callable = None


def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -176,15 +176,15 @@ class Program:

past_node: past.Program
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None
definition: Optional[types.FunctionType]
backend: Optional[ppi.ProgramExecutor]
grid_type: Optional[GridType]

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
backend: Optional[ppi.ProgramExecutor],
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
Expand Down Expand Up @@ -495,7 +495,7 @@ def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.Functi
def program(
definition=None,
*,
backend=None,
backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution)
grid_type=None,
) -> Program | Callable[[types.FunctionType], Program]:
"""
Expand All @@ -517,7 +517,9 @@ def program(
"""

def program_inner(definition: types.FunctionType) -> Program:
return Program.from_function(definition, backend, grid_type)
return Program.from_function(
definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
)

return program_inner if definition is None else program_inner(definition)

Expand Down Expand Up @@ -549,17 +551,17 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):

foast_node: OperatorNodeT
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None
definition: Optional[types.FunctionType]
backend: Optional[ppi.ProgramExecutor]
grid_type: Optional[GridType]
operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(default_factory=dict)

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
backend: Optional[ppi.ProgramExecutor],
grid_type: Optional[GridType] = None,
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
Expand Down Expand Up @@ -686,6 +688,7 @@ def as_program(
self._program_cache[hash_] = Program(
past_node=past_node,
closure_vars=closure_vars,
definition=None,
backend=self.backend,
grid_type=self.grid_type,
)
Expand All @@ -698,7 +701,12 @@ def __call__(
) -> None:
if not next_embedded.context.within_context() and self.backend is not None:
# non embedded execution
offset_provider = kwargs.pop("offset_provider", None)
if "offset_provider" not in kwargs:
raise errors.MissingArgumentError(None, "offset_provider", True)
offset_provider = kwargs.pop("offset_provider")

if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
Expand Down Expand Up @@ -744,7 +752,7 @@ def field_operator(
...


def field_operator(definition=None, *, backend=None, grid_type=None):
def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None):
"""
Generate an implementation of the field operator from a Python function object.
Expand All @@ -762,7 +770,9 @@ def field_operator(definition=None, *, backend=None, grid_type=None):
"""

def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]:
return FieldOperator.from_function(definition, backend, grid_type)
return FieldOperator.from_function(
definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type
)

return field_operator_inner if definition is None else field_operator_inner(definition)

Expand Down Expand Up @@ -798,7 +808,7 @@ def scan_operator(
axis: Dimension,
forward: bool = True,
init: core_defs.Scalar = 0.0,
backend=None,
backend=eve.NOTHING,
grid_type: GridType = None,
) -> (
FieldOperator[foast.ScanOperator]
Expand Down Expand Up @@ -836,8 +846,7 @@ def scan_operator(
def scan_operator_inner(definition: types.FunctionType) -> FieldOperator:
return FieldOperator.from_function(
definition,
backend,
grid_type,
DEFAULT_BACKEND if backend is eve.NOTHING else backend,
operator_node_cls=foast.ScanOperator,
operator_attributes={"axis": axis, "forward": forward, "init": init},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import gt4py.next as gtx
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.program_processors.runners import gtfn, roundtrip


try:
Expand All @@ -36,9 +38,10 @@
import next_tests.exclusion_matrices as definitions


@ppi.program_executor
def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
"""Temporary default backend to not accidentally test the wrong backend."""
raise ValueError("No backend selected. Backend selection is mandatory in tests.")
raise ValueError("No backend selected! Backend selection is mandatory in tests.")


OPTIONAL_PROCESSORS = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import math
from typing import Callable
from typing import Callable, Optional

import numpy as np
import pytest
Expand All @@ -22,6 +22,7 @@
from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast
from gt4py.next.ffront.decorator import FieldOperator
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_translation

from next_tests.integration_tests import cases
Expand All @@ -39,7 +40,7 @@
# becomes easier.


def make_builtin_field_operator(builtin_name: str):
def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.ProgramExecutor]):
# TODO(tehrengruber): creating a field operator programmatically should be
# easier than what we need to do here.
# construct annotation dictionary containing the input argument and return
Expand Down Expand Up @@ -109,8 +110,9 @@ def make_builtin_field_operator(builtin_name: str):
return FieldOperator(
foast_node=typed_foast_node,
closure_vars=closure_vars,
backend=None,
definition=None,
backend=backend,
grid_type=None,
)


Expand All @@ -129,9 +131,7 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp
expected = ref_impl(*inputs)
out = cartesian_case.as_field([IDim], np.zeros_like(expected))

builtin_field_op = make_builtin_field_operator(builtin_name).with_backend(
cartesian_case.backend
)
builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend)

builtin_field_op(*inps, out=out, offset_provider={})

Expand Down
Loading

0 comments on commit 27bf18f

Please sign in to comment.