Source code for qmlhc.loss.coherence

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inter-branch Coherence Loss
---------------------------
Controls dispersion among K projected futures to promote stable projections.

Two modes are supported:

- ``"variance"`` (default): mean squared deviation from the per-dimension mean.
- ``"mad"``: mean absolute deviation from the per-dimension mean.


"""

from __future__ import annotations

import numpy as np

from ..core.types import Array, LossFn


[docs] class CoherenceLoss(LossFn): """ Penalize dispersion across candidate future branches. Parameters ---------- mode : str, optional Dispersion metric to use. Options: - ``"variance"`` (scale-sensitive, smooth) - ``"mad"`` (mean absolute deviation) Default is ``"variance"``. """ def __init__(self, mode: str = "variance"): # Mode selects the dispersion metric; stored in lowercase for fast branching. self._mode = str(mode).lower() def __call__(self, futures: Array) -> float: """ Compute the coherence penalty for a set of candidate futures. Parameters ---------- futures : Array Matrix of candidate futures with shape ``(K, D)``. Returns ------- float Dispersion penalty (lower is better). Raises ------ ValueError If ``futures`` is not a 2-D array, or if ``mode`` is unsupported. """ fut = np.asarray(futures, dtype=float) if fut.ndim != 2: raise ValueError("futures must have shape (K, D)") # Per-dimension mean, kept 2-D for broadcasting (K, D) - (1, D) mu = fut.mean(axis=0, keepdims=True) if self._mode == "variance": # Mean of squared deviations (scale-sensitive and differentiable). var = ((fut - mu) ** 2).mean() return float(var) if self._mode == "mad": # Mean absolute deviation from the mean (more robust to outliers). mad = np.abs(fut - mu).mean() return float(mad) raise ValueError(f"unsupported coherence mode: {self._mode}")