From 9283e25b138d5b2a09315bea7a187513de101014 Mon Sep 17 00:00:00 2001 From: "qiangliu.7@outlook.com" Date: Tue, 29 Oct 2024 10:24:01 +0100 Subject: [PATCH] add more options to get and apply the gradient from/to the neural networks --- conflictfree/utils.py | 97 +++++++++++++++++++++++++++++++++++-------- docs/api/utils.md | 1 + 2 files changed, 81 insertions(+), 17 deletions(-) diff --git a/conflictfree/utils.py b/conflictfree/utils.py index 5c9b6cb..af1c3c5 100644 --- a/conflictfree/utils.py +++ b/conflictfree/utils.py @@ -1,11 +1,11 @@ # usr/bin/python3 # -*- coding: UTF-8 -*- from . import * -from warnings import warn import numpy as np +from typing import Literal -def get_para_vector(network) -> torch.Tensor: +def get_para_vector(network: torch.nn.Module) -> torch.Tensor: """ Returns the parameter vector of the given network. @@ -26,15 +26,22 @@ def get_para_vector(network) -> torch.Tensor: return para_vec -def get_gradient_vector(network, jump_none=True) -> torch.Tensor: +def get_gradient_vector( + network: torch.nn.Module, none_grad_mode: Literal["raise", "zero", "skip"] = "skip" +) -> torch.Tensor: """ Returns the gradient vector of the given network. Args: network (torch.nn.Module): The network for which to compute the gradient vector. - jump_none (bool): Whether to skip the None gradients. default: True - This is useful when part of your neural network is frozen or not trainable. - You should set the same value to `apply_gradient_vector` when applying the gradient vector. + none_grad_mode (Literal['raise', 'zero', 'skip']): The mode to handle None gradients. default: 'skip' + - 'raise': Raise an error when the gradient of a parameter is None. + - 'zero': Replace the None gradient with a zero tensor. + - 'skip': Skip the None gradient. + The None gradient usually occurs when part of the network is not trainable (e.g., fine-tuning) + or the weight is not used to calculate the current loss (e.g., different parts of the network calculate different losses). + If all of your losses are calculated using the same part of the network, you should set none_grad_mode to 'skip'. + If your losses are calculated using different parts of the network, you should set none_grad_mode to 'zero' to ensure the gradients have the same shape. Returns: torch.Tensor: The gradient vector of the network. @@ -43,9 +50,16 @@ def get_gradient_vector(network, jump_none=True) -> torch.Tensor: grad_vec = None for par in network.parameters(): if par.grad is None: - if jump_none: + if none_grad_mode == "raise": + raise RuntimeError("None gradient detected.") + elif none_grad_mode == "zero": + viewed = torch.zeros_like(par.data.view(-1)) + elif none_grad_mode == "skip": continue - viewed = par.grad.data.view(-1) + else: + raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.") + else: + viewed = par.grad.data.view(-1) if grad_vec is None: grad_vec = viewed else: @@ -54,27 +68,74 @@ def get_gradient_vector(network, jump_none=True) -> torch.Tensor: def apply_gradient_vector( - network: torch.nn.Module, grad_vec: torch.Tensor, jump_none=True + network: torch.nn.Module, + grad_vec: torch.Tensor, + none_grad_mode: Literal["zero", "skip"] = "skip", + zero_grad_mode: Literal["skip", "pad_zero", "pad_value"] = "pad_value", ) -> None: """ Applies a gradient vector to the network's parameters. + This function requires the network contains the some gradient information in order to apply the gradient vector. + If your network does not contain the gradient information, you should consider using `apply_gradient_vector_para_based` function. Args: network (torch.nn.Module): The network to apply the gradient vector to. grad_vec (torch.Tensor): The gradient vector to apply. - jump_none (bool): Whether to skip the None gradients. default: True - This is useful when part of your neural network is frozen or not trainable. - You should set the same value to `get_gradient_vector` when applying the gradient vector. + none_grad_mode (Literal['zero', 'skip']): The mode to handle None gradients. + You should set this parameter to the same value as the one used in `get_gradient_vector` method. + zero_grad_mode (Literal['padding', 'skip']): How to set the value of the gradient if your `none_grad_mode` is "zero". default: 'skip' + - 'skip': Skip the None gradient. + - 'padding': Replace the None gradient with a zero tensor. + - 'pad_value': Replace the None gradient using the value in the gradient. + If you set `none_grad_mode` to 'zero', that means you padded zero to your `grad_vec` if the gradient of the parameter is None when getting the gradient vector. + When you apply the gradient vector back to the network, the value in the `grad_vec` corresponding to the previous None gradient may not be zero due to the applied gradient operation. + Thus, you need to determine whether to recover the original None value, set it to zero, or set the value according to the value in `grad_vec`. + If you are not sure what you are doing, it is safer to set it to 'pad_value'. """ + if none_grad_mode == "zero" and zero_grad_mode == "pad_value": + apply_gradient_vector_para_based(network, grad_vec) with torch.no_grad(): start = 0 for par in network.parameters(): if par.grad is None: - if jump_none: + if none_grad_mode == "skip": continue - end = start + par.grad.data.view(-1).shape[0] - par.grad.data = grad_vec[start:end].view(par.grad.data.shape) + elif none_grad_mode == "zero": + start = start + par.data.view(-1).shape[0] + if zero_grad_mode == "pad_zero": + par.grad = torch.zeros_like(par.data) + elif zero_grad_mode == "skip": + continue + else: + raise ValueError(f"Invalid zero_grad_mode '{zero_grad_mode}'.") + else: + raise ValueError(f"Invalid none_grad_mode '{none_grad_mode}'.") + else: + end = start + par.data.view(-1).shape[0] + par.grad.data = grad_vec[start:end].view(par.data.shape) + start = end + + +def apply_gradient_vector_para_based( + network: torch.nn.Module, + grad_vec: torch.Tensor, +) -> None: + """ + Applies a gradient vector to the network's parameters. + Please only use this function when you are sure that the length of `grad_vec` is the same of your network's parameters. + This happens when you use `get_gradient_vector` with `none_grad_mode` set to 'zero'. + Or, the 'none_grad_mode' is 'skip' but all of the parameters in your network is involved in the loss calculation. + + Args: + network (torch.nn.Module): The network to apply the gradient vector to. + grad_vec (torch.Tensor): The gradient vector to apply. + """ + with torch.no_grad(): + start = 0 + for par in network.parameters(): + end = start + par.data.view(-1).shape[0] + par.grad = grad_vec[start:end].view(par.data.shape) start = end @@ -109,7 +170,7 @@ def get_cos_similarity(vector1: torch.Tensor, vector2: torch.Tensor) -> torch.Te return torch.dot(vector1, vector2) / vector1.norm() / vector2.norm() -def unit_vector(vector: torch.Tensor, warn_zero=False) -> torch.Tensor: +def unit_vector(vector: torch.Tensor, warn_zero: bool = False) -> torch.Tensor: """ Compute the unit vector of a given tensor. @@ -259,7 +320,9 @@ def select( Returns: Tuple[Sequence,Union[float,Sequence]]: A tuple containing the indexes of the selected slice and the selected slice. """ - assert n <= len(source_sequence), "n can not be larger than or equal to the length of the source sequence" + assert n <= len( + source_sequence + ), "n can not be larger than or equal to the length of the source sequence" indexes = np.random.choice(len(source_sequence), n, replace=False) if len(indexes) == 1: return indexes, source_sequence[indexes[0]] diff --git a/docs/api/utils.md b/docs/api/utils.md index 793d4b5..38c9d24 100644 --- a/docs/api/utils.md +++ b/docs/api/utils.md @@ -5,6 +5,7 @@ The `utils` module contains utility functions for the ConFIG algorithm. ::: conflictfree.utils.apply_para_vector ::: conflictfree.utils.get_gradient_vector ::: conflictfree.utils.apply_gradient_vector +::: conflictfree.utils.apply_gradient_vector_para_based ## Math Utility Functions ::: conflictfree.utils.get_cos_similarity