Advanced Training with Callbacks#

Introduction#

This example presents an advanced configuration of the hyper-causal training framework. It extends the previous demo by incorporating external depth scheduling, freeze epochs, and adaptive gradient clipping controlled by recursion depth. The design emphasizes deterministic, depth-aware optimization behavior with stable convergence properties.

Key components include:

  1. External DepthScheduler – adjusts recursion depth independently of the logger.

  2. Freeze epochs – disable updates after depth transitions to stabilize new configurations.

  3. Adaptive gradient clipping – scales gradient bounds dynamically with mean recursion depth.

  4. Epoch-dependent learning decay – combines depth and time scaling for step-size control.

  5. Parameter checkpointing – stores best-performing parameters in JSON format.

  6. Robust metric evaluation – computes SMAPE, RMSE, Overshoot, and Robustness.

General Flow Structure#

The training loop performs controlled optimization with three main backends, each using recursive transformations and linear projections. Depth increases are scheduled externally, triggering freeze epochs where parameters remain static, allowing gradient statistics to stabilize before continuing updates.

  • DepthAwareBackend: recursive causal unit using \(S_t^{(d)} = \tanh(W^{(d)}S_{t-1} + b^{(d)})\).

  • Projection step: generates K possible futures with span scaled by depth.

  • Central finite-difference gradients: provide stable numerical estimation.

  • Adaptive learning control: combines recursion-based scaling with epoch decay.

  • Clipping: applied adaptively based on the mean depth to limit the L2 norm of the gradient.

How to Run#

# From the project root
python -m examples.ex_training_with_callbacks_advanced

# Or directly
python examples/ex_training_with_callbacks_advanced.py

Relevant Code Snippets#

Definition of the DepthAwareBackend class and gradient control utilities.#
  1        np.ndarray
  2            Current state vector.
  3        """
  4        if params:
  5            self.set_params(params)
  6        x = self._require_input().astype(float)
  7        s = x
  8        for _ in range(max(1, int(self.depth))):
  9            s = np.tanh(self.w * s + self.b)
 10        return self._validate_state(s)
 11
 12    def project_future(self, s_t: np.ndarray, branches: int = 2) -> np.ndarray:
 13        """
 14        Generate future states with depth-dependent span.
 15
 16        Returns
 17        -------
 18        np.ndarray
 19            Projected future states of shape (K, D).
 20        """
 21        s = self._validate_state(s_t)
 22        k = max(2, int(branches))
 23        span = max(self._span_floor, self._base_span / (1.0 + 0.3 * (self.depth - 1)))
 24        self._projector = LinearProjector(weight=1.0, bias=0.0, span=span)
 25        fut = self._projector.project(s, branches=k)
 26        return self._validate_branches(fut)
 27
 28
 29# ----------------------------------------------------------------------
 30# Model-building utilities
 31# ----------------------------------------------------------------------
 32def build_model_chain(D=3):
 33    """Build a three-node model chain with depth-aware backends."""
 34    cfg = BackendConfig(output_dim=D, seed=11)
 35    b0 = DepthAwareBackend(cfg, w=0.90, b=0.03, proj_span=0.22)
 36    b1 = DepthAwareBackend(cfg, w=0.97, b=0.02, proj_span=0.25)
 37    b2 = DepthAwareBackend(cfg, w=1.05, b=0.00, proj_span=0.30)
 38    pol = MeanPolicy()
 39    n0, n1, n2 = HCNode(b0, pol), HCNode(b1, pol), HCNode(b2, pol)
 40    model = HCModel([n0, n1, n2])
 41    return model, [n0, n1, n2], [b0, b1, b2]
 42
 43
 44def params_pack(backends):
 45    """Flatten parameters of all backends into a single dictionary."""
 46    packed = {}
 47    for i, be in enumerate(backends):
 48        for k, v in be.get_params().items():
 49            packed[f"b{i}_{k}"] = np.array(v, dtype=float)
 50    return packed
 51
 52
 53def params_unpack(backends, packed):
 54    """Distribute flat parameters back to each backend."""
 55    for i, be in enumerate(backends):
 56        sub = {}
 57        for k in ("w", "b"):
 58            key = f"b{i}_{k}"
 59            if key in packed:
 60                sub[k] = packed[key]
 61        be.set_params(sub)
 62
 63
 64# ----------------------------------------------------------------------
 65# Finite-difference gradients (central)
 66# ----------------------------------------------------------------------
 67def central_diff_grads(loss_fn, params, apply_params_fn, eps: float):
 68    """
 69    Compute central finite-difference gradients for better stability.
 70
 71    Gradient ~= (f(x + eps) - f(x - eps)) / (2 * eps)
 72
 73    """
 74    grads = {}
 75    base = {k: v.copy() for k, v in params.items()}
 76
 77    def setp(p): apply_params_fn(p)
 78    setp(base)
 79    _ = loss_fn()
 80
 81    for k, v in base.items():
 82        vp = {kk: vv.copy() for kk, vv in base.items()}
 83        vm = {kk: vv.copy() for kk, vv in base.items()}
 84        vp[k] = v + eps
 85        vm[k] = v - eps
 86        setp(vp); lp = loss_fn()
 87        setp(vm); lm = loss_fn()
 88        g = (lp - lm) / (2.0 * eps)
 89        grads[k] = np.array([g], dtype=float)
 90
 91    setp(base)
 92    return grads
 93
 94
 95def grad_norm(grads: dict) -> float:
 96    """Compute L2 norm of gradients."""
 97    sq = 0.0
 98    for g in grads.values():
 99        val = float(np.asarray(g).reshape(()))
100        sq += val * val
101    return float(np.sqrt(sq))
102
103
104def clip_grads_adaptive(grads: dict, depth_mean: float) -> tuple[dict, float, float]:
105    """
106    Adaptive gradient clipping based on mean depth.
107
108    Returns
109    -------
110    tuple
111        (clipped_gradients, norm_before, norm_after)
112    """
113    if depth_mean < 1.5:
114        max_norm = 5e-2
115    elif depth_mean < 2.5:
116        max_norm = 7.5e-2
117    else:
118        max_norm = 1e-1
119
120    n_before = grad_norm(grads)
121    if n_before <= max_norm or n_before == 0.0:
122        return grads, n_before, n_before
123
124    scale = max_norm / n_before
125    clipped = {k: np.array([float(np.asarray(v).reshape(())) * scale], dtype=float) for k, v in grads.items()}
126    n_after = grad_norm(clipped)
127    return clipped, n_before, n_after
128
129
130# ----------------------------------------------------------------------
131# Training procedure
132# ----------------------------------------------------------------------
133def advanced_training_with_freeze():
134    """
135    Perform advanced hyper-causal training with depth freeze and adaptive clipping.
136
137    Returns
138    -------
139    dict
140        Dictionary containing best loss snapshot and final metrics.
141    """
142    D, K, T = 3, 5, 48
143    EPOCHS = 16
144    BASE_LR = 5e-2
145    BASE_EPS = 1e-3
146    LOG_PATH = Path("runs/telemetry_stable.jsonl")
147    SAVE_BEST = True
148    BEST_PATH = Path("runs/best_params.json")
149
150    model, nodes, backends = build_model_chain(D=D)
151
152    # Data
153    t = np.arange(T, dtype=float)
154    x_seq = np.stack([
155        0.30 * np.sin(0.35 * t + 0.00),
156        0.20 * np.sin(0.35 * t + 0.70),
157        0.10 * np.cos(0.35 * t + 0.30),
158    ], axis=1)
159    target_seq = np.zeros((T, D), dtype=float)
160
161    # Losses
162    loss_task = MSELoss()
163    loss_cons = ConsistencyLoss(alpha=0.8, beta=1.2)
164    loss_coh = CoherenceLoss(mode="variance")
165
166    # Optimizer and callbacks
167    params = params_pack(backends)
168    opt = make_gradient_descent(lr=BASE_LR)
169    state = opt.initialize(params)
170
171    callbacks = CallbackList([
172        TelemetryLogger(path=LOG_PATH, flush_interval=8),
173        MemoryLogger(),
174    ])
175    depth_cb = DepthScheduler(target_attr="depth", start=1, end=3, epochs=EPOCHS - 1)
176
177    def apply_params_fn(packed): params_unpack(backends, packed)
178
179    def forward_and_losses():
180        total_task = total_cons = total_coh = 0.0
181        s_tm1 = None
Main training procedure advanced_training_with_freeze() implementing freeze logic and adaptive clipping.#
  1                "freeze": True,
  2            })
  3            if best is None or det0["total"] < best["total"]:
  4                best = {"epoch": int(epoch), **det0}
  5                best_params = {k: float(np.asarray(v).reshape(())) for k, v in params.items()}
  6                bad_epochs = 0
  7            else:
  8                bad_epochs += 1
  9            print(f"[Epoch {epoch}] FREEZE depth={depths} total={det0['total']:.6f} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e}")
 10            if bad_epochs > patience:
 11                print(f"Early stopping activated at epoch {epoch} (freeze). Best total={best['total']:.6f} (epoch {best['epoch']}).")
 12                break
 13            continue
 14
 15        def loss_wrapper():
 16            l, _, _ = forward_and_losses()
 17            return l
 18
 19        grads = central_diff_grads(loss_wrapper, params, apply_params_fn, eps=eps_eff)
 20        grads, gnorm_before, gnorm_after = clip_grads_adaptive(grads, depth_mean)
 21        params, state = opt.step(params, grads, state)
 22        apply_params_fn(params)
 23
 24        total1, det1, _ = forward_and_losses()
 25
 26        callbacks.on_epoch_end(epoch, {
 27            "epoch": int(epoch),
 28            "loss_before": det0,
 29            "loss_after": det1,
 30            "grad_norm_before": float(gnorm_before),
 31            "grad_norm_after": float(gnorm_after),
 32            "freeze": False,
 33        })
 34
 35        if best is None or det1["total"] < best["total"]:
 36            best = {"epoch": int(epoch), **det1}
 37            best_params = {k: float(np.asarray(v).reshape(())) for k, v in params.items()}
 38            bad_epochs = 0
 39        else:
 40            bad_epochs += 1
 41
 42        print(
 43            f"[Epoch {epoch}] total_before={det0['total']:.6f} total_after={det1['total']:.6f} "
 44            f"depth={depths} lr_eff={lr_eff:.3e} eps_eff={eps_eff:.3e} "
 45            f"||g||_before={gnorm_before:.3e} ||g||_after={gnorm_after:.3e}"
 46        )
 47
 48        if bad_epochs > patience:
 49            print(f"Early stopping activated at epoch {epoch}. Best total={best['total']:.6f} (epoch {best['epoch']}).")
 50            break
 51
 52    if SAVE_BEST and best_params is not None:
 53        BEST_PATH.parent.mkdir(parents=True, exist_ok=True)
 54        with BEST_PATH.open("w") as f:
 55            json.dump(best_params, f, indent=2)
 56        print(f"\nBest parameters saved at: {BEST_PATH.resolve()}")
 57
 58    _, _, y_pred_seq = forward_and_losses()
 59    smape_val = smape_safe(target_seq[:, 0], y_pred_seq[:, 0])
 60    rmse_val = rmse(target_seq[:, 0], y_pred_seq[:, 0])
 61    over_val = overshoot(target_seq[:, 0], y_pred_seq[:, 0])
 62    rob_val = robustness(target_seq[:, 0], y_pred_seq[:, 0])
 63
 64    print("\n=== Final metrics (channel 0) ===")
 65    print(f"SMAPE:      {smape_val:.6f} %")
 66    print(f"RMSE:       {rmse_val:.6f}")
 67    print(f"Overshoot:  {over_val:.6f}")
 68    print(f"Robustness: {rob_val:.6f}")
 69    print("\nBest epoch snapshot:", {
 70        "epoch": int(best["epoch"]),
 71        "task": float(best["task"]),
 72        "cons": float(best["cons"]),
 73        "coh":  float(best["coh"]),
 74        "total": float(best["total"]),
 75    })
 76
 77    if LOG_PATH.exists():
 78        print(f"\nTelemetry JSONL → {LOG_PATH.resolve()}")
 79
 80    return {
 81        "best": {
 82            "epoch": int(best["epoch"]),
 83            "task": float(best["task"]),
 84            "cons": float(best["cons"]),
 85            "coh":  float(best["coh"]),
 86            "total": float(best["total"]),
 87        },
 88        "metrics": {
 89            "smape": smape_val,
 90            "rmse": rmse_val,
 91            "overshoot": over_val,
 92            "robustness": rob_val
 93        }
 94    }
 95
 96
 97# ----------------------------------------------------------------------
 98# Entry point
 99# ----------------------------------------------------------------------
100if __name__ == "__main__":
101    out = advanced_training_with_freeze()
102    print("\nSummary:")

Functional Explanation#

This training routine enhances the baseline model with depth adaptation and parameter freezing, resulting in a more controlled optimization process that preserves gradient stability. All components operate deterministically, and the system can reproduce results across runs.

  1. Recursive Backend Dynamics

    Each backend updates its state using depth-controlled recursion:

    \[S_t^{(d)} = \tanh(W^{(d)} S_{t-1} + b^{(d)})\]

    Depth \(d\) determines the number of recursive evaluations per step. Increasing depth raises representational capacity but requires finer gradient control.

  2. Future Projection and Span Scaling

    Future states are generated through a linear projection mechanism with depth-dependent span:

    \[S_{t+1}^{(k)} = S_t + \Delta_d \cdot \mathcal{P}_k(S_t)\]

    where \(\Delta_d\) decreases with increasing depth to maintain bounded perturbations. This ensures consistent diversity among causal branches without instability.

  3. Composite Loss Function

    The loss combines predictive, consistency, and coherence terms:

    \[\mathcal{L}_{total} = \mathcal{L}_{task} + 0.5\,\mathcal{L}_{consistency} + 0.3\,\mathcal{L}_{coherence}\]

    Each term regulates a specific property: - Task: prediction accuracy. - Consistency: smooth temporal evolution. - Coherence: branch uniformity at projection level.

  4. Finite-Difference Gradient Estimation

    Gradients are computed numerically:

    \[g_i = \frac{\mathcal{L}(\theta_i + \epsilon) - \mathcal{L}(\theta_i - \epsilon)}{2\epsilon}\]

    This avoids dependency on differentiable computation graphs and maintains stability under recursion.

  5. Adaptive Learning and Perturbation Scaling

    Learning parameters decay with both depth and epoch index:

    \[\eta_{\text{eff}} = \frac{\eta_0}{(1 + 0.5(d - 1))(1 + 0.5e)}, \quad \epsilon_{\text{eff}} = \frac{\epsilon_0}{(1 + 0.3(d - 1))(1 + 0.3e)}\]

    where \(e\) is the current epoch. This provides temporal damping, ensuring smaller updates as the system stabilizes.

  6. Freeze Epochs

    After each depth increase, one epoch executes without parameter updates:

    \[\theta_{t+1} = \theta_t \quad \text{if depth\_changed=True}\]

    This step prevents transient gradient noise from destabilizing new recursion levels.

  7. Adaptive Gradient Clipping

    The clipping threshold increases with mean depth \(\bar{d}\):

    \[\begin{split}\tau = \begin{cases} 5\times10^{-2}, & \bar{d} < 1.5 \\ 7.5\times10^{-2}, & 1.5 \le \bar{d} < 2.5 \\ 1\times10^{-1}, & \bar{d} \ge 2.5 \end{cases}\end{split}\]

    The gradients are then rescaled:

    \[g_i' = g_i \cdot \min\left(1, \frac{\tau}{\|g\|_2}\right)\]

    providing consistent control across all recursion depths.

  8. Metric Evaluation

    After training, four metrics summarize performance:

    • SMAPE: symmetric mean absolute percentage error.

    • RMSE: root mean square error.

    • Overshoot: excess deviation in prediction amplitude.

    • Robustness: correlation-based stability ratio.

    All metrics are computed on the first output channel for reproducibility.

Exact Output#

[Epoch 0] total_before=0.031141 total_after=0.030663 depth=[1, 1, 1] lr_eff=5.000e-02 eps_eff=1.000e-03 ||g||_before=1.971e-01 ||g||_after=5.000e-02
[Epoch 1] total_before=0.030663 total_after=0.030362 depth=[1, 1, 1] lr_eff=3.333e-02 eps_eff=7.692e-04 ||g||_before=1.848e-01 ||g||_after=5.000e-02
[Epoch 2] total_before=0.030362 total_after=0.030145 depth=[1, 1, 1] lr_eff=2.500e-02 eps_eff=6.250e-04 ||g||_before=1.767e-01 ||g||_after=5.000e-02
[Epoch 3] total_before=0.030145 total_after=0.029977 depth=[1, 1, 1] lr_eff=2.000e-02 eps_eff=5.263e-04 ||g||_before=1.707e-01 ||g||_after=5.000e-02
[Epoch 4] FREEZE depth=[2, 2, 2] total=0.026492 lr_eff=1.111e-02 eps_eff=3.497e-04
[Epoch 5] total_before=0.026492 total_after=0.026122 depth=[2, 2, 2] lr_eff=9.524e-03 eps_eff=3.077e-04 ||g||_before=5.258e-01 ||g||_after=7.500e-02
[Epoch 6] total_before=0.026122 total_after=0.025806 depth=[2, 2, 2] lr_eff=8.333e-03 eps_eff=2.747e-04 ||g||_before=5.114e-01 ||g||_after=7.500e-02
[Epoch 7] total_before=0.025806 total_after=0.025532 depth=[2, 2, 2] lr_eff=7.407e-03 eps_eff=2.481e-04 ||g||_before=4.988e-01 ||g||_after=7.500e-02
[Epoch 8] total_before=0.025532 total_after=0.025291 depth=[2, 2, 2] lr_eff=6.667e-03 eps_eff=2.262e-04 ||g||_before=4.876e-01 ||g||_after=7.500e-02
[Epoch 9] total_before=0.025291 total_after=0.025076 depth=[2, 2, 2] lr_eff=6.061e-03 eps_eff=2.079e-04 ||g||_before=4.775e-01 ||g||_after=7.500e-02
[Epoch 10] total_before=0.025076 total_after=0.024883 depth=[2, 2, 2] lr_eff=5.556e-03 eps_eff=1.923e-04 ||g||_before=4.683e-01 ||g||_after=7.500e-02
[Epoch 11] total_before=0.024883 total_after=0.024707 depth=[2, 2, 2] lr_eff=5.128e-03 eps_eff=1.789e-04 ||g||_before=4.599e-01 ||g||_after=7.500e-02
[Epoch 12] FREEZE depth=[3, 3, 3] total=0.023982 lr_eff=3.571e-03 eps_eff=1.359e-04
[Epoch 13] total_before=0.023982 total_after=0.023682 depth=[3, 3, 3] lr_eff=3.333e-03 eps_eff=1.276e-04 ||g||_before=9.096e-01 ||g||_after=1.000e-01
[Epoch 14] total_before=0.023682 total_after=0.023404 depth=[3, 3, 3] lr_eff=3.125e-03 eps_eff=1.202e-04 ||g||_before=8.949e-01 ||g||_after=1.000e-01
[Epoch 15] total_before=0.023404 total_after=0.023147 depth=[3, 3, 3] lr_eff=2.941e-03 eps_eff=1.136e-04 ||g||_before=8.810e-01 ||g||_after=1.000e-01

Best parameters saved at: runs/best_params.json

=== Final metrics (channel 0) ===
SMAPE:      100.000000 %
RMSE:       0.167734
Overshoot:  0.000000
Robustness: 0.972635

Best epoch snapshot: {'epoch': 15, 'task': 0.01889342623335241, 'cons': 0.0010430885080397983, 'coh': 0.01243963263570178, 'total': 0.023146860278082843}

Telemetry JSONL → runs/telemetry_stable.jsonl

Summary:
{
  "best": {
    "epoch": 15,
    "task": 0.01889342623335241,
    "cons": 0.0010430885080397983,
    "coh": 0.01243963263570178,
    "total": 0.023146860278082843
  },
  "metrics": {
    "smape": 100.0,
    "rmse": 0.16773380360977846,
    "overshoot": 0.0,
    "robustness": 0.9726352677136917
  }
}