#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Hypercausal Model Composition
-----------------------------
High-level orchestration of hypercausal nodes in single-node or chained setups.
This module defines the ``HCModel`` class, which manages the execution of one
or multiple hypercausal nodes across steps or temporal sequences. Each node
follows the hypercausal contract: given an input ``x_t`` and optional previous
state ``s_{t-1}``, it returns the current state ``s_t``, a projected future
state ``ŝ_{t+1}``, and auxiliary information.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, List, Mapping, Sequence, Tuple
import numpy as np
from .types import Array, HypercausalNode, TensorLike
[docs]
@dataclass(frozen=True)
class ModelConfig:
"""
Static configuration defining model execution semantics.
Parameters
----------
default_branches : int, optional
Number of candidate future branches (K) used by default
when no explicit value is provided, by default 2.
"""
default_branches: int = 2
[docs]
class HCModel:
"""
Composes one or more hypercausal nodes into an executable model.
Provides both single-step and multi-step execution methods that support:
- Sequential chaining of multiple nodes (``forward_chain``)
- Temporal sequence processing (``predict_sequence``)
"""
[docs]
def __init__(self, nodes: Sequence[HypercausalNode], config: ModelConfig | None = None):
"""
Initialize an HCModel with a sequence of hypercausal nodes.
Parameters
----------
nodes : Sequence[HypercausalNode]
Ordered list of hypercausal nodes to be executed.
config : ModelConfig or None, optional
Model configuration. If ``None``, uses default settings.
Raises
------
ValueError
If no nodes are provided.
"""
if not nodes:
raise ValueError("HCModel requires at least one node.")
self._nodes: List[HypercausalNode] = list(nodes)
self._cfg = config or ModelConfig()
# ------------------------------------------------------------------
# Single-step API
# ------------------------------------------------------------------
[docs]
def forward(
self,
x_t: TensorLike,
s_tm1: TensorLike | None = None,
branches: int | None = None,
) -> Tuple[Array, Array, Mapping[str, Any]]:
"""
Execute only the first node.
Parameters
----------
x_t : TensorLike
Current input vector at time t.
s_tm1 : TensorLike or None, optional
Previous state (t−1), by default ``None``.
branches : int or None, optional
Number of future branches (K). Uses ``default_branches`` if ``None``.
Returns
-------
tuple
(``s_t``, ``ŝ_{t+1}``, ``info``)
``s_t`` : Array
Current state.
``ŝ_{t+1}`` : Array
Projected next-state prediction.
``info`` : dict
Additional node diagnostics.
"""
k = self._resolve_branches(branches)
s_t, s_tp1_hat, info = self._nodes[0].forward(x_t, s_tm1=s_tm1, branches=k)
return s_t, s_tp1_hat, info
# ------------------------------------------------------------------
# Multi-node chain API
# ------------------------------------------------------------------
[docs]
def forward_chain(
self,
x_t: TensorLike,
s_tm1: TensorLike | None = None,
branches: int | None = None,
) -> Tuple[Array, Array, List[Mapping[str, Any]]]:
"""
Execute all nodes sequentially in a forward chain.
Each node receives the output ``s_t`` of the previous node as its
next input ``x_t``. The previous state reference (``s_tm1``) is
passed to the first node only; subsequent nodes use the previous
node's state for consistency.
Parameters
----------
x_t : TensorLike
Current input vector at time t.
s_tm1 : TensorLike or None, optional
Previous state (t-1), by default ``None``.
branches : int or None, optional
Number of future branches (K). Uses ``default_branches`` if ``None``.
Returns
-------
tuple
(``s_t``, ``ŝ_{t+1}``, ``infos``)
``s_t`` : Array
Final state after the last node.
``ŝ_{t+1}`` : Array
Projected next-state prediction from the last node.
``infos`` : list[dict]
Per-node diagnostic information.
"""
k = self._resolve_branches(branches)
infos: List[Mapping[str, Any]] = []
current_x = np.asarray(x_t, dtype=float).reshape(-1)
prev_state = None if s_tm1 is None else np.asarray(s_tm1, dtype=float).reshape(-1)
for idx, node in enumerate(self._nodes):
s_t, s_tp1_hat, info = node.forward(current_x, s_tm1=prev_state, branches=k)
infos.append({"node_index": idx, **info})
current_x = s_t # feed to next node
prev_state = s_t
return current_x, s_tp1_hat, infos
# ------------------------------------------------------------------
# Sequence-level API
# ------------------------------------------------------------------
[docs]
def predict_sequence(
self,
x_seq: Sequence[TensorLike],
s0: TensorLike | None = None,
branches: int | None = None,
use_chain: bool = False,
) -> Tuple[List[Array], List[Array], List[Any]]:
"""
Process an entire temporal sequence of inputs.
If ``use_chain`` is False, applies only the first node across all steps.
If True, applies the full multi-node chain at each time step.
Parameters
----------
x_seq : Sequence[TensorLike]
Input sequence (T × D).
s0 : TensorLike or None, optional
Initial state, by default ``None``.
branches : int or None, optional
Number of future branches (K). Uses ``default_branches`` if ``None``.
use_chain : bool, optional
Whether to execute all nodes sequentially per time step, by default ``False``.
Returns
-------
tuple
(``states``, ``futures``, ``infos``)
``states`` : list[Array]
Sequence of current states.
``futures`` : list[Array]
Sequence of projected next states.
``infos`` : list[Any]
Per-step diagnostic information.
"""
k = self._resolve_branches(branches)
states: List[Array] = []
futures: List[Array] = []
infos: List[Any] = []
s_tm1 = None if s0 is None else np.asarray(s0, dtype=float).reshape(-1)
for x_t in x_seq:
if use_chain:
s_t, s_tp1_hat, info = self.forward_chain(x_t, s_tm1=s_tm1, branches=k)
else:
s_t, s_tp1_hat, info = self.forward(x_t, s_tm1=s_tm1, branches=k)
states.append(s_t)
futures.append(s_tp1_hat)
infos.append(info)
s_tm1 = s_t
return states, futures, infos
# ------------------------------------------------------------------
# Internal utilities
# ------------------------------------------------------------------
def _resolve_branches(self, branches: int | None) -> int:
"""
Determine the number of branches (K) for projection.
Parameters
----------
branches : int or None
Optional override value.
Returns
-------
int
Final branch count (≥ 2).
Raises
------
ValueError
If branches < 2.
"""
k = self._cfg.default_branches if branches is None else int(branches)
if k < 2:
raise ValueError("branches must be >= 2")
return k