Training with Callbacks Demo#
Introduction#
This example presents a stable training framework for a three-node hyper-causal model. It demonstrates a complete adaptive optimization cycle using depth-aware recursion, finite-difference gradients, and callback-driven monitoring. The goal is to illustrate how a causal system can be trained under dynamic conditions without relying on classical backpropagation.
Key mechanisms include:
DepthScheduler – controls recursion depth dynamically per epoch.
Adaptive learning parameters – rescale learning rate and perturbation with recursion depth.
Finite-difference gradient estimation – numerical gradient replacement for stability.
Gradient clipping – limits the trust region to prevent divergence.
Callback telemetry – monitors all metrics and parameters across epochs.
Early stopping – detects convergence based on stability in total loss.
—
General Flow Structure#
The training loop coordinates recursive backends connected through causal dependencies. Each backend adjusts its transformation depth, projects future states, and contributes to a composite loss. The callback system regulates both the internal optimization and the external logging.
DepthAwareBackend: applies recursive transformations \(S_t^{(d)} = \tanh(W^{(d)}S_{t-1} + b^{(d)})\).
Projection: expands each state into \(K\) future branches via a linear projector.
Loss aggregation: combines task, consistency, and coherence components.
Optimizer: updates parameters using finite-difference gradients and gradient clipping.
Scheduler: increases recursion depth gradually to control system complexity.
—
How to Run#
# From the project root
python -m examples.ex_training_with_callbacks_demo
# Or directly
python examples/ex_training_with_callbacks_demo.py
—
Relevant Code Snippets#
1 self.w = float(np.asarray(params["w"]).reshape(()))
2 if "b" in params:
3 self.b = float(np.asarray(params["b"]).reshape(()))
4
5 def run(self, params: dict | None = None) -> np.ndarray:
6 """
7 Apply depth-recursive tanh transformation.
8
9 Parameters
10 ----------
11 params : dict or None, optional
12 Optional parameter override for this call.
13
14 Returns
15 -------
16 np.ndarray
17 Validated current state vector.
18 """
19 if params:
20 self.set_params(params)
21 x = self._require_input().astype(float)
22 s = x
23 for _ in range(max(1, int(self.depth))):
24 s = np.tanh(self.w * s + self.b)
25 return self._validate_state(s)
26
27 def project_future(self, s_t: np.ndarray, branches: int = 2) -> np.ndarray:
28 """
29 Project future states around ``s_t`` using a depth-adjusted span.
30
31 Parameters
32 ----------
33 s_t : np.ndarray
34 Current state vector.
35 branches : int, optional
36 Number of future branches (K), by default 2.
37
38 Returns
39 -------
40 np.ndarray
41 Future branch matrix with shape ``(K, D)``.
42 """
43 s = self._validate_state(s_t)
44 k = max(2, int(branches))
45 # span reduced with depth, but with a high floor (0.10)
46 span = max(self._span_floor, self._base_span / (1.0 + 0.3 * (self.depth - 1)))
47 self._projector = LinearProjector(weight=1.0, bias=0.0, span=span)
48 fut = self._projector.project(s, branches=k)
49 return self._validate_branches(fut)
50
51
52# ------------------------ Construction utils ------------------------
53def build_model_chain(D=3):
54 """
55 Build a three-node HCModel with depth-aware backends.
56
57 Parameters
58 ----------
59 D : int, optional
60 State dimensionality, by default 3.
61
62 Returns
63 -------
64 tuple
65 (model, nodes, backends)
66 """
67 cfg = BackendConfig(output_dim=D, seed=11)
68 b0 = DepthAwareBackend(cfg, w=0.90, b=0.03, proj_span=0.22)
69 b1 = DepthAwareBackend(cfg, w=0.97, b=0.02, proj_span=0.25)
70 b2 = DepthAwareBackend(cfg, w=1.05, b=0.00, proj_span=0.30)
71 pol = MeanPolicy()
72 n0, n1, n2 = HCNode(b0, pol), HCNode(b1, pol), HCNode(b2, pol)
73 model = HCModel([n0, n1, n2])
74 return model, [n0, n1, n2], [b0, b1, b2]
75
76
77def params_pack(backends):
78 """
79 Flatten parameters from all backends into a single dictionary.
80
81 Parameters
82 ----------
83 backends : list
84 List of backend instances.
85
86 Returns
87 -------
88 dict
89 Flattened parameter dictionary suitable for the optimizer.
90 """
91 packed = {}
92 for i, be in enumerate(backends):
93 for k, v in be.get_params().items():
94 packed[f"b{i}_{k}"] = np.array(v, dtype=float)
95 return packed
96
97
98def params_unpack(backends, packed):
99 """
100 Distribute flat parameters back to their corresponding backends.
101
102 Parameters
103 ----------
104 backends : list
105 Backend instances to update.
106 packed : dict
107 Flattened parameter dictionary.
108 """
109 for i, be in enumerate(backends):
110 sub = {}
111 for k in ("w", "b"):
112 key = f"b{i}_{k}"
113 if key in packed:
114 sub[k] = packed[key]
115 be.set_params(sub)
116
117
118# ------------------- CENTRAL finite-difference grads -------------------
119def central_diff_grads(loss_fn, params, apply_params_fn, eps: float):
120 """
121 Central finite-difference gradients (more stable than forward diff).
122
123 Gradient ≈ (f(x + eps) - f(x - eps)) / (2 * eps)
124
125 Parameters
126 ----------
127 loss_fn : callable
128 Function that recomputes the full loss given current parameters.
129 params : dict
130 Parameter dictionary (values are scalar arrays).
131 apply_params_fn : callable
132 Function to apply a parameter dictionary to the model.
133 eps : float
1 BASE_EPS = 1e-3
2 LOG_PATH = Path("runs/telemetry_stable.jsonl")
3
4 # Model
5 model, nodes, backends = build_model_chain(D=D)
6
7 # Data
8 t = np.arange(T, dtype=float)
9 x_seq = np.stack([
10 0.30 * np.sin(0.35 * t + 0.00),
11 0.20 * np.sin(0.35 * t + 0.70),
12 0.10 * np.cos(0.35 * t + 0.30),
13 ], axis=1)
14 target_seq = np.zeros((T, D), dtype=float)
15
16 # Losses
17 loss_task = MSELoss()
18 loss_cons = ConsistencyLoss(alpha=0.8, beta=1.2)
19 loss_coh = CoherenceLoss(mode="variance")
20
21 # Optimizer + telemetry
22 params = params_pack(backends)
23 opt = make_gradient_descent(lr=BASE_LR) # base LR; recalibrated by depth each epoch
24 state = opt.initialize(params)
25
26 callbacks = CallbackList([
27 TelemetryLogger(path=LOG_PATH, flush_interval=8),
28 MemoryLogger(),
29 ])
30 # Real DepthScheduler (1 → 3 across EPOCHS-1; clamped at 3)
31 depth_cb = DepthScheduler(target_attr="depth", start=1, end=3, epochs=EPOCHS - 1)
32
33 def apply_params_fn(packed):
34 params_unpack(backends, packed)
35
36 def forward_and_losses():
37 """
38 Compute forward pass over the sequence and all loss terms.
39
40 Returns
41 -------
42 tuple
43 (total_loss, details_dict, last_state_sequence)
44 """
45 total_task = total_cons = total_coh = 0.0
46 s_tm1 = None
47 y_last = []
48 for step in range(T):
49 callbacks.on_step_begin(step, {"step": int(step)})
50 s_t, s_hat, infos = model.forward_chain(x_seq[step], s_tm1=s_tm1, branches=K)
51 y_last.append(s_t)
52 total_task += loss_task(s_t, target_seq[step])
53 if s_tm1 is not None:
54 total_cons += loss_cons(s_tm1, s_t, s_hat)
55 coh_vals = []
56 for info in infos:
57 br = info.get("branches", None)
58 if isinstance(br, np.ndarray) and br.ndim == 2:
59 coh_vals.append(loss_coh(br))
60 if coh_vals:
61 total_coh += float(np.mean(coh_vals))
62 s_tm1 = s_t
63 callbacks.on_step_end(step, {"step": int(step)})
64 task = total_task / T
65 cons = total_cons / max(1, T - 1)
66 coh = total_coh / T
67 total = task + 0.5 * cons + 0.3 * coh
68 return total, {"task": float(task), "cons": float(cons), "coh": float(coh), "total": float(total)}, np.asarray(y_last)
69
70 best = None
71 patience = 1
72 bad_epochs = 0
73
74 for epoch in range(EPOCHS):
75 # 1) Adjust depth (real)
76 for be in backends:
77 depth_cb.on_epoch_begin(epoch, {"backend": be})
78 depth_mean = float(np.mean([be.depth for be in backends]))
79
80 # 2) Depth-adaptive LR and EPS
81 lr_eff = BASE_LR / (depth_mean ** 2)
82 eps_eff = BASE_EPS / (1.0 + 0.5 * (depth_mean - 1.0))
83 opt = make_gradient_descent(lr=lr_eff) # recreate with effective LR
84
85 # 3) JSON-safe telemetry
86 callbacks.on_epoch_begin(epoch, {"epoch": int(epoch), "depth": float(depth_mean), "lr_eff": float(lr_eff), "eps_eff": float(eps_eff)})
87
88 # 4) Forward before update
89 total0, det0, _ = forward_and_losses()
90
91 # 5) Central gradients + clipping + step
92 def loss_wrapper():
93 l, _, _ = forward_and_losses()
94 return l
95
96 grads = central_diff_grads(loss_wrapper, params, apply_params_fn, eps=eps_eff)
97 grads = clip_grads(grads, max_norm=5e-2)
98 params, state = opt.step(params, grads, state)
99 apply_params_fn(params)
100
101 # 6) Forward after update
102 total1, det1, _ = forward_and_losses()
103
104 # 7) Telemetry end
105 callbacks.on_epoch_end(epoch, {
106 "epoch": int(epoch),
107 "loss_before": det0,
108 "loss_after": det1,
109 })
110
111 # 8) Early stopping
112 if best is None or det1["total"] < best["total"]:
113 best = {"epoch": int(epoch), **det1}
114 bad_epochs = 0
115 else:
116 bad_epochs += 1
117
118 depths = [int(be.depth) for be in backends]
119 print(f"[Epoch {epoch}] total_before={det0['total']:.6f} total_after={det1['total']:.6f} depth={depths} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e}")
120
121 if bad_epochs > patience:
122 print(f"Early stopping activated at epoch {epoch}. Best total={best['total']:.6f} (epoch {best['epoch']}).")
123 break
124
125 # Final metrics using the last forward
126 _, _, y_pred_seq = forward_and_losses()
127 smape_val = smape_safe(target_seq[:, 0], y_pred_seq[:, 0])
128 rmse_val = rmse(target_seq[:, 0], y_pred_seq[:, 0])
129 over_val = overshoot(target_seq[:, 0], y_pred_seq[:, 0])
130 rob_val = robustness(target_seq[:, 0], y_pred_seq[:, 0])
131
132 print("\n=== Final metrics (channel 0) ===")
133 print(f"SMAPE: {smape_val:.6f} %")
134 print(f"RMSE: {rmse_val:.6f}")
135 print(f"Overshoot: {over_val:.6f}")
136 print(f"Robustness: {rob_val:.6f}")
137 print("\nBest epoch snapshot:", best)
138
139 if LOG_PATH.exists():
140 print(f"\nTelemetry JSONL → {LOG_PATH.resolve()}")
141
142 return {"best": best, "metrics": {"smape": smape_val, "rmse": rmse_val, "overshoot": over_val, "robustness": rob_val}}
143
144
145# ---------------------------- Entry point ----------------------------
146if __name__ == "__main__":
147 out = stable_training_demo()
148 print("\nSummary:")
149 print(json.dumps(out, indent=2))
—
Functional Explanation#
The model trains through controlled recursion and adaptive numerical optimization. Each component has a defined mathematical role in stabilizing and guiding the learning process.
Recursive Depth Evolution
Each backend performs a recursive update of the internal state:
\[S_t^{(d)} = \tanh(W^{(d)} S_{t-1} + b^{(d)})\]The recursion depth \(d\) determines the number of internal evaluations per epoch. Increasing \(d\) allows the model to capture higher-order temporal dependencies.
Future Projection
Each current state generates \(K\) predicted future states:
\[S_{t+1}^{(k)} = S_t + \Delta_d \cdot \mathcal{P}_k(S_t)\]where \(\mathcal{P}_k\) is a projection operator and \(\Delta_d\) scales with depth. This projection step introduces local temporal uncertainty and allows causal branching.
Loss Structure
The total loss integrates three objectives:
\[\mathcal{L}_{total} = \mathcal{L}_{task} + 0.5\,\mathcal{L}_{consistency} + 0.3\,\mathcal{L}_{coherence}\]Task loss \(\mathcal{L}_{task} = \frac{1}{T}\sum_t \|S_t - Y_t\|^2\) minimizes prediction error.
Consistency loss maintains temporal smoothness: \(\mathcal{L}_{consistency} = \alpha\|S_t - S_{t-1}\|^2 + \beta\|S_t - \hat{S}_{t+1}\|^2\).
Coherence loss enforces similarity among projected branches: \(\mathcal{L}_{coherence} = \mathrm{Var}(S_{t+1}^{(k)})\).
Gradient Estimation
Finite differences are used to estimate local gradients:
\[g_i = \frac{\mathcal{L}(\theta_i + \epsilon) - \mathcal{L}(\theta_i - \epsilon)}{2\epsilon}\]This method avoids symbolic differentiation and remains stable under non-smooth operations.
Adaptive Learning Parameters
The effective parameters adjust with recursion depth:
\[\eta_{\text{eff}} = \frac{\eta_0}{d^2}, \qquad \epsilon_{\text{eff}} = \frac{\epsilon_0}{1 + 0.5(d - 1)}\]These relations reduce step size and perturbation magnitude as depth increases, improving convergence stability for deeper causal recursions.
Gradient Clipping
All gradient vectors are constrained within a trust region:
\[g_i' = g_i \cdot \min\left(1, \frac{\tau}{\|g\|_2}\right)\]where \(\tau\) is the clipping threshold. This ensures controlled parameter updates.
Callback Coordination
DepthScheduler: adjusts recursion depth at specific epochs.
TelemetryLogger: records per-epoch statistics to JSONL.
MemoryLogger: stores metrics in memory for later visualization.
These components synchronize the optimization and provide complete training traceability.
—
Exact Output#
[Epoch 0] total_before=0.031141 total_after=0.030663 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 1] total_before=0.030663 total_after=0.030217 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 2] total_before=0.030217 total_after=0.029800 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03
[Epoch 3] total_before=0.025941 total_after=0.025629 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 4] total_before=0.025629 total_after=0.025326 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 5] total_before=0.025326 total_after=0.025030 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 6] total_before=0.025030 total_after=0.024742 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 7] total_before=0.024742 total_after=0.024462 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 8] total_before=0.024462 total_after=0.024190 depth=[2, 2, 2] lr_eff=1.250e-02 eps_eff=6.667e-04
[Epoch 9] total_before=0.022960 total_after=0.022723 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04
[Epoch 10] total_before=0.022723 total_after=0.022489 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04
[Epoch 11] total_before=0.022489 total_after=0.022259 depth=[3, 3, 3] lr_eff=5.556e-03 eps_eff=5.000e-04
=== Final metrics (channel 0) ===
SMAPE: 100.000000 %
RMSE: 0.165016
Overshoot: 0.000000
Robustness: 0.973491
Best epoch snapshot: {'epoch': 11, 'task': 0.018001068416014097, 'cons': 0.0010439829584105837, 'coh': 0.012454292234950454, 'total': 0.022259347565704524}
Telemetry JSONL → runs/telemetry_stable.jsonl
Summary:
{
"best": {
"epoch": 11,
"task": 0.018001068416014097,
"cons": 0.0010439829584105837,
"coh": 0.012454292234950454,
"total": 0.022259347565704524
},
"metrics": {
"smape": 100.0,
"rmse": 0.1650163483217719,
"overshoot": 0.0,
"robustness": 0.9734914432630335
}
}