class Implicit
Description
Calculation of the gradient of the ul model variables with Implicit Gradient Based Methods.
Implements the ul optimization procedure of implicit gradient
based method (IGBM), neumann series based method (NS) [1]
.
A wrapper of ll model which has been optimized in the ll optimization will be used in this procedure.
Parameters
-
ul_objective: callable
The main optimization problem in a hierarchical optimization problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - ul model of the bi-level model structure.
- "ll_model"(Module) - ll model of the bi-level model structure.
-
ul_model: Module
ul model in a hierarchical model structure whose parameters will be updated with ul objective and trained ll model. -
ll_objective: callable
An optimization problem which is considered as the constraint of ul problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - ul model of the bi-level model structure.
- "ll_model"(Module) - ll model of the bi-level model structure.
-
ll_model: Module
ll model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization. -
ll_learning_rate: float
Step size for ll optimization. -
k: int
The maximum number of conjugate gradient iterations. -
tolerance: float, default=1e-10
End the method earlier when the norm of the residual is less than tolerance.
Methods
compute_gradients(validate_data, validate_target, auxiliary_model, train_data, train_target)
Compute the grads of ul variable with validation data samples in the batch using ul objective. The grads will be saved in the passed in ul model. After that the update operation of ul variables needs to be done outside this module.
Parameters:
-
validate_data(Tensor) - The validation data used for ul problem optimization.
-
validate_target(Tensor) - The labels of the samples in the validation data.
-
auxiliary_model(_MonkeyPatchBase) - Wrapper of ll model encapsulated by module higher, has been optimized in ll optimization phase.
-
train_data(Tensor) - The training data used for ll problem optimization.
-
train_target(Tensor) - The labels of the samples in the train data.
Returns
ul_loss(Tensor) - The loss value of ul objective.
References
[1]
J. Lorraine, P. Vicol, and D. Duvenaud, "Optimizing millions of
hyperparameters by implicit differentiation", in AISTATS, 2020.