Skip to content

Commit

Permalink
juliaformatter+docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Sep 21, 2024
1 parent d809afb commit 33d84f5
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,34 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist
fit_prior_nsteps::Int = 100::(_ > 0)
end

@doc """
MMI.fit(m::LaplaceRegressor, verbosity, X, y)
Fit a LaplaceRegressor model using the provided features and target values.
# Arguments
- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted.
- `verbosity`: Verbosity level for logging.
- `X`: Input features, expected to be in a format compatible with MLJBase.matrix.
- `y`: Target values.
# Returns
- `fitresult`: The fitted Laplace model.
- `cache`: Currently unused, returns `nothing`.
- `report`: A tuple containing the status and message of the fitting process.
# Description
This function performs the following steps:
1. Converts the input features `X` to a matrix and transposes it.
2. Reshapes the target values `y` to shape (1,:).
3. Creates a data loader for batching the data.
4. Sets up the optimizer state using the Adam optimizer.
5. Trains the model for a specified number of epochs.
6. Initializes a Laplace model with the trained Flux model and specified parameters.
7. Fits the Laplace model using the data loader.
8. Optimizes the prior of the Laplace model.
9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success.
"""
function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
#features = Tables.schema(X).names

Expand Down Expand Up @@ -79,6 +107,22 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y)
return (fitresult, cache, report)
end

@doc """
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Predicts the response for new data using a fitted LaplaceRegressor model.
# Arguments
- `m::LaplaceRegressor`: The LaplaceRegressor model.
- `fitresult`: The result of fitting the LaplaceRegressor model.
- `Xnew`: The new data for which predictions are to be made.
# Returns
- An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances.
Finally, it creates Normal distributions from these means and variances and returns them as an array.
"""
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew) |> permutedims
la = fitresult
Expand Down Expand Up @@ -132,6 +176,31 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end

@doc """
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
Description:
This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`,
then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model.
Arguments:
- `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted.
- `verbosity`: Verbosity level for logging.
- `X`: Input data features.
- `y`: Target labels.
Returns:
- A tuple containing:
- `(la, decode)`: The fitted Laplace model and the decode function for the target labels.
- `cache`: A placeholder for any cached data (currently `nothing`).
- `report`: A report dictionary containing the status and message of the fitting process.
Notes:
- The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation.
- The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model.
"""
function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
X = MLJBase.matrix(X) |> permutedims
decode = y[1]
Expand Down Expand Up @@ -167,6 +236,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
return ((la, decode), cache, report)
end

@doc """
Predicts the class probabilities for new data using a Laplace classifier.
# Arguments
- `m::LaplaceClassifier`: The Laplace classifier model.
- `(fitresult, decode)`: A tuple containing the fitted model result and the decode function.
- `Xnew`: The new data for which predictions are to be made.
# Returns
- `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux
prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object.
"""
function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
la = fitresult
Xnew = MLJBase.matrix(Xnew) |> permutedims
Expand Down

0 comments on commit 33d84f5

Please sign in to comment.