Skip to content

class  Onestage


Description

Calculation of the gradient of the ul model variables with DARTS method.

Implements the ul optimization procedure of DARTS [1], a first order approximation method which is free of both second-order derivatives and matrix-vector products.

A wrapper of ll model that has been optimized in the ll optimization will be used in this procedure.


Parameters

  • ul_objective: callable
    The main optimization problem in a hierarchical optimization problem.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:

    • "data"(Tensor) - Data used in the ul optimization phase.
    • "target"(Tensor) - Target used in the ul optimization phase.
    • "ul_model"(Module) - UL model of the bi-level model structure.
    • "ll_model"(Module) - LL model of the bi-level model structure.
  • ul_model: Module
    ul model in a hierarchical model structure whose parameters will be updated with ul objective and trained ll model.

  • ll_objective: callable
    An optimization problem which is considered as the constraint of ul problem.

    Callable with signature callable(state). Defined based on modeling of the specific problem that need to be solved. Computing the loss of ul problem. The state object contains the following:

    • "data"(Tensor) - Data used in the ul optimization phase.
    • "target"(Tensor) - Target used in the ul optimization phase.
    • "ul_model"(Module) - ul model of the bi-level model structure.
    • "ll_model"(Module) - ll model of the bi-level model structure.
  • ll_model: Module
    ll model in a hierarchical model structure whose parameters will be updated with ll objective during ll optimization.

  • lower_learning_rate: float
    Step size for ll optimization.

  • r (optional): float, default=1e-2
    Parameter to adjust scalar \epsilon as: \epsilon = 0.01/\|\nabla_{w'}\mathcal L_{val}(w',\alpha)\|_2, and \epsilon is used as: w^\pm = w \pm \epsilon\nabla_{w'}\mathcal L_{val}(w',\alpha). Value 0.01 of r is recommended for sufficiently accurate in the paper.


Methods


References

[1] H. Liu, K. Simonyan, Y. Yang, "DARTS: Differentiable Architecture Search", in ICLR, 2019.