Advanced Training with Callbacks#
Introduction#
This example presents an advanced configuration of the hyper-causal training framework. It extends the previous demo by incorporating external depth scheduling, freeze epochs, and adaptive gradient clipping controlled by recursion depth. The design emphasizes deterministic, depth-aware optimization behavior with stable convergence properties.
Key components include:
External DepthScheduler – adjusts recursion depth independently of the logger.
Freeze epochs – disable updates after depth transitions to stabilize new configurations.
Adaptive gradient clipping – scales gradient bounds dynamically with mean recursion depth.
Epoch-dependent learning decay – combines depth and time scaling for step-size control.
Parameter checkpointing – stores best-performing parameters in JSON format.
Robust metric evaluation – computes SMAPE, RMSE, Overshoot, and Robustness.
—
General Flow Structure#
The training loop performs controlled optimization with three main backends, each using recursive transformations and linear projections. Depth increases are scheduled externally, triggering freeze epochs where parameters remain static, allowing gradient statistics to stabilize before continuing updates.
DepthAwareBackend: recursive causal unit using \(S_t^{(d)} = \tanh(W^{(d)}S_{t-1} + b^{(d)})\).
Projection step: generates K possible futures with span scaled by depth.
Central finite-difference gradients: provide stable numerical estimation.
Adaptive learning control: combines recursion-based scaling with epoch decay.
Clipping: applied adaptively based on the mean depth to limit the L2 norm of the gradient.
—
How to Run#
# From the project root
python -m examples.ex_training_with_callbacks_advanced
# Or directly
python examples/ex_training_with_callbacks_advanced.py
—
Relevant Code Snippets#
1 np.ndarray
2 Current state vector.
3 """
4 if params:
5 self.set_params(params)
6 x = self._require_input().astype(float)
7 s = x
8 for _ in range(max(1, int(self.depth))):
9 s = np.tanh(self.w * s + self.b)
10 return self._validate_state(s)
11
12 def project_future(self, s_t: np.ndarray, branches: int = 2) -> np.ndarray:
13 """
14 Generate future states with depth-dependent span.
15
16 Returns
17 -------
18 np.ndarray
19 Projected future states of shape (K, D).
20 """
21 s = self._validate_state(s_t)
22 k = max(2, int(branches))
23 span = max(self._span_floor, self._base_span / (1.0 + 0.3 * (self.depth - 1)))
24 self._projector = LinearProjector(weight=1.0, bias=0.0, span=span)
25 fut = self._projector.project(s, branches=k)
26 return self._validate_branches(fut)
27
28
29# ----------------------------------------------------------------------
30# Model-building utilities
31# ----------------------------------------------------------------------
32def build_model_chain(D=3):
33 """Build a three-node model chain with depth-aware backends."""
34 cfg = BackendConfig(output_dim=D, seed=11)
35 b0 = DepthAwareBackend(cfg, w=0.90, b=0.03, proj_span=0.22)
36 b1 = DepthAwareBackend(cfg, w=0.97, b=0.02, proj_span=0.25)
37 b2 = DepthAwareBackend(cfg, w=1.05, b=0.00, proj_span=0.30)
38 pol = MeanPolicy()
39 n0, n1, n2 = HCNode(b0, pol), HCNode(b1, pol), HCNode(b2, pol)
40 model = HCModel([n0, n1, n2])
41 return model, [n0, n1, n2], [b0, b1, b2]
42
43
44def params_pack(backends):
45 """Flatten parameters of all backends into a single dictionary."""
46 packed = {}
47 for i, be in enumerate(backends):
48 for k, v in be.get_params().items():
49 packed[f"b{i}_{k}"] = np.array(v, dtype=float)
50 return packed
51
52
53def params_unpack(backends, packed):
54 """Distribute flat parameters back to each backend."""
55 for i, be in enumerate(backends):
56 sub = {}
57 for k in ("w", "b"):
58 key = f"b{i}_{k}"
59 if key in packed:
60 sub[k] = packed[key]
61 be.set_params(sub)
62
63
64# ----------------------------------------------------------------------
65# Finite-difference gradients (central)
66# ----------------------------------------------------------------------
67def central_diff_grads(loss_fn, params, apply_params_fn, eps: float):
68 """
69 Compute central finite-difference gradients for better stability.
70
71 Gradient ~= (f(x + eps) - f(x - eps)) / (2 * eps)
72
73 """
74 grads = {}
75 base = {k: v.copy() for k, v in params.items()}
76
77 def setp(p): apply_params_fn(p)
78 setp(base)
79 _ = loss_fn()
80
81 for k, v in base.items():
82 vp = {kk: vv.copy() for kk, vv in base.items()}
83 vm = {kk: vv.copy() for kk, vv in base.items()}
84 vp[k] = v + eps
85 vm[k] = v - eps
86 setp(vp); lp = loss_fn()
87 setp(vm); lm = loss_fn()
88 g = (lp - lm) / (2.0 * eps)
89 grads[k] = np.array([g], dtype=float)
90
91 setp(base)
92 return grads
93
94
95def grad_norm(grads: dict) -> float:
96 """Compute L2 norm of gradients."""
97 sq = 0.0
98 for g in grads.values():
99 val = float(np.asarray(g).reshape(()))
100 sq += val * val
101 return float(np.sqrt(sq))
102
103
104def clip_grads_adaptive(grads: dict, depth_mean: float) -> tuple[dict, float, float]:
105 """
106 Adaptive gradient clipping based on mean depth.
107
108 Returns
109 -------
110 tuple
111 (clipped_gradients, norm_before, norm_after)
112 """
113 if depth_mean < 1.5:
114 max_norm = 5e-2
115 elif depth_mean < 2.5:
116 max_norm = 7.5e-2
117 else:
118 max_norm = 1e-1
119
120 n_before = grad_norm(grads)
121 if n_before <= max_norm or n_before == 0.0:
122 return grads, n_before, n_before
123
124 scale = max_norm / n_before
125 clipped = {k: np.array([float(np.asarray(v).reshape(())) * scale], dtype=float) for k, v in grads.items()}
126 n_after = grad_norm(clipped)
127 return clipped, n_before, n_after
128
129
130# ----------------------------------------------------------------------
131# Training procedure
132# ----------------------------------------------------------------------
133def advanced_training_with_freeze():
134 """
135 Perform advanced hyper-causal training with depth freeze and adaptive clipping.
136
137 Returns
138 -------
139 dict
140 Dictionary containing best loss snapshot and final metrics.
141 """
142 D, K, T = 3, 5, 48
143 EPOCHS = 16
144 BASE_LR = 5e-2
145 BASE_EPS = 1e-3
146 LOG_PATH = Path("runs/telemetry_stable.jsonl")
147 SAVE_BEST = True
148 BEST_PATH = Path("runs/best_params.json")
149
150 model, nodes, backends = build_model_chain(D=D)
151
152 # Data
153 t = np.arange(T, dtype=float)
154 x_seq = np.stack([
155 0.30 * np.sin(0.35 * t + 0.00),
156 0.20 * np.sin(0.35 * t + 0.70),
157 0.10 * np.cos(0.35 * t + 0.30),
158 ], axis=1)
159 target_seq = np.zeros((T, D), dtype=float)
160
161 # Losses
162 loss_task = MSELoss()
163 loss_cons = ConsistencyLoss(alpha=0.8, beta=1.2)
164 loss_coh = CoherenceLoss(mode="variance")
165
166 # Optimizer and callbacks
167 params = params_pack(backends)
168 opt = make_gradient_descent(lr=BASE_LR)
169 state = opt.initialize(params)
170
171 callbacks = CallbackList([
172 TelemetryLogger(path=LOG_PATH, flush_interval=8),
173 MemoryLogger(),
174 ])
175 depth_cb = DepthScheduler(target_attr="depth", start=1, end=3, epochs=EPOCHS - 1)
176
177 def apply_params_fn(packed): params_unpack(backends, packed)
178
179 def forward_and_losses():
180 total_task = total_cons = total_coh = 0.0
181 s_tm1 = None
1 "freeze": True,
2 })
3 if best is None or det0["total"] < best["total"]:
4 best = {"epoch": int(epoch), **det0}
5 best_params = {k: float(np.asarray(v).reshape(())) for k, v in params.items()}
6 bad_epochs = 0
7 else:
8 bad_epochs += 1
9 print(f"[Epoch {epoch}] FREEZE depth={depths} total={det0['total']:.6f} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e}")
10 if bad_epochs > patience:
11 print(f"Early stopping activated at epoch {epoch} (freeze). Best total={best['total']:.6f} (epoch {best['epoch']}).")
12 break
13 continue
14
15 def loss_wrapper():
16 l, _, _ = forward_and_losses()
17 return l
18
19 grads = central_diff_grads(loss_wrapper, params, apply_params_fn, eps=eps_eff)
20 grads, gnorm_before, gnorm_after = clip_grads_adaptive(grads, depth_mean)
21 params, state = opt.step(params, grads, state)
22 apply_params_fn(params)
23
24 total1, det1, _ = forward_and_losses()
25
26 callbacks.on_epoch_end(epoch, {
27 "epoch": int(epoch),
28 "loss_before": det0,
29 "loss_after": det1,
30 "grad_norm_before": float(gnorm_before),
31 "grad_norm_after": float(gnorm_after),
32 "freeze": False,
33 })
34
35 if best is None or det1["total"] < best["total"]:
36 best = {"epoch": int(epoch), **det1}
37 best_params = {k: float(np.asarray(v).reshape(())) for k, v in params.items()}
38 bad_epochs = 0
39 else:
40 bad_epochs += 1
41
42 print(
43 f"[Epoch {epoch}] total_before={det0['total']:.6f} total_after={det1['total']:.6f} "
44 f"depth={depths} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e} "
45 f"||g||_before={gnorm_before:.3e} ||g||_after={gnorm_after:.3e}"
46 )
47
48 if bad_epochs > patience:
49 print(f"Early stopping activated at epoch {epoch}. Best total={best['total']:.6f} (epoch {best['epoch']}).")
50 break
51
52 if SAVE_BEST and best_params is not None:
53 BEST_PATH.parent.mkdir(parents=True, exist_ok=True)
54 with BEST_PATH.open("w") as f:
55 json.dump(best_params, f, indent=2)
56 print(f"\nBest parameters saved at: {BEST_PATH.resolve()}")
57
58 _, _, y_pred_seq = forward_and_losses()
59 smape_val = smape_safe(target_seq[:, 0], y_pred_seq[:, 0])
60 rmse_val = rmse(target_seq[:, 0], y_pred_seq[:, 0])
61 over_val = overshoot(target_seq[:, 0], y_pred_seq[:, 0])
62 rob_val = robustness(target_seq[:, 0], y_pred_seq[:, 0])
63
64 print("\n=== Final metrics (channel 0) ===")
65 print(f"SMAPE: {smape_val:.6f} %")
66 print(f"RMSE: {rmse_val:.6f}")
67 print(f"Overshoot: {over_val:.6f}")
68 print(f"Robustness: {rob_val:.6f}")
69 print("\nBest epoch snapshot:", {
70 "epoch": int(best["epoch"]),
71 "task": float(best["task"]),
72 "cons": float(best["cons"]),
73 "coh": float(best["coh"]),
74 "total": float(best["total"]),
75 })
76
77 if LOG_PATH.exists():
78 print(f"\nTelemetry JSONL → {LOG_PATH.resolve()}")
79
80 return {
81 "best": {
82 "epoch": int(best["epoch"]),
83 "task": float(best["task"]),
84 "cons": float(best["cons"]),
85 "coh": float(best["coh"]),
86 "total": float(best["total"]),
87 },
88 "metrics": {
89 "smape": smape_val,
90 "rmse": rmse_val,
91 "overshoot": over_val,
92 "robustness": rob_val
93 }
94 }
95
96
97# ----------------------------------------------------------------------
98# Entry point
99# ----------------------------------------------------------------------
100if __name__ == "__main__":
101 out = advanced_training_with_freeze()
102 print("\nSummary:")
—
Functional Explanation#
This training routine enhances the baseline model with depth adaptation and parameter freezing, resulting in a more controlled optimization process that preserves gradient stability. All components operate deterministically, and the system can reproduce results across runs.
Recursive Backend Dynamics
Each backend updates its state using depth-controlled recursion:
\[S_t^{(d)} = \tanh(W^{(d)} S_{t-1} + b^{(d)})\]Depth \(d\) determines the number of recursive evaluations per step. Increasing depth raises representational capacity but requires finer gradient control.
Future Projection and Span Scaling
Future states are generated through a linear projection mechanism with depth-dependent span:
\[S_{t+1}^{(k)} = S_t + \Delta_d \cdot \mathcal{P}_k(S_t)\]where \(\Delta_d\) decreases with increasing depth to maintain bounded perturbations. This ensures consistent diversity among causal branches without instability.
Composite Loss Function
The loss combines predictive, consistency, and coherence terms:
\[\mathcal{L}_{total} = \mathcal{L}_{task} + 0.5\,\mathcal{L}_{consistency} + 0.3\,\mathcal{L}_{coherence}\]Each term regulates a specific property: - Task: prediction accuracy. - Consistency: smooth temporal evolution. - Coherence: branch uniformity at projection level.
Finite-Difference Gradient Estimation
Gradients are computed numerically:
\[g_i = \frac{\mathcal{L}(\theta_i + \epsilon) - \mathcal{L}(\theta_i - \epsilon)}{2\epsilon}\]This avoids dependency on differentiable computation graphs and maintains stability under recursion.
Adaptive Learning and Perturbation Scaling
Learning parameters decay with both depth and epoch index:
\[\eta_{\text{eff}} = \frac{\eta_0}{(1 + 0.5(d - 1))(1 + 0.5e)}, \quad \epsilon_{\text{eff}} = \frac{\epsilon_0}{(1 + 0.3(d - 1))(1 + 0.3e)}\]where \(e\) is the current epoch. This provides temporal damping, ensuring smaller updates as the system stabilizes.
Freeze Epochs
After each depth increase, one epoch executes without parameter updates:
\[\theta_{t+1} = \theta_t \quad \text{if depth\_changed=True}\]This step prevents transient gradient noise from destabilizing new recursion levels.
Adaptive Gradient Clipping
The clipping threshold increases with mean depth \(\bar{d}\):
\[\begin{split}\tau = \begin{cases} 5\times10^{-2}, & \bar{d} < 1.5 \\ 7.5\times10^{-2}, & 1.5 \le \bar{d} < 2.5 \\ 1\times10^{-1}, & \bar{d} \ge 2.5 \end{cases}\end{split}\]The gradients are then rescaled:
\[g_i' = g_i \cdot \min\left(1, \frac{\tau}{\|g\|_2}\right)\]providing consistent control across all recursion depths.
Metric Evaluation
After training, four metrics summarize performance:
SMAPE: symmetric mean absolute percentage error.
RMSE: root mean square error.
Overshoot: excess deviation in prediction amplitude.
Robustness: correlation-based stability ratio.
All metrics are computed on the first output channel for reproducibility.
—
Exact Output#
[Epoch 0] total_before=0.031141 total_after=0.030663 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03 ||g||_before=1.971e-01 ||g||_after=5.000e-02
[Epoch 1] total_before=0.030663 total_after=0.030362 depth=[1, 1, 1] lr_eff=3.333e-02 eps_eff=7.692e-04 ||g||_before=1.848e-01 ||g||_after=5.000e-02
[Epoch 2] total_before=0.030362 total_after=0.030145 depth=[1, 1, 1] lr_eff=2.500e-02 eps_eff=6.250e-04 ||g||_before=1.767e-01 ||g||_after=5.000e-02
[Epoch 3] total_before=0.030145 total_after=0.029977 depth=[1, 1, 1] lr_eff=2.000e-02 eps_eff=5.263e-04 ||g||_before=1.707e-01 ||g||_after=5.000e-02
[Epoch 4] FREEZE depth=[2, 2, 2] total=0.026492 lr_eff=1.111e-02 eps_eff=3.497e-04
[Epoch 5] total_before=0.026492 total_after=0.026122 depth=[2, 2, 2] lr_eff=9.524e-03 eps_eff=3.077e-04 ||g||_before=5.258e-01 ||g||_after=7.500e-02
[Epoch 6] total_before=0.026122 total_after=0.025806 depth=[2, 2, 2] lr_eff=8.333e-03 eps_eff=2.747e-04 ||g||_before=5.114e-01 ||g||_after=7.500e-02
[Epoch 7] total_before=0.025806 total_after=0.025532 depth=[2, 2, 2] lr_eff=7.407e-03 eps_eff=2.481e-04 ||g||_before=4.988e-01 ||g||_after=7.500e-02
[Epoch 8] total_before=0.025532 total_after=0.025291 depth=[2, 2, 2] lr_eff=6.667e-03 eps_eff=2.262e-04 ||g||_before=4.876e-01 ||g||_after=7.500e-02
[Epoch 9] total_before=0.025291 total_after=0.025076 depth=[2, 2, 2] lr_eff=6.061e-03 eps_eff=2.079e-04 ||g||_before=4.775e-01 ||g||_after=7.500e-02
[Epoch 10] total_before=0.025076 total_after=0.024883 depth=[2, 2, 2] lr_eff=5.556e-03 eps_eff=1.923e-04 ||g||_before=4.683e-01 ||g||_after=7.500e-02
[Epoch 11] total_before=0.024883 total_after=0.024707 depth=[2, 2, 2] lr_eff=5.128e-03 eps_eff=1.789e-04 ||g||_before=4.599e-01 ||g||_after=7.500e-02
[Epoch 12] FREEZE depth=[3, 3, 3] total=0.023982 lr_eff=3.571e-03 eps_eff=1.359e-04
[Epoch 13] total_before=0.023982 total_after=0.023682 depth=[3, 3, 3] lr_eff=3.333e-03 eps_eff=1.276e-04 ||g||_before=9.096e-01 ||g||_after=1.000e-01
[Epoch 14] total_before=0.023682 total_after=0.023404 depth=[3, 3, 3] lr_eff=3.125e-03 eps_eff=1.202e-04 ||g||_before=8.949e-01 ||g||_after=1.000e-01
[Epoch 15] total_before=0.023404 total_after=0.023147 depth=[3, 3, 3] lr_eff=2.941e-03 eps_eff=1.136e-04 ||g||_before=8.810e-01 ||g||_after=1.000e-01
Best parameters saved at: runs/best_params.json
=== Final metrics (channel 0) ===
SMAPE: 100.000000 %
RMSE: 0.167734
Overshoot: 0.000000
Robustness: 0.972635
Best epoch snapshot: {'epoch': 15, 'task': 0.01889342623335241, 'cons': 0.0010430885080397983, 'coh': 0.01243963263570178, 'total': 0.023146860278082843}
Telemetry JSONL → runs/telemetry_stable.jsonl
Summary:
{
"best": {
"epoch": 15,
"task": 0.01889342623335241,
"cons": 0.0010430885080397983,
"coh": 0.01243963263570178,
"total": 0.023146860278082843
},
"metrics": {
"smape": 100.0,
"rmse": 0.16773380360977846,
"overshoot": 0.0,
"robustness": 0.9726352677136917
}
}