Trainer#

In online/trainer, we have defined PPOTrainer and other online trainers. These trainers need to inherit from minestudio.online.trainer.basetrainer.BaseTrainer.

BaseTrainer is a foundational class in the MineStudio framework, designed to facilitate online reinforcement learning training. It manages the interaction between various components such as rollout workers, environment generators, and the policy model. Developers can extend this class to implement custom trainers for specific algorithms like PPO.

You can launch Trainer directly by minestudio.online.trainer.start_trainer.


Constructor#

__init__#

The constructor initializes the trainer with various configurations and components.

Parameters#

  • rollout_manager: ActorHandle
    Handles rollout workers.

  • policy_generator: Callable[[], MinePolicy]
    Function to generate the policy.

  • env_generator: Callable[[], MinecraftSim]
    Function to generate the environment.

  • num_workers: int
    Number of workers for parallel training.

  • num_readers: int
    Number of data readers.

  • num_cpus_per_reader: int
    Number of CPUs allocated per reader.

  • num_gpus_per_worker: int
    Number of GPUs allocated per worker.

  • prefetch_batches: int
    Number of batches to prefetch.

  • discount: float
    Discount factor for rewards.

  • gae_lambda: float
    Lambda for Generalized Advantage Estimation (GAE).

  • context_length: int
    Maximum context length for model input.

  • use_normalized_vf: bool
    Whether to normalize value function outputs.

  • inference_batch_size_per_gpu: int
    Batch size per GPU during inference.

  • resume: Optional[str]
    Path to checkpoint for resuming training.

  • resume_optimizer: bool
    Whether to resume optimizer state from the checkpoint.

  • kwargs
    Additional arguments.


Methods#

broadcast_model_to_rollout_workers#

Broadcasts the updated model to rollout workers.

Parameters#

  • new_version: bool
    Whether to increment the model version.


fetch_fragments_and_estimate_advantages#

Fetches fragments from the replay buffer, calculates advantages, and prepares data for training.

Parameters#

  • num_fragments: int
    Number of fragments to fetch.

Returns#

  • Dict[str, Any]: Processed data including records, TD targets, advantages, and old policy information.


setup_model_and_optimizer#

Abstract method to define the model and optimizer.

Returns#

  • Tuple[MinecraftSim, torch.optim.Optimizer]: Model and optimizer instances.


_train_func#

Main training loop function, executed by TorchTrainer.


train#

Abstract method to implement custom training logic.


fit#

Executes the training process using TorchTrainer.


Attributes#

rollout_manager#

Manages rollout workers.

policy_generator#

Function to create the policy model.

env_generator#

Function to generate environments.

gae_actor#

Actor for calculating GAE and reward targets.


Usage#

To use BaseTrainer, extend it and implement the abstract methods setup_model_and_optimizer and train.

class PPOTrainer(BaseTrainer):
    def setup_model_and_optimizer(self):
        # Define model and optimizer
        pass

    def train(self):
        # Custom training logic
        pass

Refer to minestudio.online.trainer.ppotrainer.PPOTrainer for an example.