Skip to content

Commit

Permalink
Merge pull request #14 from JuliaSmoothOptimizers/PR-Codecov
Browse files Browse the repository at this point in the history
Pr codecov
  • Loading branch information
farhadrclass authored Jul 6, 2023
2 parents 5ebf17d + cd5defe commit 54a840d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/FluxNLPModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ function FluxNLPModel(
chain_ANN::T,
data_train,
data_test;
current_training_minibatch = first(data_train),
current_test_minibatch = first(data_test),
current_training_minibatch = [],
current_test_minibatch = [],
size_minibatch::Int = 100,
loss_f::F = Flux.mse, #Flux.crossentropy,
) where {T <: Chain, F <: Function}
Expand All @@ -66,6 +66,10 @@ function FluxNLPModel(
if (isempty(data_train) || isempty(data_test))
error("train data or test is empty")
end
if (isempty(current_training_minibatch) || isempty(current_test_minibatch))
current_training_minibatch = first(data_train)
current_test_minibatch = first(data_test)
end

return FluxNLPModel(
meta,
Expand Down
27 changes: 25 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,29 @@ device = cpu #TODO should we test on GPU?
println(norm(grad_x1 - grad_x1_2))
@test norm(grad_x1 - grad_x1_2) 0.0

# @test grad_x1 ≈ grad_x1_2
# @test all(grad_x1 .≈ grad_x1_2)
@test x1 == DNNLPModel.w
@test Flux.params(DNNLPModel.chain)[1][1] == x1[1]
@test Flux.params(DNNLPModel.chain)[1][2] == x1[2]

@test_throws Exception FluxNLPModel(DN, [], test_data) # if the train data is empty
@test_throws Exception FluxNLPModel(DN, train_data, []) # if the test data is empty
@test_throws Exception FluxNLPModel(DN, [], []) # if the both data is empty

# Testing if the value of the first batch was passed it
DNNLPModel_2 = FluxNLPModel(
DN,
train_data,
test_data,
current_training_minibatch = first(train_data),
current_test_minibatch = first(test_data),
)

#checking if we can call accuracy
train_acc = FluxNLPModels.accuracy(DNNLPModel_2; data_loader = train_data) # accuracy on train data
test_acc = FluxNLPModels.accuracy(DNNLPModel_2) # on the test data

@test train_acc >= 0.0
@test train_acc <= 1.0
end

@testset "minibatch tests" begin
Expand All @@ -88,5 +106,10 @@ end
@test nlp.current_training_minibatch_status === nothing
buffer_minibatch = deepcopy(nlp.current_training_minibatch)
@test minibatch_next_train!(nlp) # should return true
@test minibatch_next_train!(nlp) # should return true
@test !isequal(nlp.current_training_minibatch, buffer_minibatch)

reset_minibatch_test!(nlp)
@test minibatch_next_test!(nlp) # should return true
@test minibatch_next_test!(nlp) # should return true
end

0 comments on commit 54a840d

Please sign in to comment.