Skip to content

Commit

Permalink
All tests should pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Jul 2, 2024
1 parent c3c5252 commit 9f54690
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 264 deletions.
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
# GridTools

[![Build Status](https://github.com/jeffzwe/GridTools.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/jeffzwe/GridTools.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Static Badge](https://img.shields.io/badge/docs-stable-blue.svg)](https://jeffzwe.github.io/GridTools.jl/dev)
[![Build Status](https://github.com/jeffzwe/GridTools.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/jeffzwe/GridTools.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Static Badge](https://img.shields.io/badge/docs-stable-blue.svg)](https://jeffzwe.github.io/GridTools.jl/dev)

## Installation

### Setup python virtual environment


### Development installation

```bash
export GRIDTOOLS_JL_PATH="..."
export GT4PY_PATH="..."
# create python virtual environemnt
# make sure to use a python version that is compatible with GT4Py
python -m venv venv
# activate virtual env
# this command has be run everytime GridTools.jl is used
source venv/bin/activate
# clone gt4py
git clone [email protected]:GridTools/gt4py.git $GT4PY_PATH
pip install -r $GT4PY_PATH/requirements-dev.txt
pip install -e $GT4PY_PATH
#
```

## Troubleshooting

__undefined symbol: PyObject_Vectorcall__

Make sure to run everything in the same environment that you have build `PyCall` with. A common reason is you have built PyCall in a virtual environement and then didn't load it when executing stencils.
24 changes: 7 additions & 17 deletions src/GridTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,6 @@ compute_remapped_field_info(
conn
)= ((), ())

function remap_ts(
field::Field,
offset::FieldOffsetTS{OffsetName, SourceDim, Tuple{TargetDim}},
nb_ind::Int64)::Field where {OffsetName, SourceDim <: Dimension, TargetDim <:Dimension}
conn = OFFSET_PROVIDER[string(OffsetName)]

new_offsets = Dict(field.dims[i] => field.origin[i] for i in 1:length(field.dims))
new_offsets[conn] = nb_ind
return Field(field.dims, field.data, field.broadcast_dims, origin = new_offsets)
end


function remap_broadcast_dims(
broadcast_dims::Tuple{T, Vararg{Dimension}},
Expand Down Expand Up @@ -282,7 +271,7 @@ function remap_ts(
# eltype(field.data)(0)
# end
#end, CartesianIndices(map(len -> Base.OneTo(len), out_field_size)))
out_field = Array{eltype(field.data)}(undef, out_field_size)
out_field = zeros(eltype(field.data), out_field_size)
for position in eachindex(IndexCartesian(), out_field)
neighbor_exists, new_position = remap_position(Tuple(position), out_field_dims, offset, nb_ind, conn)
if neighbor_exists
Expand Down Expand Up @@ -371,10 +360,10 @@ copyfield!(target, source) = target .= source

# Field operator functionalities ------------------------------------------------------------

OFFSET_PROVIDER::Union{Dict{String, Union{Connectivity, Dimension}}, Nothing} = nothing
OFFSET_PROVIDER::Union{Dict{String, <:Union{Connectivity, Dimension}}, Nothing} = nothing
FIELD_OPERATORS::Dict{Symbol, PyObject} = Dict{Symbol, PyObject}()

function (fo::FieldOp)(args...; offset_provider::Dict{String, Union{Connectivity, Dimension}} = Dict{String, Union{Connectivity, Dimension}}(), backend::String = "embedded", out = nothing, kwargs...)
function (fo::FieldOp)(args...; offset_provider::Dict{String, <:Union{Connectivity, Dimension}} = Dict{String, Union{Connectivity, Dimension}}(), backend::String = "embedded", out = nothing, kwargs...)

is_outermost_fo = isnothing(OFFSET_PROVIDER)
if is_outermost_fo
Expand Down Expand Up @@ -476,13 +465,14 @@ end

macro module_vars()
return esc(quote
# TODO(tehrengruber): for some reasons this was needed from some point on. cleanup
base_vars = Dict(name => Core.eval(Base, name) for name in [:Int64, :Int32, :Float32, :Float64])
module_vars = Dict(name => Core.eval(@__MODULE__, name) for name in names(@__MODULE__))
local_vars = Base.@locals
merge(module_vars, local_vars, GridTools.builtin_op)
merge(base_vars, module_vars, local_vars, GridTools.builtin_op)
end)
end


"""
@field_operator
Expand All @@ -499,7 +489,7 @@ macro field_operator(expr::Expr)
expr_dict = splitdef(expr)
expr_dict[:name] = generate_unique_name(f_name)
unique_expr = combinedef(expr_dict)

return Expr(:(=), esc(f_name), :(FieldOp(namify($(Expr(:quote, expr))), $(esc(unique_expr)), $(Expr(:quote, expr)), get_closure_vars($(Expr(:quote, expr)), @module_vars))))
end

Expand Down
21 changes: 13 additions & 8 deletions src/gt2py/gt2py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ CLOSURE_VARS::Dict = Dict()

# Methods -----------------------------------------------------------------------------------

function py_field_operator(fo, backend = py_backends["gpu"], grid_type = py"None"o, operator_attributes = Dict())
function py_field_operator(fo, backend = Nothing, grid_type = py"None"o, operator_attributes = Dict())
if backend == Nothing
backend = py_backends["gpu"]
end

foast_definition_node, closure_vars = jast_to_foast(fo.expr, fo.closure_vars)
loc = foast_definition_node.location
Expand All @@ -146,12 +149,12 @@ function py_field_operator(fo, backend = py_backends["gpu"], grid_type = py"None
foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node)

return FieldOperater(
foast_node=foast_node,
closure_vars=closure_vars,
definition=py"None"o,
backend=backend,
grid_type=grid_type,
)
foast_node=foast_node,
closure_vars=closure_vars,
definition=py"None"o,
backend=backend,
grid_type=grid_type,
)
end

function jast_to_foast(expr::Expr, closure_vars::Dict)
Expand Down Expand Up @@ -187,7 +190,7 @@ function postprocess_definition(foast_node, closure_vars, annotations)
end

py_args(args::Union{Base.Pairs, Dict}) = Dict(i.first => convert_type(i.second) for i in args)
py_args(args::Tuple) = [map(convert_type, args)...]
py_args(args::Tuple) = [convert_type(arg) for arg in args]
py_args(arg) = convert_type(arg)
py_args(n::Nothing) = nothing

Expand All @@ -204,6 +207,8 @@ function convert_type(a::Field)
new_data = np.asarray(a.data)
@warn "Dtype of the Field: $a is not concrete. Data must be copied to Python which may affect performance. Try using dtypes <: Array."
end
println(new_data)
println(typeof(new_data))

offset = Dict(convert_type(dim) => a.origin[i] for (i, dim) in enumerate(a.dims))

Expand Down
21 changes: 16 additions & 5 deletions src/gt2py/jast_to_foast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function _builtin_type_constructor_symbols(captured_vars, loc)::Tuple
if name in fbuiltins.TYPE_BUILTIN_NAMES && value == py"getattr"(fbuiltins, name) # NOTE: Should be === . Doesnt work with pyobjects
)

to_be_inserted = union(python_type_builtins, captured_type_builtins)
to_be_inserted = merge(python_type_builtins, captured_type_builtins)

for (name, value) in to_be_inserted
push!(result,
Expand Down Expand Up @@ -67,13 +67,22 @@ function visit_function(args::Array, closure_vars::Dict)
function_header = function_header.args[1]
end

for param in Base.tail((Tuple(function_header.args)))
function_params = vcat(function_params, visit_types(param.args, closure_vars, inner_loc))
func_name, func_params_expr... = [function_header.args...]

# canonicalize keyword arguments
# TODO(tehrengruber): ensure this is tested properly
if length(func_params_expr) >= 1 && func_params_expr[1] isa Expr && func_params_expr[1].head == :parameters
func_kwparams_exprs = popfirst!(func_params_expr).args
push!(func_params_expr, func_kwparams_exprs...)
end

for param in func_params_expr
push!(function_params, visit_types(param.args, closure_vars, inner_loc))
end

return foast.FunctionDefinition(
id=string(function_header.args[1]),
params= function_params,
id=string(func_name),
params=function_params,
body=function_body,
closure_vars=closure_var_symbols,
location=inner_loc
Expand Down Expand Up @@ -417,6 +426,8 @@ function from_type_hint(expr::Expr, closure_vars::Dict)
dim = []
(dims, dtype) = param_type[2:end]

# TODO: do some sanity checks here for example Field{Int64, Dims} will fail terribly

for d in dims.args[2:end]
@assert string(d) in keys(closure_vars)
push!(dim, closure_vars[string(d)])
Expand Down
10 changes: 7 additions & 3 deletions test/embedded_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Statistics

# Setup ------------------------------------------------------------------------------------------------------------------------------
include("mesh_definitions.jl")

cell_values = Field(Cell, [1.0, 1.0, 2.0, 3.0, 5.0, 8.0])
edge_to_cell_table = [
Expand All @@ -23,7 +24,10 @@ C2E_offset_provider = Connectivity(edge_to_cell_table, Edge, Cell, 2)

offset_provider = Dict{String, Union{Connectivity, Dimension}}(
"E2C" => E2C_offset_provider,
"C2E" => C2E_offset_provider
"C2E" => C2E_offset_provider,
# TODO(tehrengruber): cleanup
"E2CDim" => E2C_offset_provider,
"C2EDim" => C2E_offset_provider
)

x = Field((Cell, K), reshape(collect(-3.0:2.0), (3, 2)))
Expand Down Expand Up @@ -94,7 +98,7 @@ end
end

fo_remapping(cell_values, offset_provider=offset_provider, out = out)
@test out == result_offset_call
@test out == result_offset_call
end


Expand Down Expand Up @@ -169,5 +173,5 @@ end
1.0 1.0 1.0]
result = Field((Cell, K), result_data)

@test concat(x, y).data == result
@test concat(a, b).data == result
end
Loading

0 comments on commit 9f54690

Please sign in to comment.