Deep Supervision with Recursion

Here Deep Supervision relates mostly to the training paradigm of recent wave of recursive models like HRM and TRM.

The model iteratively refines an answer y and latent reasoning state z. A "full recursion process" at each time is:

for step in range(n):        # e.g., n=6 "think" steps
    z = f(x, y, z)           # update latent reasoning
y = g(y, z)                  # "act" - update answer
loss = compute_loss(y, target)   # supervision here

This process is repeated for T-1 steps with torch.no_grad(). The T-1 blocks are purely state evolution: no loss, no grad. Just rolling forward the (y, z) state. Then loss is computed only on the final T step, which gets full backprop.

Full pseudocode from TRM:

deep-supervision-with-recursion

If this is novel or not is highly debatable. However the specific setup of T-1 steps to refine state, then using last step with full backprop does seem quite interesting and powerful.