Skip to content

finetune.cpp command-line arg #13873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")

# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

Expand Down
51 changes: 51 additions & 0 deletions common/arg.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not autoformat files in the same PR where you make functional changes. It creates a lot of unnecessary work for maintainers. As I said, please fix your environment to avoid doing this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy re: resolve.
as i said, the intention is to autoformat only the new code i add. if i accidentally changed other lines and they were affected, i'm happy to revert

Original file line number Diff line number Diff line change
Expand Up @@ -3376,5 +3376,56 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(common_arg({ "-save", "--opt-save-model-to" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %s); note: sgd alpha recommended ~10x (no momentum)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to update this string?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes ty :)

params.opt_save_model_to.c_str()),
[](common_params & params, const std::string & value) { params.opt_save_model_to = value; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
(double) params.lr.lr),
[](common_params & params, const std::string & value) { params.lr.lr = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{ "-lr-half", "--learning-rate-halflife-epochs" }, "N",
string_format("reduce lr in half every N epochs (default: %.3g)", (double) params.lr.halflife_epochs),
[](common_params & params, const std::string & value) { params.lr.halflife_epochs = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-lr-halvings", "--learning-rate-halvings" }, "N",
string_format("max N lr halvings (default: %.3g)", (double) params.lr.halvings),
[](common_params & params, const std::string & value) { params.lr.halvings = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
Comment on lines +3393 to +3400
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me the more intuitive parameterization of a decaying learning rate would be to set a minimum value for the learning rate rather than a maximum number of times the learning rate is halved.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but then you need to distinguish one being explicitly specified vs not. i.e. mine can be left at default while tweaking just -lr
Not a big deal, lmk

add_opt(common_arg(
{ "-wd", "--weight-decay" }, "WD",
string_format(
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
(double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-val", "--val-split" }, "FRACTION",
string_format("portion of data to use as validation when optimizing (default: %.2g).",
(double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-period", "--opt-period" }, "N",
string_format("make logical batch this multiple of physical batch - needs more memory for accumulation if >1 (default: %d)", params.opt_period),
[](common_params & params, int opt_period) { params.opt_period = opt_period; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
Comment on lines +3417 to +3420
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be parametrized by setting the logical and physical batch sizes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was, but (correct me if I'm wrong) other opt init code in master adjusts the batch sizes to the minimum of logical, physical. I didn't have the confidence to mess with that. We can drop it and leave it for someone else to investigate if you prefer.

add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = ggml_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
})
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

return ctx_arg;
}
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1535,3 +1535,11 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std

return result;
}

ggml_opt_optimizer_params common_lr_opt_pars(void * userdata) {
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
const lr_opt & d = *(lr_opt *) userdata;
result.adamw.alpha = result.sgd.alpha = d.decayed(d.epoch);
result.sgd.wd = result.adamw.wd = d.wd;
return result;
}
35 changes: 32 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

#pragma once

#include "llama-cpp.h"

#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>
#include <sstream>
#include <cmath>

#include "ggml-opt.h"
#include "llama-cpp.h"

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
Expand Down Expand Up @@ -80,6 +82,7 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_FINETUNE,

LLAMA_EXAMPLE_COUNT,
};
Expand Down Expand Up @@ -219,6 +222,25 @@ enum common_reasoning_format {
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};

struct lr_decay {
float lr = 1e-5;
float halflife_epochs = 100;
float halvings = 10;

float decayed(float epoch) const {
float maxepoch = halvings * halflife_epochs;
return lr * std::pow(.5, (epoch > maxepoch ? maxepoch : epoch) / halflife_epochs);
}
};

struct lr_opt : lr_decay {
float epoch = 0;
float wd = 0;
unsigned epochs = 2;
};

struct ggml_opt_optimizer_params common_lr_opt_pars(void * userdata);
Comment on lines +225 to +242
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a decaying learning rate is a common enough use case in machine learning that it would make sense to implement in ggml-opt.

The way you've implemented it the learning rate will be scaled down by discrete factors of 2 rather than a smooth decay. Is this intentional?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was intentional to have a step only at the end of a full epoch; I realize this is not how everyone does it.
I don't agree that discrete factors of 2 follow, halflife is a configurable (possibly fractional) # of epochs.


struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
Expand Down Expand Up @@ -350,6 +372,13 @@ struct common_params {
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)

// finetune
struct lr_opt lr;
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
float val_split = 0.05f; // fraction of data used for validation when optimizing
int32_t opt_period = 1;
std::string opt_save_model_to = "finetuned-model.gguf";

// embedding
bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
Expand Down
70 changes: 37 additions & 33 deletions examples/training/finetune.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <vector>

#include "arg.h"
#include "common.h"
#include "llama.h"
#include "log.h"

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
# pragma warning(disable : 4244 4267) // possible loss of data
#endif



int main(int argc, char ** argv) {
common_params params;

params.escape = false;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
return 1;
}

if (params.use_mmap) {
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
__func__);
params.use_mmap = false;
}
if (params.cache_type_k != GGML_TYPE_F32) {
Expand All @@ -38,11 +40,11 @@ int main(int argc, char ** argv) {
common_init();
llama_backend_init();
llama_numa_init(params.numa);

// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
auto pctx = ctx.get();

if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
Expand All @@ -55,31 +57,33 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}

constexpr float val_split = 0.05f;

std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);

struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
optimizer_params.adamw.alpha = 1e-7f; // learning rate

struct llama_opt_params lopt_params {
/*n_ctx_train =*/ 0,
/*param_filter =*/ llama_opt_param_filter_all,
/*param_filter_ud =*/ nullptr,
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
/*get_opt_pars_ud =*/ &optimizer_params,
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);

auto & lr = params.lr;
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -lr-half %.2g -epochs %d -period %d -val %.2g\n",
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr, (double) lr.wd, (double) lr.halflife_epochs,
(unsigned) lr.epochs, (unsigned) params.opt_period, (double) params.val_split);

struct llama_opt_params lopt_params{
/*n_ctx_train =*/0,
/*param_filter =*/llama_opt_param_filter_all,
/*param_filter_ud =*/nullptr,
/*get_opt_pars =*/common_lr_opt_pars,
/*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer,
/*opt_period =*/params.opt_period,
};
llama_opt_init(ctx.get(), model.get(), lopt_params);
llama_opt_init(pctx, model.get(), lopt_params);

const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);

ggml_opt_result_t result_train = ggml_opt_result_init();
ggml_opt_result_t result_eval = ggml_opt_result_init();

for (int epoch = 0; epoch < 2; ++epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
for (unsigned epoch = 0; epoch < lr.epochs; ++epoch) {
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");

ggml_opt_result_reset(result_train);
Expand All @@ -88,7 +92,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);

llama_model_save_to_file(model.get(), "finetuned-model.gguf");
llama_model_save_to_file(model.get(), params.opt_save_model_to.c_str());

llama_backend_free();

Expand Down
35 changes: 27 additions & 8 deletions ggml/include/ggml-opt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,30 @@ extern "C" {
GGML_OPT_BUILD_TYPE_OPT = 30,
};

enum ggml_opt_optimizer_type {
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
GGML_OPT_OPTIMIZER_TYPE_SGD,

GGML_OPT_OPTIMIZER_TYPE_COUNT
};

// "adamw" or "sgd" (case insensitive)
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
GGML_API enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char *);

// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
struct ggml_opt_optimizer_params {
// AdamW optimizer parameters
struct {
float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float wd; // weight decay for AdamW, use 0.0f to disable
float alpha; // learning rate
float beta1; // adamw
float beta2; // adamw
Comment on lines +92 to +93
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float beta1; // adamw
float beta2; // adamw
float beta1; // first AdamW momentum
float beta2; // second AdamW momentum

float eps; // epsilon for numerical stability
float wd; // weight decay - 0.0f to disable
} adamw;
struct {
float alpha; // learning rate
float wd; // weight decay
} sgd;
};

// callback to calculate optimizer parameters prior to a backward pass
Expand Down Expand Up @@ -113,7 +127,10 @@ extern "C" {
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done

ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters

// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
enum ggml_opt_optimizer_type optimizer;
};

// get parameters for an optimization context with defaults set where possible
Expand Down Expand Up @@ -186,7 +203,7 @@ extern "C" {
// The second context should contain all other tensors and will be (re)allocated automatically.
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
// 4. Call ggml_opt_fit. If you need more control (e.g. optimizer sgd) you can use ggml_opt_epoch instead.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// 4. Call ggml_opt_fit. If you need more control (e.g. optimizer sgd) you can use ggml_opt_epoch instead.
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.

There is SGD support in ggml_opt_fit so did you perhaps forget to adjust this comment again?


// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
typedef void (*ggml_opt_epoch_callback)(
Expand Down Expand Up @@ -226,12 +243,14 @@ extern "C" {
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
int64_t nepoch, // how many times the dataset should be iterated over
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
bool silent); // whether or not info prints to stderr should be suppressed

GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this declaration upwards so that it's in the same place as the other getters for ggml_opt_context (remember to also move the implementation).

#ifdef __cplusplus
}
#endif
13 changes: 11 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ extern "C" {
GGML_OP_REPEAT_BACK,
GGML_OP_CONCAT,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
Expand Down Expand Up @@ -486,7 +486,7 @@ extern "C" {
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_POOL_2D_BACK,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_ARANGE,
Expand Down Expand Up @@ -517,6 +517,7 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW,
GGML_OP_OPT_STEP_SGD,

GGML_OP_COUNT,
};
Expand Down Expand Up @@ -2063,6 +2064,14 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * adamw_params); // parameters such a the learning rate

// SGD (with weight decay) step
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
// params: alpha (learning rate), wd (weight decay)
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * grad,
struct ggml_tensor * adamw_params);

//
// automatic differentiation
//
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
ggml_compute_forward_opt_step_adamw(params, tensor);
}
break;
case GGML_OP_OPT_STEP_SGD:
{
ggml_compute_forward_opt_step_sgd(params, tensor);
}
break;
case GGML_OP_NONE:
{
// nop
Expand Down Expand Up @@ -2345,6 +2350,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
{
n_tasks = n_threads;
} break;
Expand Down
Loading