Training with Callbacks Demo#

Introduction#

This example presents a stable training framework for a three-node hyper-causal model. It demonstrates a complete adaptive optimization cycle using depth-aware recursion, finite-difference gradients, and callback-driven monitoring. The goal is to illustrate how a causal system can be trained under dynamic conditions without relying on classical backpropagation.

Key mechanisms include:

  1. DepthScheduler – controls recursion depth dynamically per epoch.

  2. Adaptive learning parameters – rescale learning rate and perturbation with recursion depth.

  3. Finite-difference gradient estimation – numerical gradient replacement for stability.

  4. Gradient clipping – limits the trust region to prevent divergence.

  5. Callback telemetry – monitors all metrics and parameters across epochs.

  6. Early stopping – detects convergence based on stability in total loss.

General Flow Structure#

The training loop coordinates recursive backends connected through causal dependencies. Each backend adjusts its transformation depth, projects future states, and contributes to a composite loss. The callback system regulates both the internal optimization and the external logging.

  • DepthAwareBackend: applies recursive transformations \(S_t^{(d)} = \tanh(W^{(d)}S_{t-1} + b^{(d)})\).

  • Projection: expands each state into \(K\) future branches via a linear projector.

  • Loss aggregation: combines task, consistency, and coherence components.

  • Optimizer: updates parameters using finite-difference gradients and gradient clipping.

  • Scheduler: increases recursion depth gradually to control system complexity.

How to Run#

# From the project root
python -m examples.ex_training_with_callbacks_demo

# Or directly
python examples/ex_training_with_callbacks_demo.py

Relevant Code Snippets#

Definition of the DepthAwareBackend class (recursive tanh transformation and projection)#
  1            self.w = float(np.asarray(params["w"]).reshape(()))
  2        if "b" in params:
  3            self.b = float(np.asarray(params["b"]).reshape(()))
  4
  5    def run(self, params: dict | None = None) -> np.ndarray:
  6        """
  7        Apply depth-recursive tanh transformation.
  8
  9        Parameters
 10        ----------
 11        params : dict or None, optional
 12            Optional parameter override for this call.
 13
 14        Returns
 15        -------
 16        np.ndarray
 17            Validated current state vector.
 18        """
 19        if params:
 20            self.set_params(params)
 21        x = self._require_input().astype(float)
 22        s = x
 23        for _ in range(max(1, int(self.depth))):
 24            s = np.tanh(self.w * s + self.b)
 25        return self._validate_state(s)
 26
 27    def project_future(self, s_t: np.ndarray, branches: int = 2) -> np.ndarray:
 28        """
 29        Project future states around ``s_t`` using a depth-adjusted span.
 30
 31        Parameters
 32        ----------
 33        s_t : np.ndarray
 34            Current state vector.
 35        branches : int, optional
 36            Number of future branches (K), by default 2.
 37
 38        Returns
 39        -------
 40        np.ndarray
 41            Future branch matrix with shape ``(K, D)``.
 42        """
 43        s = self._validate_state(s_t)
 44        k = max(2, int(branches))
 45        # span reduced with depth, but with a high floor (0.10)
 46        span = max(self._span_floor, self._base_span / (1.0 + 0.3 * (self.depth - 1)))
 47        self._projector = LinearProjector(weight=1.0, bias=0.0, span=span)
 48        fut = self._projector.project(s, branches=k)
 49        return self._validate_branches(fut)
 50
 51
 52# ------------------------ Construction utils ------------------------
 53def build_model_chain(D=3):
 54    """
 55    Build a three-node HCModel with depth-aware backends.
 56
 57    Parameters
 58    ----------
 59    D : int, optional
 60        State dimensionality, by default 3.
 61
 62    Returns
 63    -------
 64    tuple
 65        (model, nodes, backends)
 66    """
 67    cfg = BackendConfig(output_dim=D, seed=11)
 68    b0 = DepthAwareBackend(cfg, w=0.90, b=0.03, proj_span=0.22)
 69    b1 = DepthAwareBackend(cfg, w=0.97, b=0.02, proj_span=0.25)
 70    b2 = DepthAwareBackend(cfg, w=1.05, b=0.00, proj_span=0.30)
 71    pol = MeanPolicy()
 72    n0, n1, n2 = HCNode(b0, pol), HCNode(b1, pol), HCNode(b2, pol)
 73    model = HCModel([n0, n1, n2])
 74    return model, [n0, n1, n2], [b0, b1, b2]
 75
 76
 77def params_pack(backends):
 78    """
 79    Flatten parameters from all backends into a single dictionary.
 80
 81    Parameters
 82    ----------
 83    backends : list
 84        List of backend instances.
 85
 86    Returns
 87    -------
 88    dict
 89        Flattened parameter dictionary suitable for the optimizer.
 90    """
 91    packed = {}
 92    for i, be in enumerate(backends):
 93        for k, v in be.get_params().items():
 94            packed[f"b{i}_{k}"] = np.array(v, dtype=float)
 95    return packed
 96
 97
 98def params_unpack(backends, packed):
 99    """
100    Distribute flat parameters back to their corresponding backends.
101
102    Parameters
103    ----------
104    backends : list
105        Backend instances to update.
106    packed : dict
107        Flattened parameter dictionary.
108    """
109    for i, be in enumerate(backends):
110        sub = {}
111        for k in ("w", "b"):
112            key = f"b{i}_{k}"
113            if key in packed:
114                sub[k] = packed[key]
115        be.set_params(sub)
116
117
118# ------------------- CENTRAL finite-difference grads -------------------
119def central_diff_grads(loss_fn, params, apply_params_fn, eps: float):
120    """
121    Central finite-difference gradients (more stable than forward diff).
122
123    Gradient ≈ (f(x + eps) - f(x - eps)) / (2 * eps)
124
125    Parameters
126    ----------
127    loss_fn : callable
128        Function that recomputes the full loss given current parameters.
129    params : dict
130        Parameter dictionary (values are scalar arrays).
131    apply_params_fn : callable
132        Function to apply a parameter dictionary to the model.
133    eps : float
Main function stable_training_demo() (training loop with callbacks and adaptive learning)#
  1    BASE_EPS = 1e-3
  2    LOG_PATH = Path("runs/telemetry_stable.jsonl")
  3
  4    # Model
  5    model, nodes, backends = build_model_chain(D=D)
  6
  7    # Data
  8    t = np.arange(T, dtype=float)
  9    x_seq = np.stack([
 10        0.30 * np.sin(0.35 * t + 0.00),
 11        0.20 * np.sin(0.35 * t + 0.70),
 12        0.10 * np.cos(0.35 * t + 0.30),
 13    ], axis=1)
 14    target_seq = np.zeros((T, D), dtype=float)
 15
 16    # Losses
 17    loss_task = MSELoss()
 18    loss_cons = ConsistencyLoss(alpha=0.8, beta=1.2)
 19    loss_coh = CoherenceLoss(mode="variance")
 20
 21    # Optimizer + telemetry
 22    params = params_pack(backends)
 23    opt = make_gradient_descent(lr=BASE_LR)  # base LR; recalibrated by depth each epoch
 24    state = opt.initialize(params)
 25
 26    callbacks = CallbackList([
 27        TelemetryLogger(path=LOG_PATH, flush_interval=8),
 28        MemoryLogger(),
 29    ])
 30    # Real DepthScheduler (1 → 3 across EPOCHS-1; clamped at 3)
 31    depth_cb = DepthScheduler(target_attr="depth", start=1, end=3, epochs=EPOCHS - 1)
 32
 33    def apply_params_fn(packed):
 34        params_unpack(backends, packed)
 35
 36    def forward_and_losses():
 37        """
 38        Compute forward pass over the sequence and all loss terms.
 39
 40        Returns
 41        -------
 42        tuple
 43            (total_loss, details_dict, last_state_sequence)
 44        """
 45        total_task = total_cons = total_coh = 0.0
 46        s_tm1 = None
 47        y_last = []
 48        for step in range(T):
 49            callbacks.on_step_begin(step, {"step": int(step)})
 50            s_t, s_hat, infos = model.forward_chain(x_seq[step], s_tm1=s_tm1, branches=K)
 51            y_last.append(s_t)
 52            total_task += loss_task(s_t, target_seq[step])
 53            if s_tm1 is not None:
 54                total_cons += loss_cons(s_tm1, s_t, s_hat)
 55            coh_vals = []
 56            for info in infos:
 57                br = info.get("branches", None)
 58                if isinstance(br, np.ndarray) and br.ndim == 2:
 59                    coh_vals.append(loss_coh(br))
 60            if coh_vals:
 61                total_coh += float(np.mean(coh_vals))
 62            s_tm1 = s_t
 63            callbacks.on_step_end(step, {"step": int(step)})
 64        task = total_task / T
 65        cons = total_cons / max(1, T - 1)
 66        coh = total_coh / T
 67        total = task + 0.5 * cons + 0.3 * coh
 68        return total, {"task": float(task), "cons": float(cons), "coh": float(coh), "total": float(total)}, np.asarray(y_last)
 69
 70    best = None
 71    patience = 1
 72    bad_epochs = 0
 73
 74    for epoch in range(EPOCHS):
 75        # 1) Adjust depth (real)
 76        for be in backends:
 77            depth_cb.on_epoch_begin(epoch, {"backend": be})
 78        depth_mean = float(np.mean([be.depth for be in backends]))
 79
 80        # 2) Depth-adaptive LR and EPS
 81        lr_eff = BASE_LR / (depth_mean ** 2)
 82        eps_eff = BASE_EPS / (1.0 + 0.5 * (depth_mean - 1.0))
 83        opt = make_gradient_descent(lr=lr_eff)  # recreate with effective LR
 84
 85        # 3) JSON-safe telemetry
 86        callbacks.on_epoch_begin(epoch, {"epoch": int(epoch), "depth": float(depth_mean), "lr_eff": float(lr_eff), "eps_eff": float(eps_eff)})
 87
 88        # 4) Forward before update
 89        total0, det0, _ = forward_and_losses()
 90
 91        # 5) Central gradients + clipping + step
 92        def loss_wrapper():
 93            l, _, _ = forward_and_losses()
 94            return l
 95
 96        grads = central_diff_grads(loss_wrapper, params, apply_params_fn, eps=eps_eff)
 97        grads = clip_grads(grads, max_norm=5e-2)
 98        params, state = opt.step(params, grads, state)
 99        apply_params_fn(params)
100
101        # 6) Forward after update
102        total1, det1, _ = forward_and_losses()
103
104        # 7) Telemetry end
105        callbacks.on_epoch_end(epoch, {
106            "epoch": int(epoch),
107            "loss_before": det0,
108            "loss_after": det1,
109        })
110
111        # 8) Early stopping
112        if best is None or det1["total"] < best["total"]:
113            best = {"epoch": int(epoch), **det1}
114            bad_epochs = 0
115        else:
116            bad_epochs += 1
117
118        depths = [int(be.depth) for be in backends]
119        print(f"[Epoch {epoch}] total_before={det0['total']:.6f} total_after={det1['total']:.6f} depth={depths} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e}")
120
121        if bad_epochs > patience:
122            print(f"Early stopping activated at epoch {epoch}. Best total={best['total']:.6f} (epoch {best['epoch']}).")
123            break
124
125    # Final metrics using the last forward
126    _, _, y_pred_seq = forward_and_losses()
127    smape_val = smape_safe(target_seq[:, 0], y_pred_seq[:, 0])
128    rmse_val = rmse(target_seq[:, 0], y_pred_seq[:, 0])
129    over_val = overshoot(target_seq[:, 0], y_pred_seq[:, 0])
130    rob_val = robustness(target_seq[:, 0], y_pred_seq[:, 0])
131
132    print("\n=== Final metrics (channel 0) ===")
133    print(f"SMAPE:      {smape_val:.6f} %")
134    print(f"RMSE:       {rmse_val:.6f}")
135    print(f"Overshoot:  {over_val:.6f}")
136    print(f"Robustness: {rob_val:.6f}")
137    print("\nBest epoch snapshot:", best)
138
139    if LOG_PATH.exists():
140        print(f"\nTelemetry JSONL → {LOG_PATH.resolve()}")
141
142    return {"best": best, "metrics": {"smape": smape_val, "rmse": rmse_val, "overshoot": over_val, "robustness": rob_val}}
143
144
145# ---------------------------- Entry point ----------------------------
146if __name__ == "__main__":
147    out = stable_training_demo()
148    print("\nSummary:")
149    print(json.dumps(out, indent=2))

Functional Explanation#

The model trains through controlled recursion and adaptive numerical optimization. Each component has a defined mathematical role in stabilizing and guiding the learning process.

  1. Recursive Depth Evolution

    Each backend performs a recursive update of the internal state:

    \[S_t^{(d)} = \tanh(W^{(d)} S_{t-1} + b^{(d)})\]

    The recursion depth \(d\) determines the number of internal evaluations per epoch. Increasing \(d\) allows the model to capture higher-order temporal dependencies.

  2. Future Projection

    Each current state generates \(K\) predicted future states:

    \[S_{t+1}^{(k)} = S_t + \Delta_d \cdot \mathcal{P}_k(S_t)\]

    where \(\mathcal{P}_k\) is a projection operator and \(\Delta_d\) scales with depth. This projection step introduces local temporal uncertainty and allows causal branching.

  3. Loss Structure

    The total loss integrates three objectives:

    \[\mathcal{L}_{total} = \mathcal{L}_{task} + 0.5\,\mathcal{L}_{consistency} + 0.3\,\mathcal{L}_{coherence}\]
    • Task loss \(\mathcal{L}_{task} = \frac{1}{T}\sum_t \|S_t - Y_t\|^2\) minimizes prediction error.

    • Consistency loss maintains temporal smoothness: \(\mathcal{L}_{consistency} = \alpha\|S_t - S_{t-1}\|^2 + \beta\|S_t - \hat{S}_{t+1}\|^2\).

    • Coherence loss enforces similarity among projected branches: \(\mathcal{L}_{coherence} = \mathrm{Var}(S_{t+1}^{(k)})\).

  4. Gradient Estimation

    Finite differences are used to estimate local gradients:

    \[g_i = \frac{\mathcal{L}(\theta_i + \epsilon) - \mathcal{L}(\theta_i - \epsilon)}{2\epsilon}\]

    This method avoids symbolic differentiation and remains stable under non-smooth operations.

  5. Adaptive Learning Parameters

    The effective parameters adjust with recursion depth:

    \[\eta_{\text{eff}} = \frac{\eta_0}{d^2}, \qquad \epsilon_{\text{eff}} = \frac{\epsilon_0}{1 + 0.5(d - 1)}\]

    These relations reduce step size and perturbation magnitude as depth increases, improving convergence stability for deeper causal recursions.

  6. Gradient Clipping

    All gradient vectors are constrained within a trust region:

    \[g_i' = g_i \cdot \min\left(1, \frac{\tau}{\|g\|_2}\right)\]

    where \(\tau\) is the clipping threshold. This ensures controlled parameter updates.

  7. Callback Coordination

    • DepthScheduler: adjusts recursion depth at specific epochs.

    • TelemetryLogger: records per-epoch statistics to JSONL.

    • MemoryLogger: stores metrics in memory for later visualization.

    These components synchronize the optimization and provide complete training traceability.

Exact Output#

[Epoch 0] total_before=0.031141 total_after=0.030663 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 1] total_before=0.030663 total_after=0.030217 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 2] total_before=0.030217 total_after=0.029800 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 3] total_before=0.025941 total_after=0.025629 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 4] total_before=0.025629 total_after=0.025326 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 5] total_before=0.025326 total_after=0.025030 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 6] total_before=0.025030 total_after=0.024742 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 7] total_before=0.024742 total_after=0.024462 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 8] total_before=0.024462 total_after=0.024190 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 9] total_before=0.022960 total_after=0.022723 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04
[Epoch 10] total_before=0.022723 total_after=0.022489 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04
[Epoch 11] total_before=0.022489 total_after=0.022259 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04

=== Final metrics (channel 0) ===
SMAPE:      100.000000 %
RMSE:       0.165016
Overshoot:  0.000000
Robustness: 0.973491

Best epoch snapshot: {'epoch': 11, 'task': 0.018001068416014097, 'cons': 0.0010439829584105837, 'coh': 0.012454292234950454, 'total': 0.022259347565704524}

Telemetry JSONL → runs/telemetry_stable.jsonl

Summary:
{
  "best": {
    "epoch": 11,
    "task": 0.018001068416014097,
    "cons": 0.0010439829584105837,
    "coh": 0.012454292234950454,
    "total": 0.022259347565704524
  },
  "metrics": {
    "smape": 100.0,
    "rmse": 0.1650163483217719,
    "overshoot": 0.0,
    "robustness": 0.9734914432630335
  }
}