Skip to content

class  BOTHOptimizer


Description

Wrapper for performing bi-level optimization and gradient-based initialization optimization

BOTHOptimizer is the wrapper of Bi-Level Optimization(BLO) and Initialization Optimization(Initialization-based EGBR) process which builds LL, UL and Initialization problem solver with corresponding method modules and uses in training phase. The optimization process could also be done by using methods packages directly.


Parameters

  • method: str Define basic method for following training process, it should be included in ['Initial', 'Feature']. 'Initial' type refers to meta-learning optimization strategy, including methods like 'MAML, FOMAML, TNet, WarpGrad, L2F'; 'Feature' type refers to bi-level optimization strategy, includes methods like 'BDA, RHG, Truncated RHG, Onestage, BVFIM, IAPTT-GM, LS, NS, GN, BVFIM'.

  • ll_method: str, default=None method chosen for solving LL problem, including ['Dynamic' | 'Implicit' | 'BVFIM'].

  • ul_method: str, default=None Method chosen for solving UL problem, including ['Recurrence','Onestage' | 'LS','NS', 'GN' | 'BVFIM'].

  • ll_objective: callable, default=None An optimization problem which is considered as the constraint of UL problem.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:

    • "data" Data used in the LL optimization phase.
    • "target" Target used in the LL optimization phase.
    • "ul_model" UL model of the bi-level model structure.
    • "ll_model" LL model of the bi-level model structure.
  • ul_objective: callable, default=None The main optimization problem in a hierarchical optimization problem.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:

    • "data" Data used in the UL optimization phase.
    • "target" Target used in the UL optimization phase.
    • "ul_model" Ul model of the bi-level model structure.
    • "ll_model" LL model of the bi-level model structure.
  • inner_objective: callable, default=None The inner loop optimization objective.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of inner objective. The state object contains the following:

    • "data" Data used in inner optimization phase.
    • "target" Target used in inner optimization phase.
    • "model" Meta model to be updated.
    • "updated_weights" Weights of model updated in inner-loop, will be used for forward propagation.
  • outer_objective: callable, default=None The outer optimization objective.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of outer objective. The state object contains the following:

    • "data" Data used in outer optimization phase.
    • "target" Target used in outer optimization phase.
    • "model" Meta model to be updated.
    • "updated_weights" Weights of model updated in inner-loop, will be used for forward propagation.
  • ll_model: Module, default=None The model whose parameters will be updated during ul-level optimization.

  • ul_model: Module, default=None ul model in a hierarchical model structure whose parameters will be updated with ul objective.

  • meta_model: MetaModel, default=None Model whose initial value will be optimized. If choose MAML method to optimize, any user-defined torch nn.Module could be used as long as the definition of forward() function meets the standard; but if choose other derived methods, internally defined both.utils.model.meta_model should be used for related additional modules.

  • total_iters: int, default=60000 Total iterations of the experiment, used to set weight decay.


Methods