Base class
class UpperGrad
Description
Base class for all ul Variable Gradients Calculation.
Parameters
-
ul_objective: callable
The main optimization problem in a hierarchical optimization problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:
- "data"(Tensor) - Data used in the ul optimization phase.
- "target"(Tensor) - Target used in the ul optimization phase.
- "ul_model"(Module) - UL model of the bi-level model structure.
- "ll_model"(Module) - LL model of the bi-level model structure.
-
ul_model: Module
UL model in a hierarchical model structure whose parameters will be updated with ul objective and trained ll model. -
ll_model: Module
LL model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization.
Methods
compute_gradients(validate_data, validate_target, auxiliary_model, **kwargs)
Compute the grads of upper variable with validation data samples in the batch using ul objective. The grads will be saved in the passed in ul model.After that the update operation of ul variables needs to be done outside this module.
Parameters:
-
validate_data(Tensor) - The validation data used for ul problem optimization.
-
validate_target(Tensor) - The labels of the samples in the validation data.
-
auxiliary_model(_MonkeyPatchBase) - Wrapper of ll model encapsulated by module higher, has been optimized in ll optimization phase.
Returns
ul_loss(Tensor) - The loss value of ul objective.