Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct mlj interface #126

Merged
merged 62 commits into from
Nov 7, 2024
Merged
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
51ed526
new first commit
pasq-cat Sep 12, 2024
e8e96d1
various stuff
pasq-cat Sep 12, 2024
255cf19
fixes
pasq-cat Sep 12, 2024
35ac2d8
changes
pasq-cat Sep 13, 2024
44469a2
there is still a problem with the classifier
pasq-cat Sep 18, 2024
5315264
almost fixed
pasq-cat Sep 18, 2024
f6e7a00
works but i have to fix the hyperparameters
pasq-cat Sep 18, 2024
d7c4f7b
question on parameters....
pasq-cat Sep 18, 2024
bbab460
there is some problem with the one hot encoding
pasq-cat Sep 18, 2024
8af38ae
fixed error in univariatefinite
pasq-cat Sep 19, 2024
fe19d4d
performance improvement
pasq-cat Sep 19, 2024
d809afb
JuliaFormatter
pasq-cat Sep 21, 2024
33d84f5
juliaformatter+docstrings
pasq-cat Sep 21, 2024
9731297
removed predict_proba and ret_Distr from the struct
pasq-cat Sep 21, 2024
f70d239
mlj docstring in progress
pasq-cat Sep 21, 2024
80c6553
ah fixed constant , added prototype for regression
pasq-cat Sep 21, 2024
d1c895c
small stuff here and there in the docstring plus
pasq-cat Sep 21, 2024
19ffa16
still writing this long ass docstring
pasq-cat Sep 21, 2024
de0bd91
added fit_params functions
pasq-cat Sep 22, 2024
87df85f
switched to customized loop
pasq-cat Sep 22, 2024
24459a1
fixed error in custom loop
pasq-cat Sep 22, 2024
0e2ca03
various fixes
pasq-cat Sep 22, 2024
841d5eb
added reformat. must updated again the doc string....
pasq-cat Sep 22, 2024
de784f1
work on the docstring and then made it in a module
pasq-cat Sep 22, 2024
b7a99f6
fixed uuid, made test file.for direct_mlj. shut down the tests for ml…
pasq-cat Sep 23, 2024
c44b8d8
added tests. should be good....
pasq-cat Sep 23, 2024
b762185
added mlj to the dependency in test
pasq-cat Sep 23, 2024
ced3da0
prep for update + added mljmodelinterface to doc env
pasq-cat Sep 25, 2024
b700f85
changed the loop so that it nows uses optimisers from optimisers.jl
pasq-cat Oct 1, 2024
da6fc76
started joining the functions in a single common function for both mo…
pasq-cat Oct 3, 2024
70df568
various fixes
pasq-cat Oct 4, 2024
9889872
merged functions for both cases
pasq-cat Oct 4, 2024
0f46fd6
julia formatter
pasq-cat Oct 4, 2024
ab8b6bf
added unit tests
pasq-cat Oct 15, 2024
263cc67
more units
pasq-cat Oct 15, 2024
453b49f
fix
pasq-cat Oct 15, 2024
656b24e
changed unit test and a minor fix in the update function. there is st…
pasq-cat Oct 16, 2024
7c4d744
only things left to fix are the selectrows functions
pasq-cat Oct 16, 2024
f872d96
returning one-hot encoded directly
pat-alt Oct 16, 2024
71a3611
nearly there I think
pat-alt Oct 16, 2024
74d778e
one more issue with regression
pat-alt Oct 16, 2024
80784bb
fixed predict so that it return a vector of distributions-> fixed eva…
pasq-cat Oct 18, 2024
be80e32
amend: fixed predict so that it return a vector of distributions-> fi…
pasq-cat Oct 18, 2024
d426844
Merge branch 'direct_mlj_interface' of https://github.com/JuliaTrustw…
pasq-cat Oct 18, 2024
f4fcd95
madea mess with commits.... bah
pasq-cat Oct 18, 2024
851784f
trying to increase patch coverage
pasq-cat Oct 18, 2024
0752b83
fkn hell this codecov bot is worse than the inquisition
pasq-cat Oct 18, 2024
573ffd8
uhmmmmmm
pasq-cat Oct 21, 2024
db14b84
fixed _isdefined
pasq-cat Oct 21, 2024
82c5714
trying to fix docs issue and no longer importing MLJ nor MLJBase name…
pat-alt Oct 22, 2024
7202013
formatting
pat-alt Oct 22, 2024
a05e25f
removing mlj_flux
pat-alt Oct 22, 2024
05df2e1
fixed issues
pat-alt Oct 22, 2024
02abec2
removing reference to deep_propertier
pat-alt Oct 23, 2024
374aca5
hadn't saved file
pat-alt Oct 23, 2024
06343bd
Merge branch 'main' into local_direct_mlj
pasq-cat Oct 29, 2024
59917f8
added default mlp
pasq-cat Oct 29, 2024
9da5d7d
reducing number of epochs and trying to extende patch coverage
pasq-cat Oct 30, 2024
503763d
removed the else because it seems to have no role.
pasq-cat Oct 30, 2024
e78e9b8
ops forgot to remove the comment
pasq-cat Oct 30, 2024
6a5f26f
various change in the documentation
pasq-cat Oct 30, 2024
12f2584
ufffffffffffffffffffff
pasq-cat Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
ufffffffffffffffffffff
pasq-cat committed Oct 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 12f2584462f5dbad7090941f3d489071fcb258a8
8 changes: 4 additions & 4 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@
n_input = size(X, 1)
dims = size(y)
if length(dims) == 1
n_output = 1

Check warning on line 86 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L86

Added line #L86 was not covered by tests
else
n_output = dims[1]
end
@@ -110,7 +110,7 @@
chain = Chain(
Dense(n_input, 20, relu),
Dense(20, 20, relu),
#Dense(20, 20, relu),
Dense(20, 20, relu),
Dense(20, n_output)
)

@@ -140,7 +140,7 @@
y, decode = y

if (m.model === nothing)
@warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 3 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
@warn "Warning: no Flux model has been provided in the model. LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each and input and output layers compatible with the dataset."
shape = dataset_shape(m, X, y)

m.model = default_build(11, shape)
@@ -182,7 +182,7 @@

# Print loss every 100 epochs if verbosity is 1 or more
if verbosity >= 1 && epoch % 100 == 0
println("Epoch $epoch: Loss: $loss_per_epoch ")

Check warning on line 185 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L185

Added line #L185 was not covered by tests
end
end

@@ -386,10 +386,10 @@
for name in names
if !(name in exceptions)
if !_isdefined(m1, name)
!_isdefined(m2, name) || return false

Check warning on line 389 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L389

Added line #L389 was not covered by tests
elseif _isdefined(m2, name)
if name in MLJBase.deep_properties(LaplaceRegressor)
_equal_to_depth_one(getproperty(m1, name), getproperty(m2, name)) ||

Check warning on line 392 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L392

Added line #L392 was not covered by tests
return false
else
(
@@ -404,7 +404,7 @@
) || return false
end
else
return false

Check warning on line 407 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L407

Added line #L407 was not covered by tests
end
end
end
@@ -419,21 +419,21 @@

function _equal_flux_chain(chain1::Flux.Chain, chain2::Flux.Chain)
if length(chain1.layers) != length(chain2.layers)
return false

Check warning on line 422 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L422

Added line #L422 was not covered by tests
end
params1 = Flux.params(chain1)
params2 = Flux.params(chain2)
if length(params1) != length(params2)
return false

Check warning on line 427 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L427

Added line #L427 was not covered by tests
end
for (p1, p2) in zip(params1, params2)
if !isequal(p1, p2)
return false

Check warning on line 431 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L431

Added line #L431 was not covered by tests
end
end
for (layer1, layer2) in zip(chain1.layers, chain2.layers)
if typeof(layer1) != typeof(layer2)
return false

Check warning on line 436 in src/direct_mlj.jl

Codecov / codecov/patch

src/direct_mlj.jl#L436

Added line #L436 was not covered by tests
end
end
return true
@@ -605,7 +605,7 @@
# Hyperparameters (format: name-type-default value-restrictions)
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
@@ -743,7 +743,7 @@
# Hyperparameters (format: name-type-default value-restrictions)
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 3 hidden layer with 20 neurons each.
- `model::Union{Flux.Chain,Nothing} = nothing`: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layer with 20 neurons each.
- `flux_loss = Flux.Losses.logitcrossentropy` : a Flux loss function
- `optimiser = Adam()` a Flux optimiser