Source code for qmlhc.optim.numpy_optim.trust_region

# -*- coding: utf-8 -*-
"""
Trust-Region (KL over States)
-----------------------------
Wrapper that enforces a trust-region constraint measured as a (symmetric) KL
proxy over state branches. If the KL bound is exceeded, performs backtracking
line-search along the update direction until satisfied.

Interface:
    - initialize(params) -> state
    - step_params(model, params, context) -> (new_params, state)

Requires:
    - base optimizer exposing step_params(...)
    - context["kl_fn"](old_info, new_info) -> float (KL or proxy)
    - context["info"]: current state's info dict (with 'branches' if possible)
    - context["refresh_info"](model, params, context) -> info (callable)
"""

from __future__ import annotations
from typing import Any, Callable, Dict, Mapping, Tuple
import numpy as np


[docs] class HCTrustRegion: """Trust-region wrapper with KL constraint over state-space.""" def __init__( self, base_opt: Any, delta_kl: float = 2e-2, backtrack: float = 0.7, max_backtracks: int = 8, ): self.base = base_opt self.delta_kl = float(delta_kl) self.bt = float(backtrack) self.bt_max = int(max_backtracks) self._state: Dict[str, Any] = {}
[docs] def initialize(self, params: Mapping[str, Any]) -> Dict[str, Any]: if hasattr(self.base, "initialize"): base_state = self.base.initialize(params) else: base_state = {} self._state = {"base_state": base_state, "steps": 0} return dict(self._state)
[docs] def step_params( self, model: Any, params: Mapping[str, Any], context: Mapping[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: old_info = context.get("info", {}) kl_fn: Callable[[Mapping[str, Any], Mapping[str, Any]], float] = context["kl_fn"] refresh_info: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], Mapping[str, Any]] = context["refresh_info"] # propose update proposal, base_state = self.base.step_params(model, params, context) # line search with KL guard alpha = 1.0 new_params = proposal for _ in range(self.bt_max + 1): new_info = refresh_info(model, new_params, context) kl_val = float(kl_fn(old_info, new_info)) if kl_val <= self.delta_kl: self._state = {"base_state": base_state, "steps": self._state.get("steps", 0) + 1, "kl": kl_val, "alpha_bt": alpha} return new_params, dict(self._state) # backtrack alpha *= self.bt blended = {} for k in params.keys(): blended[k] = (1 - alpha) * np.asarray(params[k], dtype=float) + alpha * np.asarray(proposal[k], dtype=float) new_params = blended # if failed to satisfy KL, return original params (conservative) self._state = {"base_state": base_state, "steps": self._state.get("steps", 0) + 1, "kl": float("inf"), "alpha_bt": 0.0} return dict(params), dict(self._state)