Skip to content

Commit

Permalink
Extend the parametric keywords to include anything (#66)
Browse files Browse the repository at this point in the history
* Extend the parametric keywords to include anything

* Fix spelling an format
Azzaare authored Jun 17, 2024
1 parent 9c623ff commit f1c39de
Showing 7 changed files with 160 additions and 161 deletions.
8 changes: 4 additions & 4 deletions src/composition.jl
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ function generate(c::Composition, name, ::Val{:Julia})
co = reduce_symbols(symbs[4], ", ", false; prefix = CN * "co_")

documentation = """\"\"\"
$name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
$name(x; X = zeros(length(x), $tr_length), params...)
Composition `$name` generated by CompositionalNetworks.jl.
```
@@ -85,10 +85,10 @@ function generate(c::Composition, name, ::Val{:Julia})
"""

output = """
function $name(x; X = zeros(length(x), $tr_length), param=nothing, dom_size)
$(CN)tr_in(Tuple($tr), X, x, param)
function $name(x; X = zeros(length(x), $tr_length), dom_size, params...)
$(CN)tr_in(Tuple($tr), X, x; params)
X[1:length(x), 1] .= 1:length(x) .|> (i -> $ar(@view X[i, 1:$tr_length]))
return $ag(@view X[:, 1]) |> (y -> $co(y; param, dom_size, nvars=length(x)))
return $ag(@view X[:, 1]) |> (y -> $co(y; dom_size, nvars=length(x), params...))
end
"""
return documentation * format_text(output, BlueStyle(); pipe_to_function_call = false)
11 changes: 3 additions & 8 deletions src/icn.jl
Original file line number Diff line number Diff line change
@@ -151,16 +151,11 @@ function _compose(icn::ICN)
end
end

function composition(
x;
X = zeros(length(x), length(funcs[1])),
param = nothing,
dom_size,
)
tr_in(Tuple(funcs[1]), X, x, param)
function composition(x; X = zeros(length(x), length(funcs[1])), dom_size, params...)
tr_in(Tuple(funcs[1]), X, x; params...)
X[1:length(x), 1] .=
1:length(x) .|> (i -> funcs[2][1](@view X[i, 1:length(funcs[1])]))
return (y -> funcs[4][1](y; param, dom_size, nvars = length(x)))(
return (y -> funcs[4][1](y; dom_size, nvars = length(x), params...))(
funcs[3][1](@view X[:, 1]),
)
end
2 changes: 1 addition & 1 deletion src/layer.jl
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ end

"""
generate_exclusive_operation(max_op_number)
Generates the operations (weigths) of a layer with exclusive operations.
Generates the operations (weights) of a layer with exclusive operations.
"""
function generate_exclusive_operation(max_op_number)
op = rand(1:max_op_number)
80 changes: 39 additions & 41 deletions src/layers/comparison.jl
Original file line number Diff line number Diff line change
@@ -2,61 +2,59 @@
co_identity(x)
Identity function. Already defined in Julia as `identity`, specialized for scalars in the `comparison` layer.
"""
co_identity(x; param = nothing, dom_size = 0, nvars = 0) = identity(x)
co_identity(x; params...) = identity(x)

"""
co_abs_diff_val_param(x; param)
Return the absolute difference between `x` and `param`.
co_abs_diff_var_val(x; val)
Return the absolute difference between `x` and `val`.
"""
co_abs_diff_val_param(x; param, dom_size = 0, nvars = 0) = abs(x - param)
co_abs_diff_var_val(x; val, params...) = abs(x - val)

"""
co_val_minus_param(x; param)
Return the difference `x - param` if positive, `0.0` otherwise.
co_var_minus_val(x; val)
Return the difference `x - val` if positive, `0.0` otherwise.
"""
co_val_minus_param(x; param, dom_size = 0, nvars = 0) = max(0.0, x - param)
co_var_minus_val(x; val, params...) = max(0.0, x - val)

"""
co_param_minus_val(x; param)
Return the difference `param - x` if positive, `0.0` otherwise.
co_val_minus_var(x; val)
Return the difference `val - x` if positive, `0.0` otherwise.
"""
co_param_minus_val(x; param, dom_size = 0, nvars = 0) = max(0.0, param - x)
co_val_minus_var(x; val, params...) = max(0.0, val - x)

"""
co_euclidean_param(x; param, dom_size)
Compute an euclidean norm with domain size `dom_size`, weighted by `param`, of a scalar.
co_euclidean_val(x; val, dom_size)
Compute an euclidean norm with domain size `dom_size`, weighted by `val`, of a scalar.
"""
function co_euclidean_param(x; param, dom_size, nvars = 0)
return x == param ? 0.0 : (1.0 + abs(x - param) / dom_size)
function co_euclidean_val(x; val, dom_size, params...)
return x == val ? 0.0 : (1.0 + abs(x - val) / dom_size)
end

"""
co_euclidean(x; dom_size)
Compute an euclidean norm with domain size `dom_size` of a scalar.
"""
function co_euclidean(x; param = nothing, dom_size, nvars = 0)
return co_euclidean_param(x; param = 0.0, dom_size = dom_size)
function co_euclidean(x; dom_size, params...)
return co_euclidean_val(x; val = 0.0, dom_size)
end

"""
co_abs_diff_val_vars(x; nvars)
co_abs_diff_var_vars(x; nvars)
Return the absolute difference between `x` and the number of variables `nvars`.
"""
co_abs_diff_val_vars(x; param = nothing, dom_size = 0, nvars) = abs(x - nvars)
co_abs_diff_var_vars(x; nvars, params...) = abs(x - nvars)

"""
co_val_minus_vars(x; nvars)
co_var_minus_vars(x; nvars)
Return the difference `x - nvars` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
"""
co_val_minus_vars(x; param = nothing, dom_size = 0, nvars) =
co_val_minus_param(x; param = nvars)
co_var_minus_vars(x; nvars, params...) = co_var_minus_val(x; val = nvars)

"""
co_vars_minus_val(x; nvars)
co_vars_minus_var(x; nvars)
Return the difference `nvars - x` if positive, `0.0` otherwise, where `nvars` denotes the numbers of variables.
"""
co_vars_minus_val(x; param = nothing, dom_size = 0, nvars) =
co_param_minus_val(x; param = nvars)
co_vars_minus_var(x; nvars, params...) = co_val_minus_var(x; val = nvars)


# Parametric layers
@@ -66,18 +64,18 @@ function make_comparisons(::Val{:none})
return LittleDict{Symbol,Function}(
:identity => co_identity,
:euclidean => co_euclidean,
:abs_diff_val_vars => co_abs_diff_val_vars,
:val_minus_vars => co_val_minus_vars,
:vars_minus_val => co_vars_minus_val,
:abs_diff_var_vars => co_abs_diff_var_vars,
:var_minus_vars => co_var_minus_vars,
:vars_minus_var => co_vars_minus_var,
)
end

function make_comparisons(::Val{:val})
return LittleDict{Symbol,Function}(
:abs_diff_val_param => co_abs_diff_val_param,
:val_minus_param => co_val_minus_param,
:param_minus_val => co_param_minus_val,
:euclidean_param => co_euclidean_param,
:abs_diff_var_val => co_abs_diff_var_val,
:var_minus_val => co_var_minus_val,
:val_minus_var => co_val_minus_var,
:euclidean_val => co_euclidean_val,
)
end

@@ -113,21 +111,21 @@ end
end

funcs_param = [
CN.co_abs_diff_val_param => [2, 5],
CN.co_val_minus_param => [2, 0],
CN.co_param_minus_val => [0, 5],
CN.co_abs_diff_var_val => [2, 5],
CN.co_var_minus_val => [2, 0],
CN.co_val_minus_var => [0, 5],
]

for (f, results) in funcs_param
for (key, vals) in enumerate(data)
@test f(vals.first; param = vals.second[1]) == results[key]
@test f(vals.first; val = vals.second[1]) == results[key]
end
end

funcs_vars = [
CN.co_abs_diff_val_vars => [2, 0],
CN.co_val_minus_vars => [0, 0],
CN.co_vars_minus_val => [2, 0],
CN.co_abs_diff_var_vars => [2, 0],
CN.co_var_minus_vars => [0, 0],
CN.co_vars_minus_var => [2, 0],
]

for (f, results) in funcs_vars
@@ -136,11 +134,11 @@ end
end
end

funcs_param_dom = [CN.co_euclidean_param => [1.4, 2.0]]
funcs_val_dom = [CN.co_euclidean_val => [1.4, 2.0]]

for (f, results) in funcs_param_dom
for (f, results) in funcs_val_dom
for (key, vals) in enumerate(data)
@test f(vals.first, param = vals.second[1], dom_size = vals.second[2])
@test f(vals.first, val = vals.second[1], dom_size = vals.second[2])
results[key]
end
end
Loading

0 comments on commit f1c39de

Please sign in to comment.