From 33d84f570c1fcbbcd91c3ab3c7c0c80d10bdd8eb Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Sat, 21 Sep 2024 07:22:10 +0200 Subject: [PATCH] juliaformatter+docstrings --- src/direct_mlj.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 7375825..77409e1 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -43,6 +43,34 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJFlux.MLJFluxProbabilist fit_prior_nsteps::Int = 100::(_ > 0) end +@doc """ + MMI.fit(m::LaplaceRegressor, verbosity, X, y) + +Fit a LaplaceRegressor model using the provided features and target values. + +# Arguments +- `m::LaplaceRegressor`: The LaplaceRegressor model to be fitted. +- `verbosity`: Verbosity level for logging. +- `X`: Input features, expected to be in a format compatible with MLJBase.matrix. +- `y`: Target values. + +# Returns +- `fitresult`: The fitted Laplace model. +- `cache`: Currently unused, returns `nothing`. +- `report`: A tuple containing the status and message of the fitting process. + +# Description +This function performs the following steps: +1. Converts the input features `X` to a matrix and transposes it. +2. Reshapes the target values `y` to shape (1,:). +3. Creates a data loader for batching the data. +4. Sets up the optimizer state using the Adam optimizer. +5. Trains the model for a specified number of epochs. +6. Initializes a Laplace model with the trained Flux model and specified parameters. +7. Fits the Laplace model using the data loader. +8. Optimizes the prior of the Laplace model. +9. Returns the fitted Laplace model, a cache (currently `nothing`), and a report indicating success. +""" function MMI.fit(m::LaplaceRegressor, verbosity, X, y) #features = Tables.schema(X).names @@ -79,6 +107,22 @@ function MMI.fit(m::LaplaceRegressor, verbosity, X, y) return (fitresult, cache, report) end +@doc """ +function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) + + Predicts the response for new data using a fitted LaplaceRegressor model. + + # Arguments + - `m::LaplaceRegressor`: The LaplaceRegressor model. + - `fitresult`: The result of fitting the LaplaceRegressor model. + - `Xnew`: The new data for which predictions are to be made. + + # Returns + - An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`. + + The function first converts `Xnew` to a matrix and permutes its dimensions. It then uses the `LaplaceRedux.predict` function to obtain the predicted means and variances. +Finally, it creates Normal distributions from these means and variances and returns them as an array. +""" function MMI.predict(m::LaplaceRegressor, fitresult, Xnew) Xnew = MLJBase.matrix(Xnew) |> permutedims la = fitresult @@ -132,6 +176,31 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilis link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end +@doc """ + + function MMI.fit(m::LaplaceClassifier, verbosity, X, y) + + Description: + This function fits a LaplaceClassifier model using the provided data. It first preprocesses the input data `X` and target labels `y`, + then trains a neural network model using the Flux library. After training, it fits a Laplace approximation to the trained model. + + Arguments: + - `m::LaplaceClassifier`: The LaplaceClassifier model to be fitted. + - `verbosity`: Verbosity level for logging. + - `X`: Input data features. + - `y`: Target labels. + + Returns: + - A tuple containing: + - `(la, decode)`: The fitted Laplace model and the decode function for the target labels. + - `cache`: A placeholder for any cached data (currently `nothing`). + - `report`: A report dictionary containing the status and message of the fitting process. + + Notes: + - The function uses the Flux library for neural network training and the LaplaceRedux library for fitting the Laplace approximation. + - The `optimize_prior!` function is called to optimize the prior parameters of the Laplace model. + +""" function MMI.fit(m::LaplaceClassifier, verbosity, X, y) X = MLJBase.matrix(X) |> permutedims decode = y[1] @@ -167,6 +236,20 @@ function MMI.fit(m::LaplaceClassifier, verbosity, X, y) return ((la, decode), cache, report) end +@doc """ +Predicts the class probabilities for new data using a Laplace classifier. + + # Arguments + - `m::LaplaceClassifier`: The Laplace classifier model. + - `(fitresult, decode)`: A tuple containing the fitted model result and the decode function. + - `Xnew`: The new data for which predictions are to be made. + + # Returns + - `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data. + +The function transforms the new data `Xnew` into a matrix, applies the LaplaceRedux +prediction function, and then returns the predictions as a `MLJBase.UnivariateFinite` object. +""" function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew) la = fitresult Xnew = MLJBase.matrix(Xnew) |> permutedims