Skip to content

Commit

Permalink
fix(ignite): allow to set param_groups to arbitrary values
Browse files Browse the repository at this point in the history
This is needed for compatability with LR schedulers

Resolves Issue #1260
  • Loading branch information
sebffischer authored and dfalbel committed Jan 29, 2025
1 parent 2cf01b4 commit cb1b09a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 32 deletions.
67 changes: 36 additions & 31 deletions R/ignite.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ OptimizerIgnite <- R6::R6Class(
opts
}

private$.additional_param_groups <- list(list())

if (is.list(params) && is.list(params[[1]]$params)) {
opts <- helper(params[[1]])
self$ptr <- do.call(private$.optim, c(list(params = params[[1]]$params), opts))
Expand All @@ -49,7 +51,7 @@ OptimizerIgnite <- R6::R6Class(
#' they belong, converted to character.
#' @return (`list()`)
state_dict = function() {
stop("Abstract method")
extract_ignite_state_dict(self, private$.get_states(self$ptr), private$.state_names)
},
#' @description
#' Loads the state dictionary into the optimizer.
Expand Down Expand Up @@ -110,13 +112,15 @@ OptimizerIgnite <- R6::R6Class(
param_group <- c(param_group, self$defaults[!(names(self$defaults) %in% names(param_group))])
do.call(private$.assert_params, param_group)
do.call(private$.add_param_group, c(list(opt = self$ptr, params = params), param_group))
private$.additional_param_groups[[length(private$.additional_param_groups) + 1]] <- list()
}
),
active = list(
#' @description
#' The parameter groups of the optimizer.
param_groups = function(rhs) {
if (!missing(rhs)) {
cpp_names <- c("params", private$.config_names)
prev_param_groups <- self$param_groups
all_params = unlist(lapply(prev_param_groups, function(x) x$params))
if (!is.list(rhs) && length(rhs) == length(prev_param_groups)) {
Expand All @@ -125,32 +129,35 @@ OptimizerIgnite <- R6::R6Class(
walk(seq_along(prev_param_groups), function(i) {
prev_param_group <- prev_param_groups[[i]]
new_param_group <- rhs[[i]]
if (!is_permutation(names(new_param_group), names(prev_param_group))) {
value_error("Parameter groups must have names {paste0(names(prev_param_group), collapse = ', ')} but got {paste0(names(new_param_group), collapse = ', ')}.")
if (!is_subset(cpp_names, names(new_param_group))) {
value_error("Parameter groups must include names '{paste0(cpp_names, collapse = ', ')}' but only included '{paste0(names(new_param_group), collapse = ', ')}'.")
}

new_param_group_additional <- new_param_group[!(names(new_param_group) %in% cpp_names)]
private$.additional_param_groups[[i]] <- new_param_group_additional
param_cmp_value = if (is.integer(new_param_group$params)) {
all_params[new_param_group$params]
} else {
new_param_group$params
}

if (!identical(prev_param_group$params, param_cmp_value)) {
print(prev_param_group$params)
print(new_param_group$params)
value_error("Cannot change the parameter groups, use `$add_param_group()` to add a new parameter group.")
}

private$.set_param_group_options(self$ptr, rhs)
})
# the possible additional param groups are simply ignored
private$.set_param_group_options(self$ptr, rhs)
}
private$.get_param_groups(self$ptr)
pgs = private$.get_param_groups(self$ptr)
lapply(seq_along(pgs), function(i) {
c(pgs[[i]], private$.additional_param_groups[[i]])
})
}
),
private = list(
.additional_param_groups = NULL,
.optim = function(params, ...) stop("Abstract method"),
.set_states = function(ptr, params, states) stop("Abstract method"),
.add_param_group = function(ptr, params, options) stop("Abstract method"),
.get_states = function(ptr) stop("Abstract method"),
.assert_params = function(...) stop("Abstract method"),
.set_param_group_options = function(ptr, options) stop("Abstract method"),
.get_param_groups = function(ptr) stop("Abstract method")
Expand Down Expand Up @@ -195,14 +202,13 @@ optim_ignite_adagrad <- optimizer_ignite(
initial_accumulator_value = 0, eps = 1e-10) {
super$initialize(params, defaults = list(lr = lr, lr_decay = lr_decay, weight_decay = weight_decay, initial_accumulator_value = initial_accumulator_value, eps = eps))
},
state_dict = function() {
extract_ignite_state_dict(self, rcpp_ignite_adagrad_get_states(self$ptr), c("sum", "step"))
},
private = list(
.optim = function(params, ...) {
rcpp_ignite_adagrad(params = params, ...)
},

.get_states = rcpp_ignite_adagrad_get_states,
.state_names = c("sum", "step"),
.config_names = c("lr", "lr_decay", "weight_decay", "initial_accumulator_value", "eps"),
.set_states = rcpp_ignite_adagrad_set_states,
.add_param_group = rcpp_ignite_adagrad_add_param_group,
.assert_params = assert_adagrad_params,
Expand Down Expand Up @@ -234,14 +240,13 @@ optim_ignite_rmsprop <- optimizer_ignite(
weight_decay = 0, momentum = 0, centered = FALSE) {
super$initialize(params, defaults = list(lr = lr, alpha = alpha, eps = eps, weight_decay = weight_decay, momentum = momentum, centered = centered))
},
state_dict = function() {
extract_ignite_state_dict(self, rcpp_ignite_rmsprop_get_states(self$ptr), c("grad_avg", "square_avg", "momentum_buffer", "step"))
},
private = list(
.optim = function(params, ...) {
rcpp_ignite_rmsprop(params = params, ...)
},

.get_states = rcpp_ignite_rmsprop_get_states,
.state_names = c("grad_avg", "square_avg", "momentum_buffer", "step"),
.config_names = c("lr", "alpha", "eps", "weight_decay", "momentum", "centered"),
.set_states = rcpp_ignite_rmsprop_set_states,
.add_param_group = rcpp_ignite_rmsprop_add_param_group,
.assert_params = assert_rmsprop_params,
Expand Down Expand Up @@ -273,13 +278,13 @@ optim_ignite_sgd <- optimizer_ignite(
weight_decay = 0, nesterov = FALSE) {
super$initialize(params, defaults = list(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov))
},
state_dict = function() {
extract_ignite_state_dict(self, rcpp_ignite_sgd_get_states(self$ptr), "momentum_buffer")
},
private = list(
.optim = function(params, ...) {
rcpp_ignite_sgd(params = params, ...)
},
.get_states = rcpp_ignite_sgd_get_states,
.state_names = "momentum_buffer",
.config_names = c("lr", "momentum", "dampening", "weight_decay", "nesterov"),
.set_states = rcpp_ignite_sgd_set_states,
.add_param_group = rcpp_ignite_sgd_add_param_group,
.assert_params = assert_sgd_params,
Expand Down Expand Up @@ -311,15 +316,13 @@ optim_ignite_adam <- optimizer_ignite(
weight_decay = 0, amsgrad = FALSE) {
super$initialize(params, defaults = list(lr = lr, betas = betas, eps = eps, weight_decay = weight_decay, amsgrad = amsgrad))
},
state_dict = function() {
extract_ignite_state_dict(self, rcpp_ignite_adam_get_states(self$ptr),
c("exp_avg", "exp_avg_sq", "max_exp_avg_sq", "step"))
},
private = list(
.optim = function(params, ...) {
rcpp_ignite_adam(params = params, ...)
},

.get_states = rcpp_ignite_adam_get_states,
.config_names = c("lr", "betas", "eps", "weight_decay", "amsgrad"),
.state_names = c("exp_avg", "exp_avg_sq", "max_exp_avg_sq", "step"),
.set_states = rcpp_ignite_adam_set_states,
.add_param_group = rcpp_ignite_adam_add_param_group,
.assert_params = assert_adam_params,
Expand Down Expand Up @@ -351,15 +354,13 @@ optim_ignite_adamw <- optimizer_ignite(
weight_decay = 1e-2, amsgrad = FALSE) {
super$initialize(params, defaults = list(lr = lr, betas = betas, eps = eps, weight_decay = weight_decay, amsgrad = amsgrad))
},
state_dict = function() {
extract_ignite_state_dict(self, rcpp_ignite_adamw_get_states(self$ptr),
c("exp_avg", "exp_avg_sq", "max_exp_avg_sq", "step"))
},
private = list(
.optim = function(params, ...) {
rcpp_ignite_adamw(params = params, ...)
},
.step = rcpp_ignite_optim_step,
.get_states = rcpp_ignite_adamw_get_states,
.config_names = c("lr", "betas", "eps", "weight_decay", "amsgrad"),
.state_names = c("exp_avg", "exp_avg_sq", "max_exp_avg_sq", "step"),
.set_states = rcpp_ignite_adamw_set_states,
.add_param_group = rcpp_ignite_adamw_add_param_group,
.assert_params = assert_adamw_params,
Expand All @@ -371,6 +372,10 @@ optim_ignite_adamw <- optimizer_ignite(
)
)

is_subset <- function(vec1, vec2) {
all(vec1 %in% vec2)
}

is_permutation <- function(vec1, vec2) {
# Check if lengths are the same
if (length(vec1) != length(vec2)) {
Expand Down
29 changes: 28 additions & 1 deletion tests/testthat/test-ignite.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ test_that("base class: error handling when loading state dict", {
expect_error(o$load_state_dict(sd2), "The 1-th state has elements with names exp_avg")
sd3 = o$state_dict()
sd3$param_groups[[1]]$lr = NULL
expect_error(o$load_state_dict(sd3), "but got params, weight_decay")
expect_error(o$load_state_dict(sd3), "must include names 'params")
})

test_that("base class: deep cloning not possible", {
Expand Down Expand Up @@ -240,3 +240,30 @@ test_that("base class: changing the learning rate has an effect", {
s(n2, o2)
expect_false(torch_equal(n1$parameters[[1]], n2$parameters[[1]]) && torch_equal(n1$parameters[[2]], n2$parameters[[2]]))
})


test_that("can specify additional param_groups", {
o = optim_ignite_adamw(list(torch_tensor(1, requires_grad = TRUE)), lr = 0.1)
o$param_groups[[1]]$initial_lr = 0.2
expect_equal(o$param_groups[[1]]$initial_lr, 0.2)
expect_equal(o$state_dict()$param_groups[[1]]$initial_lr, 0.2)
o$param_groups[[1]]$initial_lr = 0.3
expect_equal(o$param_groups[[1]]$initial_lr, 0.3)
expect_equal(o$state_dict()$param_groups[[1]]$initial_lr, 0.3)

o$param_groups[[1]]$initial_lr = NULL
expect_equal(o$param_groups[[1]]$initial_lr, NULL)
expect_equal(o$state_dict()$param_groups[[1]]$initial_lr, NULL)

o = optim_ignite_adamw(params = list(
list(params = list(torch_tensor(1, requires_grad = TRUE)), lr = 0.1),
list(params = list(torch_tensor(1, requires_grad = TRUE)), lr = 0.2)
))

o$param_groups[[1]]$initial_lr = 0.1
o$param_groups[[2]]$initial_lr = 0.2
expect_equal(o$param_groups[[1]]$initial_lr, 0.1)
expect_equal(o$param_groups[[2]]$initial_lr, 0.2)
expect_equal(o$state_dict()$param_groups[[1]]$initial_lr, 0.1)
expect_equal(o$state_dict()$param_groups[[2]]$initial_lr, 0.2)
})

0 comments on commit cb1b09a

Please sign in to comment.