Source code for qmlhc.predictors.anticipator

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Contrafactual Anticipators
--------------------------
Wrappers that synthesize structured counterfactual futures on top of a base
projector and optional perturbations. 
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Sequence

import numpy as np

from ..core.types import Array, TensorLike
from .projector import Projector

# Maps a (D,) vector to a (D,) vector
Perturb = Callable[[Array], Array]


[docs] @dataclass(frozen=True) class AnticipatorConfig: """ Static configuration controlling counterfactual generation semantics. Parameters ---------- branches : int, optional Number of base branches K produced by the projector, by default ``3``. Must satisfy ``K >= 2`` in downstream usage. symmetric : bool, optional When ``True``, for every perturbation ``v`` around the center, also add its mirrored counterpart ``(2 * center - v)``. Default is ``True``. """ branches: int = 3 symmetric: bool = True
[docs] class ContrafactualAnticipator: """ Generate structured counterfactual futures on top of a base projector. Given a current state ``s_t``, first obtains a base set of futures from ``Projector.project``. Then, for each user-provided perturbation, it adds a single variant (and optionally its symmetric mirror) around the base center. Notes ----- The final future set is the concatenation of: - the projector's base set ``(K, D)``, - one row per perturbation, - (optionally) one mirrored row per perturbation. """ def __init__( self, projector: Projector, perturbations: Sequence[Perturb] | None = None, config: AnticipatorConfig | None = None, ) -> None: self._proj = projector self._perts = list(perturbations or []) self._cfg = config or AnticipatorConfig()
[docs] def generate(self, s_t: TensorLike) -> Array: """ Produce a combined future set ``(K', D)`` from base projection and variants. Steps ----- 1. Call the base projector to obtain ``base_set`` with shape ``(K, D)``. 2. If perturbations are provided, compute the center as ``mean(base_set, axis=0)``. 3. For each perturbation ``p``, append ``p(center)`` as a new row. 4. If ``symmetric`` is enabled, also append the mirrored row ``2 * center - p(center)``. Parameters ---------- s_t : TensorLike Current state vector ``(D,)``. Returns ------- Array Concatenated futures matrix with shape ``(K', D)``, where ``K' >= K`` depends on the number of perturbations and symmetry. """ base_set = self._proj.project(s_t, branches=self._cfg.branches) variants = [base_set] if self._perts: center = base_set.mean(axis=0) for p in self._perts: v = p(center) variants.append(np.expand_dims(v, axis=0)) if self._cfg.symmetric: variants.append(np.expand_dims(2 * center - v, axis=0)) fut = np.concatenate(variants, axis=0) return fut