Base class
class LowerOptimize
Description
Base class for all ll optimization procedures.
Parameters
-
ll_objective: callable
An optimization problem which is considered as the constraint of ll problem.Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ll problem. The state object contains the following:
- "data"(Tensor) - Data used in the ll optimization phase.
- "target"(Tensor) - Target used in the ll optimization phase.
- "upper_model"(Module) - UL model of the bi-level model structure.
- "lower_model"(Module) - LL model of the bi-level model structure.
-
lower_loop: int
Updating iterations over ll optimization. -
ul_model: Module
UL model in a hierarchical model structure whose parameters will be updated with upper objective. -
ll_model: Module
LL model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization.
Methods
optimize(train_data, train_target, auxiliary_model, auxiliary_opt, **kwargs)
Execute the ll optimization procedure with training data samples using ll objective. The passed in wrapper of ll model will be updated.
Parameters:
-
train_data(Tensor) - The training data used for ll problem optimization.
-
train_target(Tensor) - The labels of the samples in the train data.
-
auxiliary_model(_MonkeyPatchBase) - Wrapper of lower model encapsulated by module higher, will be optimized in ll optimization procedure.
-
auxiliary_opt(DifferentiableOptimizer) - Wrapper of ll optimizer encapsulated by module higher, will be used in ll optimization procedure.