class Onestage
Description
Calculation of the gradient of the ul model variables with DARTS method.
Implements the ul optimization procedure of DARTS [1]
, a first order approximation
method which is free of both second-order derivatives and matrix-vector products.
A wrapper of ll model that has been optimized in the ll optimization will be used in this procedure.
Parameters
-
ul_objective: callable
The main optimization problem in a hierarchical optimization problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - UL model of the bi-level model structure.
- "ll_model"(Module) - LL model of the bi-level model structure.
-
ul_model: Module
ul model in a hierarchical model structure whose parameters will be updated with ul objective and trained ll model. -
ll_objective: callable
An optimization problem which is considered as the constraint of ul problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - ul model of the bi-level model structure.
- "ll_model"(Module) - ll model of the bi-level model structure.
-
ll_model: Module
ll model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization. -
lower_learning_rate: float
Step size for ll optimization. -
r (optional): float, default=1e-2
Parameter to adjust scalar \epsilon as: \epsilon = 0.01/\|\nabla_{w'}\mathcal L_{val}(w',\alpha)\|_2, and \epsilon is used as: w^\pm = w \pm \epsilon\nabla_{w'}\mathcal L_{val}(w',\alpha). Value 0.01 of r is recommended for sufficiently accurate in the paper.
Methods
compute_gradients(validate_data, validate_target, auxiliary_model, train_data, train_target)
Compute the grads of ul variable with validation data samples in the batch using ul objective. The grads will be saved in the passed in ul model.
Note that the implemented ul optimization procedure will only compute the grads of ul variables. After that the update operation of ul variables needs to be done outside this module.
Parameters:
-
validate_data(Tensor) - The validation data used for ul problem optimization.
-
validate_target(Tensor) - The labels of the samples in the validation data.
-
auxiliary_model(_MonkeyPatchBase) - Wrapper of ll model encapsulated by module higher, has been optimized in ll optimization phase.
-
train_data(Tensor) - The training data used for ll problem optimization.
-
train_target(Tensor) - The labels of the samples in the train data.
Returns
ul_loss(Tensor) - The loss value of ul objective.
References
[1]
H. Liu, K. Simonyan, Y. Yang, "DARTS: Differentiable Architecture Search",
in ICLR, 2019.