Source code for qmlhc.optim.registry_numpy
# -*- coding: utf-8 -*-
"""
NumPy Optimizer Registry
------------------------
Factory for NumPy-based optimizers, wired to the project's Optimizer API.
Usage:
from qmlhc.optim.registry_numpy import create_optimizer_numpy
opt = create_optimizer_numpy("spsa", lr0=0.05, eps0=0.1)
Returned objects implement:
- initialize(params) -> state
- step_params(model, params, context) -> (new_params, state)
"""
from __future__ import annotations
from typing import Any, Callable, Dict
from .numpy_optim.finite_diff import HCFiniteDiffOptimizer
from .numpy_optim.spsa import HCSPSAOptimizer
from .numpy_optim.adam import HCAdam
from .numpy_optim.natural_grad import HCNaturalGrad
from .numpy_optim.trust_region import HCTrustRegion
from .numpy_optim.dual_ascent import HCDualAscent
from .numpy_optim.mpc import HCMPCShortHorizon
from .numpy_optim.kfac import HCKFACLike
_CREATORS: Dict[str, Callable[..., Any]] = {
"finite-diff": lambda **kw: HCFiniteDiffOptimizer(**kw),
"spsa": lambda **kw: HCSPSAOptimizer(**kw),
"adam": lambda **kw: HCAdam(**kw),
"natural-grad": lambda **kw: HCNaturalGrad(**kw),
"trust-kl": lambda **kw: HCTrustRegion(**kw),
"dual-ascent": lambda **kw: HCDualAscent(**kw),
"mpc": lambda **kw: HCMPCShortHorizon(**kw),
"kfac": lambda **kw: HCKFACLike(**kw),
}
[docs]
def create_optimizer_numpy(name: str, **kwargs) -> Any:
"""
Create a NumPy-based optimizer by name.
Parameters
----------
name : str
One of {"finite-diff","spsa","adam","natural-grad","trust-kl",
"dual-ascent","mpc","kfac"}.
kwargs : dict
Optimizer hyperparameters. For wrappers ("trust-kl","dual-ascent"),
pass 'base_opt' (the underlying optimizer instance).
Returns
-------
object
Optimizer instance exposing initialize(...) and step_params(...).
"""
key = name.strip().lower()
try:
return _CREATORS[key](**kwargs)
except KeyError as e:
raise KeyError(f"Unknown optimizer '{name}'. Available: {list(_CREATORS)}") from e