class GN
Description
UL Variable Gradients Calculation with GN Method
Implements the ul problem optimization procedure of approximated Bilevel Stochastic
Gradient method (BSG-1)[1]
, which approximates second-order ul gradient
to first-order
A wrapper of ll model that has been optimized in the ll optimization will be used in this procedure.
Parameters
-
ul_objective: callable
The main optimization problem in a hierarchical optimization problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - ul model of the bi-level model structure.
- "ll_model"(Module) - ll model of the bi-level model structure.
-
ul_model: Module
ul model in a hierarchical model structure whose parameters will be updated with ul objective and trained ll model. -
ll_objective: callable
An optimization problem which is considered as the constraint of ul problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - ul model of the bi-level model structure.
- "ll_model"(Module) - ll model of the bi-level model structure.
-
ll_model: Module
ll model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization. -
ll_learning_rate: float
Step size for ll optimization.
Methods
compute_gradients(validate_data, validate_target, auxiliary_model, train_data, train_target)
Compute the grads of ul variable with validation data samples in the batch using ul objective. The grads will be saved in the passed in ul model. After that the update operation of ul variables needs to be done outside this module.
Parameters:
-
validate_data(Tensor) - The validation data used for ul problem optimization.
-
validate_target(Tensor) - The labels of the samples in the validation data.
-
auxiliary_model(_MonkeyPatchBase) - Wrapper of ll model encapsulated by module higher, has been optimized in ll optimization phase.
-
train_data(Tensor) - The training data used for ll problem optimization.
-
train_target(Tensor) - The labels of the samples in the train data.
Returns
ul_loss(Tensor) - The loss value of ul objective.
References
[1]
T. Giovannelli, G. Kent, L. N. Vicente, "Bilevel stochastic methods for
optimization and machine learning: Bilevel stochastic descent and DARTS", in arxiv, 2021.