Trainer#
This document describes the trainer in the online training module, with the primary implementation being the PPOTrainer
.
Design Principle#
The PPOTrainer
(Proximal Policy Optimization Trainer) is the core component responsible for updating the policy and value function models. It leverages the experience data collected by rollout workers and is architected for efficient, distributed training using Ray Train
.
The design of the PPOTrainer
is guided by several key principles. At its heart is the PPO Algorithm, a well-regarded actor-critic method in reinforcement learning known for its balance of sample efficiency, stability, and ease of implementation. To handle the demands of large-scale training, the trainer is built for Distributed Training, utilizing Ray Train
for data parallelism. This enables the training process to scale across multiple GPUs and even multiple nodes in a cluster.
Effective Data Handling is crucial. The trainer ingests SampleFragment
s (segments of trajectories) from a ReplayBuffer
(a process managed by its BaseTrainer
parent class). It then computes Generalized Advantage Estimation (GAE) and TD-targets, which are vital for PPO’s policy and value function updates. A data_iter
is employed to efficiently load and prepare batches of data for the training epochs.
Model Management encompasses several responsibilities. The trainer initializes the MinePolicy
model and an AdamW optimizer. It can optionally maintain a separate ref_model
for KL divergence regularization, which helps stabilize training by penalizing large deviations from a previous policy. Robust model checkpointing capabilities are included, allowing for saving and loading of training progress. Furthermore, updated model weights are periodically broadcast to the rollout workers to ensure they are using the latest policy for experience collection.
The Loss Calculation in PPO is a composite objective. This typically includes:
A Policy Loss, based on a clipped surrogate objective to ensure stable policy updates.
A Value Function Loss, usually a mean squared error to train the value function, which can also be clipped.
An optional Entropy Bonus, added to the loss to encourage exploration by discouraging the policy from becoming too deterministic too quickly, as stated in Soft Actor-Critic (SAC) literature.
An optional KL Divergence Penalty, based on the KL divergence between the current policy and the
ref_model
.
To aid convergence, the trainer supports Learning Rate Scheduling, often in the form of linear annealing. Gradient Management techniques are also incorporated, such as gradient accumulation (to simulate larger batch sizes) and gradient norm clipping (to prevent exploding gradients). Finally, comprehensive Logging and Monitoring are achieved using wandb_logger
for Weights & Biases integration and torchmetrics
for accumulating various performance statistics.
Logic#
The training process orchestrated by the PPOTrainer
unfolds in several key stages:
1. Initialization#
The journey begins with the setup phase, primarily within the __init__
and setup_model_and_optimizer
methods. First, all necessary Hyperparameters for the PPO algorithm, optimizer settings, batch sizes, and the training schedule are configured.
Following this, the Model and Optimizer are Instantiated. This involves creating the primary policy model (an instance of MinePolicy
) and initializing the AdamW optimizer with the configured learning rate and weight decay.
Note
If the zero_initial_vf
option is enabled, the weights of the value function head within the policy model are explicitly initialized to zero. This technique can sometimes provide a better starting point for the value function in the early stages of training.
If KL divergence regularization is active (i.e., kl_divergence_coef_rho
is non-zero), a separate Reference Model (ref_model
) is also set up. This model typically holds an older version of the policy and serves as a baseline for the KL penalty. Lastly, the model is prepared for Distributed Setup with Ray Train
, which usually involves wrapping it with torch.nn.parallel.DistributedDataParallel
if multiple GPUs or nodes are part of the training cluster.
2. Main Training Loop#
The train
method orchestrates the overarching training loop. This loop Iterates for a predefined num_iterations
. Within each iteration, if Learning Rate Annealing (anneal_lr_linearly
) is enabled, the learning rate for the optimizer is adjusted, typically decreased linearly as training progresses. The core work of data processing and model updates for each cycle happens within the train_iteration
method.
A critical aspect of distributed training is Model Broadcasting. This is typically handled by the rank 0 worker (the chief worker). After an initial vf_warmup
period (during which the value function might be trained more aggressively or exclusively to stabilize it), the updated model weights are broadcast to all rollout workers. This synchronization is facilitated by the broadcast_model_to_rollout_workers
method, often part of the BaseTrainer
class, ensuring that data collection agents are working with the most up-to-date policy.
3. Single Training Iteration#
Each call to the train_iteration
method performs one complete pass of training. It begins with Data Acquisition and Preprocessing. The fetch_fragments_and_estimate_advantages
method (inherited or called from BaseTrainer
) is invoked. This is a crucial step that retrieves a collection of SampleFragment
s from the replay buffer, computes Generalized Advantage Estimation (GAE) values, and calculates TD-lambda targets for value function training.
Once the data is prepared and advantages are estimated, the PPO Update Execution commences by calling the ppo_update
method with this processed information. Additionally, if a KL divergence penalty is used, its coefficient (kl_divergence_coef_rho
) is decayed according to coef_rho_decay
, often to gradually reduce its influence as training stabilizes.
4. PPO Update Step#
The ppo_update
method is where the core PPO algorithm refines the model parameters. This process itself has several sub-stages. The update iterates for a specified number of Epochs (epochs_per_iteration
) over the currently collected batch of data. Within each epoch, a data_iter
provides Mini-Batches of SampleFragment
s.
For each mini-batch, several operations occur during Mini-Batch Processing. First, prepare_batch
converts the list of SampleFragment
s into a batched PyTorch tensor format, making it suitable for feeding into the neural network. Data is often processed in fixed-length temporal chunks, defined by context_length
.
The heart of the update is the Forward Pass. The current policy model processes the data chunk to obtain new action log probabilities (new_logp
) for the actions taken in the fragments, value predictions (vpred
), and the raw policy logits. If KL regularization is active, the ref_model
also performs a forward pass on the same data to get its corresponding action log probabilities or logits.
Next comes the Loss Calculation. The PPO objective is a composite of several terms:
Policy Loss (Actor Loss): This is calculated using the PPO clipped surrogate objective. It involves computing the probability ratio $r_t(\theta) = \exp(\log \pi_\theta(a_t|s_t) - \log \pi_{\theta_old}(a_t|s_t))$. The loss is then derived from $\min(r_t(\theta)\hat{A}_t, \mathrm{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t)
, where
\hat{A}_t` is the advantage estimate. This clipping mechanism is key to PPO’s stability.Value Loss (Critic Loss): This is typically the Mean Squared Error (MSE) between the predicted values (
vpred
) and the calculated TD-targets. The value loss can also be clipped (ifclip_vloss
is true) by comparing it to a version where value predictions are clipped around theold_vpred
from the previous iteration.Entropy Bonus: An entropy term, derived from the policy logits, is added to the objective. This encourages exploration by penalizing policies that become too deterministic too quickly.
KL Divergence Loss: If enabled, this measures the KL divergence between the current policy’s action distribution and that of the
ref_model
. It acts as an additional regularization term. The Total Loss is then a weighted sum of these components.
Tip
During an initial vf_warmup
phase, the training might focus more on the value function. In such cases, the total loss might be dominated by, or solely consist of, the value function loss (and KL loss if active). This helps in stabilizing the value estimates before the policy is trained more aggressively.
With the total loss computed, the Backward Pass and Optimization step follows. The total_loss
is backpropagated to compute gradients. If Gradient Accumulation is used (i.e., gradient_accumulation
> 1), gradients are summed over several mini-batches before an optimizer step is taken, effectively simulating a larger batch size which can sometimes stabilize training. Gradient Clipping is then applied, where gradients are clipped to a maximum norm (max_grad_norm
) using torch.nn.utils.clip_grad_norm_
. This is a common technique to prevent issues with exploding gradients. Finally, the optimizer (e.g., AdamW) updates the model parameters using these accumulated and clipped gradients.
Throughout this PPO update cycle, various Metrics are Aggregated. These include individual loss components, the approximate KL divergence between the old and new policies, the fraction of samples where the PPO clipping was active, and explained variance. These metrics are computed and tracked using torchmetrics
.
5. Logging and Checkpointing#
Essential for monitoring and recovery, logging and checkpointing operations are typically performed by the rank 0 worker. After each ppo_update
(meaning, after all epochs for a given batch of data have been processed), the aggregated training Metrics are Logged using wandb_logger
. This sends the data to Weights & Biases, allowing for real-time monitoring and later analysis of the training run.
Model Checkpointing is performed periodically. The state of the model (its weights), the optimizer’s state, and the current number of updates are saved to disk. The frequency of these saves is governed by save_interval
, and save_path
specifies the directory for these checkpoints. To manage disk space, Old Checkpoint Management ensures that older checkpoints are removed based on the keep_interval
, for instance, by keeping only the last N checkpoints.
6. Reference Model Update#
The _update_ref_model
method is responsible for the synchronization of the reference model, which is used when KL regularization is active. If enable_ref_update
is true, the weights of the ref_model
are periodically updated to match the weights of the current trained policy model.
Note
This update of the reference model is often coordinated with the broadcast_model_to_rollout_workers
calls. This ensures the reference policy doesn’t lag too far behind the current policy, yet doesn’t update too frequently, which could diminish its regularizing effect by making it too similar to the current policy.
This comprehensive cycle of data ingestion, advantage estimation, policy optimization, and model management continues for the specified num_iterations
. The ultimate aim is to train an agent that learns to perform effectively and master complex tasks within the Minecraft environment.