Trainer#
In online/trainer
, we have defined PPOTrainer
and other online trainers. These trainers need to inherit from minestudio.online.trainer.basetrainer.BaseTrainer
.
BaseTrainer
is a foundational class in the MineStudio framework, designed to facilitate online reinforcement learning training. It manages the interaction between various components such as rollout workers, environment generators, and the policy model. Developers can extend this class to implement custom trainers for specific algorithms like PPO.
You can launch Trainer
directly by minestudio.online.trainer.start_trainer
.
Constructor#
__init__
#
The constructor initializes the trainer with various configurations and components.
Parameters#
rollout_manager: ActorHandle
Handles rollout workers.policy_generator: Callable[[], MinePolicy]
Function to generate the policy.env_generator: Callable[[], MinecraftSim]
Function to generate the environment.num_workers: int
Number of workers for parallel training.num_readers: int
Number of data readers.num_cpus_per_reader: int
Number of CPUs allocated per reader.num_gpus_per_worker: int
Number of GPUs allocated per worker.prefetch_batches: int
Number of batches to prefetch.discount: float
Discount factor for rewards.gae_lambda: float
Lambda for Generalized Advantage Estimation (GAE).context_length: int
Maximum context length for model input.use_normalized_vf: bool
Whether to normalize value function outputs.inference_batch_size_per_gpu: int
Batch size per GPU during inference.resume: Optional[str]
Path to checkpoint for resuming training.resume_optimizer: bool
Whether to resume optimizer state from the checkpoint.kwargs
Additional arguments.
Methods#
broadcast_model_to_rollout_workers
#
Broadcasts the updated model to rollout workers.
Parameters#
new_version: bool
Whether to increment the model version.
fetch_fragments_and_estimate_advantages
#
Fetches fragments from the replay buffer, calculates advantages, and prepares data for training.
Parameters#
num_fragments: int
Number of fragments to fetch.
Returns#
Dict[str, Any]
: Processed data including records, TD targets, advantages, and old policy information.
setup_model_and_optimizer
#
Abstract method to define the model and optimizer.
Returns#
Tuple[MinecraftSim, torch.optim.Optimizer]
: Model and optimizer instances.
_train_func
#
Main training loop function, executed by TorchTrainer
.
train
#
Abstract method to implement custom training logic.
fit
#
Executes the training process using TorchTrainer
.
Attributes#
rollout_manager
#
Manages rollout workers.
policy_generator
#
Function to create the policy model.
env_generator
#
Function to generate environments.
gae_actor
#
Actor for calculating GAE and reward targets.
Usage#
To use BaseTrainer
, extend it and implement the abstract methods setup_model_and_optimizer
and train
.
class PPOTrainer(BaseTrainer):
def setup_model_and_optimizer(self):
# Define model and optimizer
pass
def train(self):
# Custom training logic
pass
Refer to minestudio.online.trainer.ppotrainer.PPOTrainer
for an example.