Hypercausal Chain Demo#

Introduction#

This example demonstrates a multi-node hyper-causal chain simulation composed of three connected nodes (HCNode) sharing a sequential temporal flow. Each node uses a parametric backend (ParametricBackend) and cooperates through causal propagation and gradient-based optimization.

General Flow Structure#

The model represents a temporal hyper-causal system where each node contributes to a sequential information chain:

  • Parametric Backend: transforms inputs via \(S_t = \tanh(w \cdot x_t + b)\) and projects multiple possible futures.

  • Linear Projector: expands \(S_t\) into \(K\) candidate branches \(S_{t+1}^{(k)}\).

  • Loss functions: combine task accuracy, temporal consistency, and branch coherence for optimization.

  • Gradient Descent: updates all backends using finite-difference approximations.

How to Run#

# From the project root
python -m examples.ex_hypercausal_chain_demo

# Or directly
python examples/ex_hypercausal_chain_demo.py

Relevant Code Snippets#

Definition of the ParametricBackend class (tanh transformation and projection)#
 1from qmlhc.predictors import LinearProjector
 2# Losses
 3from qmlhc.loss import MSELoss, ConsistencyLoss, CoherenceLoss
 4# Optimizer
 5from qmlhc.optim import make_gradient_descent
 6
 7
 8# ============================================================================
 9# 1) Parametric backend with projection via LinearProjector
10# ============================================================================
11class ParametricBackend(BaseBackend):
12    """
13    Deterministic backend with per-node parameters (w, b).
14
15    The backend applies a tanh transformation and generates future projections
16    centered around the current state.
17
18    Methods
19    -------
20    run(params=None)
21        Computes ``S_t = tanh(w * x + b)``.
22    project_future(S_t, K)
23        Uses a LinearProjector centered on ``S_t`` to generate K future branches.
24    """
25
26    def __init__(
27        self,
28        config: BackendConfig,
29        w: float = 0.9,
30        b: float = 0.05,
31        proj_span: float = 0.25,
32    ):
33        super().__init__(config)
34        self.w = float(w)
35        self.b = float(b)
36        # Internal linear projector: uses S_t as projection base (not x)
37        self._projector = LinearProjector(weight=1.0, bias=0.0, span=float(proj_span))
38
39    def get_params(self) -> dict:
40        """
41        Return parameters as arrays for compatibility with the Optimizer API.
42
43        Returns
44        -------
45        dict
46            Dictionary with keys ``"w"`` and ``"b"`` as NumPy arrays.
47        """
48        return {
49            "w": np.array([self.w], dtype=float),
50            "b": np.array([self.b], dtype=float),
51        }
52
53    def set_params(self, params: dict) -> None:
54        """
55        Update backend parameters if provided.
56
57        Parameters
58        ----------
59        params : dict
60            Dictionary that may contain keys ``"w"`` and/or ``"b"``.
61        """
62        if "w" in params:
63            self.w = float(np.asarray(params["w"]).reshape(()))
64        if "b" in params:
65            self.b = float(np.asarray(params["b"]).reshape(()))
66
67    def run(self, params: dict | None = None) -> np.ndarray:
68        """
69        Apply the backend transformation ``S_t = tanh(w * x + b)``.
70
71        Parameters
72        ----------
73        params : dict or None, optional
74            Optional parameter override for this run.
75
76        Returns
77        -------
78        np.ndarray
79            Transformed state vector ``S_t``.
80        """
81        if params:
82            self.set_params(params)
83        x = self._require_input()
84        s_t = np.tanh(self.w * x + self.b)
85        s_t = self._validate_state(s_t)
86        return s_t
87
88    def project_future(self, s_t: np.ndarray, branches: int = 2) -> np.ndarray:
89        """
90        Generate future states around ``s_t`` using a linear projector.
91
92        Parameters
93        ----------
94        s_t : np.ndarray
95            Current state vector.
96        branches : int, optional
97            Number of future branches (K). Default is 2.
98
Main function chain_demo_step() (chain flow, loss computation, and optimization)#
  1    grads = {}
  2    base_params = {k: v.copy() for k, v in params.items()}
  3    apply_params_fn(base_params)
  4    base_loss = loss_fn()
  5
  6    for k, v in base_params.items():
  7        perturbed = {kk: vv.copy() for kk, vv in base_params.items()}
  8        perturbed[k] = v + eps
  9        apply_params_fn(perturbed)
 10        l_eps = loss_fn()
 11        grad = (l_eps - base_loss) / eps
 12        grads[k] = np.array([grad], dtype=float)
 13
 14    apply_params_fn(base_params)
 15    return grads
 16
 17
 18def dict_to_scalars(d: dict) -> dict:
 19    """
 20    Convert scalar ndarray values to safe Python floats for printing.
 21
 22    Parameters
 23    ----------
 24    d : dict
 25        Dictionary of parameter arrays.
 26
 27    Returns
 28    -------
 29    dict
 30        Dictionary with all values converted to floats.
 31    """
 32    out = {}
 33    for k, v in d.items():
 34        arr = np.asarray(v)
 35        if arr.shape == () or arr.size == 1:
 36            out[k] = float(arr.reshape(()).item())
 37        else:
 38            out[k] = arr.tolist()
 39    return out
 40
 41
 42def grad_l2_norm(grads: dict) -> float:
 43    """
 44    Compute L2 norm of all scalar gradients.
 45
 46    Parameters
 47    ----------
 48    grads : dict
 49        Dictionary with scalar gradient arrays.
 50
 51    Returns
 52    -------
 53    float
 54        L2 norm of gradients.
 55    """
 56    sq_sum = 0.0
 57    for v in grads.values():
 58        g = float(np.asarray(v).reshape(()).item())
 59        sq_sum += g * g
 60    return float(np.sqrt(sq_sum))
 61
 62
 63# ============================================================================
 64# 3) Hyper-causal chain demo + optimization step
 65# ============================================================================
 66def chain_demo_step():
 67    """
 68    Run a hyper-causal chain demo with one optimization step.
 69
 70    Builds a sequential model of three nodes, executes multiple time steps,
 71    computes task, consistency, and coherence losses, and applies a single
 72    gradient-descent update using finite-difference gradients.
 73
 74    Returns
 75    -------
 76    tuple
 77        (losses_before, losses_after)
 78    """
 79    D = 3
 80    K = 5
 81    T = 6  # Temporal sequence length
 82
 83    model, nodes, backends = build_model_chain(D=D, K=K)
 84
 85    # Input data (simple oscillatory pattern) and task targets
 86    t = np.arange(T, dtype=float)
 87    x_seq = np.stack(
 88        [
 89            0.3 * np.sin(0.7 * t + 0.0),
 90            0.2 * np.sin(0.7 * t + 0.8),
 91            0.1 * np.cos(0.7 * t + 0.3),
 92        ],
 93        axis=1,
 94    )
 95
 96    target_seq = np.zeros((T, D), dtype=float)
 97
 98    mse = MSELoss()
 99    cons = ConsistencyLoss(alpha=0.8, beta=1.2)
100    coh = CoherenceLoss(mode="variance")
101

Functional Explanation#

The hypercausal chain operates as a multi-node causal model, where each node processes, projects, and corrects its state based on local losses and temporal dependencies.

  1. Parametric Transformation

    Each node computes its local state:

    \[S_t = \tanh(w \cdot x_t + b)\]

    Here, \(w\) and \(b\) are node-specific parameters learned through gradient descent. The nonlinear \(\tanh\) activation ensures numerical stability, bounding all internal states within \((-1, 1)\).

  2. Future Projection (Linear Projector)

    Each state generates \(K\) possible futures using a linear projector centered at the current state:

    \[S_{t+1}^{(k)} = \text{LinearProjector}(S_t), \quad k \in \{1, \dots, K\}\]

    This projection expands the local state into a hypercausal “fan” of possibilities, representing multiple potential outcomes for the next time step.

  3. Loss Composition

    The total loss combines three complementary objectives:

    \[\mathcal{L}_{total} = \mathcal{L}_{task} + 0.5 \, \mathcal{L}_{consistency} + 0.3 \, \mathcal{L}_{coherence}\]
    • Task Loss (MSE):

      \[\mathcal{L}_{task} = \frac{1}{T} \sum_{t=1}^{T} \| S_t - Y_t \|^2\]

      Measures how close the node’s output is to the desired target trajectory.

    • Consistency Loss (Triadic):

      \[\mathcal{L}_{consistency} = \alpha \| S_t - S_{t-1} \|^2 + \beta \| S_t - \hat{S}_{t+1} \|^2\]

      Ensures smooth temporal evolution between past, present, and predicted future states.

    • Coherence Loss:

      \[\mathcal{L}_{coherence} = \text{Var}(S_{t+1}^{(k)})\]

      Penalizes excessive divergence among projected branches, maintaining causal stability.

  4. Gradient Estimation and Parameter Update

    Instead of backpropagation, the example uses a finite-difference gradient estimator:

    \[g_i = \frac{\mathcal{L}(\theta_i + \epsilon) - \mathcal{L}(\theta_i)}{\epsilon}\]

    Each parameter update follows a simple gradient-descent rule:

    \[\theta_i \leftarrow \theta_i - \eta \, g_i\]

    where \(\eta\) is the learning rate.

  5. Optimization Loop

    • The model runs for multiple time steps (\(T = 6\)), accumulating losses.

    • The optimizer (make_gradient_descent) applies one parameter update across all nodes.

    • Losses and parameters before/after the update are displayed for interpretability.

Exact Output#

=== Hypercausal Chain Demo ===
D=3, K=5, T=6

Parameters (before):
{'b0_w': 0.9, 'b0_b': 0.05, 'b1_w': 0.95, 'b1_b': 0.02, 'b2_w': 1.05, 'b2_b': 0.0}

Losses BEFORE update:
{'task': 0.02533261887044361, 'cons': 0.006637497192550465, 'coh': 0.029526745779767466, 'total': 0.037509391200649084}

Updating parameters with GD (finite-diff grads)...
||grad||_2 ≈ 3.556270e-01

Parameters (after):
{'b0_w': 0.8979223132592518, 'b0_b': 0.040790416564867205, 'b1_w': 0.9474490593467247, 'b1_b': 0.009730820141449475, 'b2_w': 1.0473899389089636, 'b2_b': -0.010405162840116874}

Losses AFTER update:
{'task': 0.019773491535848887, 'cons': 0.006664418126248786, 'coh': 0.02973869208514816, 'total': 0.03202730822451773}

Summary:
total BEFORE:  0.037509
total AFTER :  0.032027