What this is
The Training API supports two ways to compute loss:- Built-in losses via
forward_backwardwith a string identifier (e.g."cross_entropy") — fastest, no extra forward pass needed. - Custom losses via
forward_backward_customwith an arbitrary Python function — flexible, supports any differentiable objective at the cost of an additional forward pass.
Built-in loss: cross_entropy
For supervised fine-tuning, use the built-incross_entropy loss via forward_backward:
cross_entropy, the SDK backfills result.metrics["response_tokens"] so you can compute a mean loss from sum-style metrics when needed.
For a forward-only pass (e.g. to compute reference logprobs without updating weights):
Custom losses: forward_backward_custom
forward_backward_custom lets you implement any objective function in Python. You provide the loss computation; the API handles the forward pass on remote GPUs, passes logprobs back to your function, then sends the computed gradients back for the backward pass.
How it works
- You call
training_client.forward_backward_custom(datums, loss_fn). - The trainer runs a forward pass on the GPU and returns per-token logprobs.
- The logprobs are converted to PyTorch tensors with
requires_grad=True. - Your
loss_fnis called with the datums and logprobs. - The API calls
loss.backward()to computed_loss/d_logprobgradients. - Gradients are sent back to the trainer GPU for the model backward pass.
forward_backward_custom does an extra forward pass compared to forward_backward, requiring ~1.5x FLOPs and up to ~3x wall time per step.Embedding-space custom losses
For objectives that operate on pooled hidden states instead of logprobs, passoutput="embedding" and pooling="mean" or "last":
Loss function signature
Key rules
logprobs_list[i]hasrequires_grad=True— your loss must be differentiable through it.- Use
torch.dot()to compute weighted sums — this correctly propagates gradients through the logprobs. - Return a scalar tensor as the loss, and a
dict[str, float]as metrics. - Access token weights via
data[i].loss_fn_inputs["weights"].data— these are0for prompt tokens and1for response tokens.
Building datums
Using tinker_cookbook (weight-based)
datum_from_model_input_weights constructs datums with explicit token weights:
Using tinker.Datum directly (target-token-based)
For RL-style objectives where you need per-completion control (e.g. routing matrices, customloss_fn_inputs), construct datums directly:
Multi-target cross-entropy
For sparse distillation objectives, built-incross_entropy also supports
multiple candidate target tokens per model position. In this mode,
target_tokens has shape [N, K], where:
Nis the number of model input positions.Kis the number of candidate targets per position.target_tokens.datais flattened row-major and must containN * Ktoken ids.
weights, it must describe the same flattened target entries as
target_tokens.data. That means weights.data must contain exactly the same
number of values as target_tokens.data (N * K values), in the same row-major
order, with one weight per candidate target.
Example: simple cross-entropy
Example: GRPO with KL penalty
Example: DPO margin loss
Built-in loss methods: GRPO vs DAPO vs GSPO-token
When using the managed RFT flow or the cookbook’s RL recipe, three built-in loss methods are available via--rl-loss-method:
| Method | Clipping | KL penalty | Loss aggregation | Importance sampling |
|---|---|---|---|---|
grpo (default) | Symmetric [0.8, 1.2] | Yes (0.001) | Token-mean | Token-level |
dapo | Asymmetric [0.8, 1.28] | No | Token-mean | Token-level |
gspo-token | Very tight [1-3e-4, 1+4e-4] | No | Seq-mean-token-mean | Sequence-level |
seq-mean-token-mean aggregation normalizes per-sequence before averaging, reducing bias toward longer responses.
For Training API users implementing custom loss functions via forward_backward_custom, these methods serve as reference implementations. You can replicate or modify their behavior in your custom loss function. See Parameter Tuning for detailed guidance on when to choose each method.
Applying the optimizer step
Afterforward_backward_custom, call optim_step to update weights:
forward_backward_custom multiple times before calling optim_step:
Advanced optimizer-step controls such as server-side gradient accumulation normalization are intentionally kept out of this user-facing guide. See the cookbook skill reference for agent-facing operational guidance.
Common pitfalls
- Token-weight misalignment can silently break objective semantics — always verify that
len(logprobs)andlen(weights)are compatible (truncate tomin_len). - Ignoring per-step diagnostics makes instability hard to attribute — log metrics from every train step.
- Forgetting
.result()— all Tinker API calls return futures. Without.result(), errors are silently swallowed. - Non-differentiable loss: If your loss doesn’t depend on
logprobs_listentries through differentiable ops, gradients will be zero.
Related guides
- Training and Sampling — end-to-end workflow
- Saving and Loading — checkpoint and weight sync
- Cookbook RL recipe — GRPO with full reward pipeline
- Cookbook DPO recipe — DPO with preference data