Skip to content

Commit

Permalink
Merge pull request #17 from lorenzovarese/13-refactor-tests-in-gt4py_…
Browse files Browse the repository at this point in the history
…fo_execjl

13 refactor tests in gt4py fo execjl
  • Loading branch information
lorenzovarese authored Jul 17, 2024
2 parents adae26a + 512dd45 commit 084cf9f
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 154 deletions.
34 changes: 32 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
/Manifest.toml
# Files generated by invoking Julia with --code-coverage
*.jl.cov
*.jl.*.cov

# Files generated by invoking Julia with --track-allocation
*.jl.mem

# System-specific files and directories generated by the BinaryProvider and BinDeps packages
# They contain absolute paths specific to the host computer, and so should not be committed
deps/deps.jl
deps/build.log
deps/downloads/
deps/usr/
deps/src/

# Build artifacts for creating documentation generated by the Documenter package
docs/build/
/.DS_Store
docs/site/

# File generated by Pkg, the package manager, based on a corresponding Project.toml
# It records a fixed state of all packages used by the project. As such, it should not be
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# Python Env
.venv
env_setup.sh
.python-version

# Misc
.DS_Store
.vscode
10 changes: 7 additions & 3 deletions src/GridTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ function Field(
return Field(Tuple(dim), data, Tuple(broadcast_dims), origin = origin)
end

# TODO(tehrengruber): There is no need to have FieldOffset and FieldOffsetTS, remove FieldOffset
struct FieldOffsetTS{
Name,
Source <: Dimension,
Expand Down Expand Up @@ -367,10 +368,11 @@ function remap_ts(
offset::FieldOffsetTS{OffsetName, SourceDim, Tuple{TargetDim}},
nb_ind::Int64)::Field where {OffsetName, SourceDim <: Dimension, TargetDim <:Dimension}
conn = OFFSET_PROVIDER[string(OffsetName)]
@assert conn isa Dimension

new_offsets = Dict(field.dims[i] => field.origin[i] for i in 1:length(field.dims))
new_offsets[conn] = nb_ind
return Field(field.dims, field.data, field.broadcast_dims, origin = new_offsets)
new_origin = Dict(field.dims[i] => field.origin[i] for i in 1:length(field.dims))
new_origin[conn] -= nb_ind
return Field(field.dims, field.data, field.broadcast_dims, origin = new_origin)
end

function remap_ts(
Expand Down Expand Up @@ -466,6 +468,8 @@ struct Connectivity
dims::Integer
end

Base.getindex(conn::Connectivity, row::Union{Integer, Colon}, col::Integer) = conn.data[row, col]

# Field operator ----------------------------------------------------------------------

struct FieldOp
Expand Down
4 changes: 2 additions & 2 deletions src/embedded/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ julia> a = Field((Cell, K), fill(1.0, (3,3)));
julia> b = Field((Cell, K), fill(2.0, (3,3)));
julia> where(mask, a, b)
3x3 Field with dimensions ("Cell", "K") with indices 1:3×1:3:
1.0 2.0 2.0
2.0 1.0 2.0
2.0 2.0 1.0
1.0 1.0 1.0
1.0 2.0 2.0
```
The `where` function builtin also allows for nesting of tuples. In this scenario, it will first perform an unrolling:
Expand Down
8 changes: 5 additions & 3 deletions src/embedded/cust_broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Base.BroadcastStyle(::Type{<:Field}) = Broadcast.ArrayStyle{Field}()

# TODO(tehrengruber): Implement a range with an attached dimension instead of this single object
# for an entire domain. Invesitage what broadcast_dims is needed for here.
struct FieldShape{
N,
Dim <: NTuple{N, Dimension},
Expand All @@ -11,9 +13,9 @@ struct FieldShape{
broadcast_dims::B_Dim
end

function shape(f::Field)
return FieldShape(f.dims, axes(f), f.broadcast_dims)
end
shape(f::Field) = FieldShape(f.dims, axes(f), f.broadcast_dims)

Base.length(bc::FieldShape{N}) where N = N

# Only called for assign broadcasting (.=)
@inline function Base.Broadcast.materialize!(dest, bc::Broadcasted{ArrayStyle{Field}})
Expand Down
20 changes: 10 additions & 10 deletions src/gt2py/gt2py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,20 @@ function py_field_operator(
end

function jast_to_foast(expr::Expr, closure_vars::Dict)
expr, closure_vars, annotations = preprocess_definiton(expr, closure_vars)
expr, closure_vars = remove_function_aliases(expr, closure_vars) # TODO Can be ommited once gt4py allows aliases
foast_node = visit_jast(expr, closure_vars)
foast_node = postprocess_definition(foast_node, closure_vars, annotations)
return foast_node, closure_vars
expr, py_closure_vars, annotations = preprocess_definiton(expr, closure_vars)
expr, py_closure_vars = remove_function_aliases(expr, py_closure_vars) # TODO Can be ommited once gt4py allows aliases
foast_node = visit_jast(expr, py_closure_vars)
foast_node = postprocess_definition(foast_node, py_closure_vars, annotations)
return foast_node, py_closure_vars
end

function preprocess_definiton(expr::Expr, closure_vars::Dict)
sat = single_assign_target_pass(expr)
ucc = unchain_compairs_pass(sat)
ssa = single_static_assign_pass(ucc)
closure_vars = translate_closure_vars(closure_vars)
annotations = get_annotation(ssa, closure_vars)
return (ssa, closure_vars, annotations)
py_closure_vars = translate_closure_vars(closure_vars)
annotations = get_annotation(ssa, py_closure_vars)
return (ssa, py_closure_vars, annotations)
end

function postprocess_definition(foast_node, closure_vars, annotations)
Expand All @@ -241,6 +241,7 @@ function postprocess_definition(foast_node, closure_vars, annotations)
return foast_node
end

# TODO(tehrengruber): unify with convert_type
py_args(args::Union{Base.Pairs, Dict}) =
Dict(i.first => convert_type(i.second) for i in args)
py_args(args::Tuple) = [convert_type(arg) for arg in args]
Expand All @@ -260,8 +261,6 @@ function convert_type(a::Field)
new_data = np.asarray(a.data)
@warn "Dtype of the Field: $a is not concrete. Data must be copied to Python which may affect performance. Try using dtypes <: Array."
end
println(new_data)
println(typeof(new_data))

offset = Dict(convert_type(dim) => a.origin[i] for (i, dim) in enumerate(a.dims))

Expand All @@ -282,6 +281,7 @@ function convert_type(a::Connectivity)

# account for different indexing in python
return gtx.NeighborTableOffsetProvider(
# TODO(lorenzovarese): fix performance (conversion from 0-index to 1-index) (caching or directly store the 0-index version of the connectivity)
ifelse.(a.data .!= -1, a.data .- 1, a.data),
target_dim,
source_dim,
Expand Down
21 changes: 13 additions & 8 deletions src/gt2py/jast_to_foast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ function visit_(sym::Val{:call}, args::Array, outer_loc)
func = visit(args[1], outer_loc),
args = [
visit(x, outer_loc) for
x in Base.tail(Tuple(args)) if (typeof(x) != Expr || x.head != :(kw))
x in args[2:end] if (typeof(x) != Expr || x.head != :(kw))
],
kwargs = Dict(
x.args[1] => visit(x.args[2], outer_loc) for
x in Base.tail(Tuple(args)) if (typeof(x) == Expr && x.head == :kw)
x in args[2:end] if (typeof(x) == Expr && x.head == :kw)
),
location = outer_loc,
)
Expand Down Expand Up @@ -345,10 +345,15 @@ function visit_(sym::Val{:(/=)}, args::Array, outer_loc)
end

function visit_(sym::Val{:ref}, args::Array, outer_loc)
if typeof(args[2]) <: Integer
if typeof(args[2]) <: Integer # TODO: also check that args[1] is an offset
# TODO(tehrengruber): This is an extremely dirty hack, we need to get
# the information from the offset provider or similar, but it is not
# available here.
is_cartesian = string(args[1])[1] in ['I', 'J', 'K']
index = is_cartesian ? args[2] : args[2]-1
return foast.Subscript(
value = visit(args[1], outer_loc),
index = args[2] - 1, # Due to different indexing in python
index = index,
location = outer_loc,
)
else
Expand Down Expand Up @@ -437,18 +442,18 @@ function from_type_hint(expr::Expr, closure_vars::Dict)
param_type = expr.args
if param_type[1] == :Tuple
return ts.TupleType(
types = [recursive_make_symbol(arg) for arg in Base.tail(param_type)]
types = [recursive_make_symbol(arg) for arg in param_type[]]
)
elseif param_type[1] == :Field
@assert length(param_type) == 3 (
"Field type requires two arguments, got $(length(param_type)-1) in $(param_type)."
)

dim = []
(dims, dtype) = param_type[2:end]

# TODO: do some sanity checks here for example Field{Int64, Dims} will fail terribly


dim = []
for d in dims.args[2:end]
@assert string(d) in keys(closure_vars)
push!(dim, closure_vars[string(d)])
Expand Down
13 changes: 5 additions & 8 deletions src/gt2py/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@ function translate_closure_vars(j_closure_vars::Dict)::Dict
new_value = nothing

if typeof(value) <: FieldOffset
py_source = map(
dim -> gtx.Dimension(
get_dim_name(dim),
kind = py_dim_kind[get_dim_kind(dim)]
),
value.source
py_source = gtx.Dimension(
get_dim_name(value.source),
kind = py_dim_kind[get_dim_kind(value.source)]
)
py_target = map(
dim -> gtx.Dimension(
Expand All @@ -73,8 +70,8 @@ function translate_closure_vars(j_closure_vars::Dict)::Dict
)
new_value = gtx.FieldOffset(
value.name,
source = length(py_source) == 1 ? py_source[1] : py_source,
target = length(py_target) == 1 ? py_target[1] : py_target
source = py_source,
target = py_target
)
elseif typeof(value) <: Function
new_value = builtin_py_op[key]
Expand Down
Loading

0 comments on commit 084cf9f

Please sign in to comment.