Source code for qmlhc.optim.numpy_optim.kfac

# -*- coding: utf-8 -*-
"""
K-FAC-like Preconditioner (Branch-Factored)
-------------------------------------------
Kronecker-Factored Approximate Curvature inspired preconditioning using
branch-wise covariance factors as a low-cost second-order proxy.

This is a simplified K-FAC-like optimizer tailored to hypercausal state
branches: assumes block-diagonal structure across branch groups.

Interface:
    - initialize(params) -> state
    - step_params(model, params, context) -> (new_params, state)

Context:
    - context["info"]["branches"]: (K x D) matrix of state samples
    - context['grads'] or a grad_estimator for parameter gradients
"""

from __future__ import annotations
from typing import Any, Callable, Dict, Mapping, Tuple
import numpy as np
from .utils import flatten_params, deflatten_params, cov_empirical, cg_solve


[docs] class HCKFACLike: """Branch-factored curvature preconditioning, simplified K-FAC variant.""" def __init__( self, lr: float = 5e-3, damp: float = 1e-3, blocks: int = 4, grad_estimator: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], np.ndarray] | None = None, clip: float | None = None, seed: int = 2027, ): self.lr = float(lr) self.damp = float(damp) self.blocks = int(blocks) self.grad_estimator = grad_estimator self.clip = clip self._rng = np.random.default_rng(seed) self._state: Dict[str, Any] = {}
[docs] def initialize(self, params: Mapping[str, Any]) -> Dict[str, Any]: self._state = {"steps": 0} return dict(self._state)
[docs] def step_params( self, model: Any, params: Mapping[str, Any], context: Mapping[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: theta, layout = flatten_params(params) # gradient if self.grad_estimator is not None: g = np.asarray(self.grad_estimator(model, params, context), dtype=float).reshape(theta.shape) else: grads_dict = context.get("grads", {}) if not grads_dict: raise ValueError("HCKFACLike expects either grad_estimator or context['grads']") g_chunks = [] for k, n in layout: v = np.atleast_1d(np.asarray(grads_dict[k], dtype=float)).reshape(-1) if v.size != n: raise ValueError(f"Gradient size mismatch for key '{k}'") g_chunks.append(v) g = np.concatenate(g_chunks) # state branches covariance and factorization into blocks info = context.get("info", {}) B = info.get("branches", None) if B is None or np.asarray(B).ndim != 2: precond = g else: B = np.asarray(B, dtype=float) D = B.shape[1] C = cov_empirical(B) + self.damp * np.eye(D) # block partition of state-dim blocks = max(1, min(self.blocks, D)) sizes = np.full(blocks, D // blocks, dtype=int) sizes[: D % blocks] += 1 cuts = np.cumsum(sizes) # random linear map param->state (thin) P = self._rng.normal(size=(theta.size, D)) / np.sqrt(D) g_state = P.T @ g v_star = np.zeros_like(g_state) start = 0 for c in cuts: sl = slice(start, c) Cb = C[sl, sl] gb = g_state[sl] # solve (Cb) v = gb v_star[sl] = np.linalg.solve(Cb, gb) start = c precond = P @ v_star theta_new = theta - self.lr * precond if self.clip is not None: theta_new = np.clip(theta_new, -self.clip, self.clip) new_params = deflatten_params(theta_new, layout, params) self._state = {"steps": self._state.get("steps", 0) + 1, "precond_norm": float(np.linalg.norm(precond))} return new_params, dict(self._state)