Online API Documentation#

Rollout#

Env Worker#

class minestudio.online.rollout.env_worker.EnvWorker(env_generator: Callable[[], MinecraftSim], conn: Connection, video_output_dir: str, video_fps: int, restart_interval: int | None = None, max_fast_reset: int = 10000, env_id: int = 0, rollout_worker_id: int = 0)[source]#

A class for running a Minecraft simulation environment in a separate process.

Parameters:
  • env_generator – A function that returns a MinecraftSim instance.

  • conn – A multiprocessing connection object for communication with the main process.

  • video_output_dir – The directory to save output videos to.

  • video_fps – The frames per second for output videos.

  • restart_interval – The interval in seconds after which to restart the environment.

  • max_fast_reset – The maximum number of fast resets to perform.

  • env_id – The ID of the environment.

  • rollout_worker_id – The ID of the rollout worker.

report_rewards(rewards: ndarray)[source]#

Sends the rewards for an episode to the main process.

Parameters:

rewards – A NumPy array of rewards for the episode.

Returns:

The result from the main process.

reset_state() Dict[str, Tensor][source]#

Sends a reset signal to the main process and receives the initial observation.

Returns:

The initial observation from the environment.

run() None[source]#

The main loop of the environment worker process. Handles environment resets, steps, and video recording.

step_agent(obs: dict, last_reward: float, last_terminated: bool, last_truncated: bool, episode_uuid: str) Tuple[Dict[str, Tensor], float][source]#

Sends an observation to the main process and receives an action and predicted value.

Parameters:
  • obs – The current observation from the environment.

  • last_reward – The reward from the previous step.

  • last_terminated – Whether the previous episode terminated.

  • last_truncated – Whether the previous episode was truncated.

  • episode_uuid – The UUID of the current episode.

Returns:

A tuple containing the action and predicted value.

class minestudio.online.rollout.env_worker.VideoWriter(video_fps: int, queue_size: int = 200)[source]#

A class for writing video frames to a file in a separate thread.

Parameters:
  • video_fps – The frames per second of the output video.

  • queue_size – The maximum size of the command queue.

close_video()[source]#

Closes the current video file.

open_video(path: PosixPath)[source]#

Opens a video file for writing.

Parameters:

path – The path to the video file.

Raises:

AssertionError – if a video is already open.

run()[source]#

The main loop of the video writer thread. Processes commands from the queue to open, write, and close video files.

write_frame(frame: ndarray)[source]#

Writes a frame to the video file.

Parameters:

frame – The frame to write, as a NumPy array.

Raises:

AssertionError – if no video is open.

minestudio.online.rollout.env_worker.draw_vpred(img: ndarray, vpred: float, additional_text: str | None = '')[source]#

Draws the predicted value (vpred) and additional text on an image.

Parameters:
  • img – The input image as a NumPy array.

  • vpred – The predicted value to display.

  • additional_text – Optional additional text to display.

Returns:

The image with the text drawn on it.

Episode Statistics#

Rollout Manager#

Rollout Worker#

class minestudio.online.rollout.rollout_worker.ProgressHandler(*args, **kwargs)[source]#

A protocol defining the structure for a progress handler function.

This handler is called by the RolloutWorker to report step-wise progress. Implementers of this protocol can use this to, for example, send data to a replay buffer.

class minestudio.online.rollout.rollout_worker.RolloutWorker(num_envs: int, policy_generator: Callable[[], MinePolicy], env_generator: Callable[[], MinecraftSim], use_normalized_vf: bool = False, model_device: str = 'cpu', next_model_version: int = 0, batch_size: int = 1, video_fps: int = 20, video_output_dir: str = './output', resume: str | None = None, restart_interval: int | None = None, moving_stat_duration: int = 300, log_interval: int | None = None, episode_statistics: ActorHandle | None = None, progress_handler: ProgressHandler | None = None, max_fast_reset: int = 10000, rollout_worker_id: int = 0)[source]#

Manages a set of environments and a policy to collect experience for reinforcement learning.

This class is responsible for: - Spawning and managing multiple environment worker processes (EnvWorker). - Performing inference using the policy model. - Stepping through the environments and collecting observations, actions, rewards, etc. - Communicating with an EpisodeStatistics actor to report episode-level metrics. - Optionally calling a progress_handler to process step-wise data (e.g., for a replay buffer).

Parameters:
  • num_envs – The number of parallel environments to run.

  • policy_generator – A callable that returns an instance of MinePolicy (the policy model).

  • env_generator – A callable that returns an instance of MinecraftSim (the environment).

  • use_normalized_vf – Whether to use a normalized value function. If True, vpreds will be denormalized.

  • model_device – The device to run the policy model on (e.g., “cpu”, “cuda:0”).

  • next_model_version – The initial version of the model.

  • batch_size – The number of environment steps to batch together for inference.

  • video_fps – Frames per second for video recording in EnvWorker.

  • video_output_dir – Directory to save videos in EnvWorker.

  • resume – Optional path to a checkpoint to resume training from.

  • restart_interval – Optional interval in seconds after which to restart EnvWorkers.

  • moving_stat_duration – Duration in seconds for calculating moving statistics (e.g., for pipeline monitoring).

  • log_interval – Optional interval in seconds for logging pipeline monitoring stats.

  • episode_statistics – Optional Ray ActorHandle for an EpisodeStatistics actor.

  • progress_handler – Optional callable that conforms to the ProgressHandler protocol.

  • max_fast_reset – Maximum number of fast resets for an EnvWorker before a full restart.

  • rollout_worker_id – Identifier for this rollout worker.

inference(requests: list) Tuple[list, list, list][source]#

Performs a batch of inference requests using the policy model.

Parameters:

requests – A list of tuples, where each tuple contains (worker_id, observation).

Returns:

A tuple containing: - result_actions: A list of actions for each request. - result_states: A list of next hidden states for each request. - result_vpreds: A list of predicted values (vpreds) for each request.

load_weights(weights: Dict[str, Tensor]) None[source]#

Loads new weights into the policy model.

Parameters:

weights – A dictionary containing the state dictionary of the model.

loop() None[source]#

Runs one iteration of the rollout loop.

This involves: 1. Polling environments until enough requests are queued to fill a batch. 2. Performing inference on the batch of requests. 3. Sending actions back to the environments. 4. Calling the progress_handler if it’s set. 5. Polling environments again to process any immediate responses. 6. Logging statistics if the log_interval is met.

poll_environments()[source]#

Polls all environment connections for messages and processes them.

Handles different message types: - “step_agent”: An environment is ready for a new action. The observation is added to queued_requests. - “reset_state”: An environment has reset. The agent’s hidden state for this environment is reset. - “report_rewards”: An environment has finished an episode. Rewards are reported to EpisodeStatistics.

progress_handler(*, worker_uuid: str, obs: Dict[str, Any], env_spec: str, state: List[Tensor], action: Dict[str, Any], last_reward: float, last_terminated: bool, last_truncated: bool, episode_uuid: str) None[source]#

Default progress handler, likely intended to be overridden or replaced. This implementation seems to be a duplicate of logic within RolloutWorkerWrapper.progress_handler and might not be used if a custom progress_handler is provided during RolloutWorker initialization.

Parameters:
  • worker_uuid – The UUID of the worker environment.

  • obs – The observation from the environment.

  • env_spec – Specification of the environment, used for filtering.

  • state – The hidden state of the policy model.

  • action – The action taken by the policy.

  • last_reward – The reward received from the previous step.

  • last_terminated – Whether the episode terminated at the previous step.

  • last_truncated – Whether the episode was truncated at the previous step.

  • episode_uuid – The UUID of the current episode.

Raises:

AssertionError – if env_spec is not in self.env_spec.config4test when not in self.env_spec.config.

rollout(num_batches: int) None[source]#

Runs the rollout loop for a specified number of batches.

Parameters:

num_batches – The number of inference batches to collect.

update_model_version(next_model_version: int)[source]#

Updates the model version number that will be associated with subsequently collected data.

Parameters:

next_model_version – The new model version number.

Start Manager#

Date: 2025-05-22 21:43:52 LastEditors: caishaofei-mus1 1744260356@qq.com LastEditTime: 2025-05-23 11:39:21 FilePath: /MineStudio/var/minestudio/online/rollout/start_manager.py

minestudio.online.rollout.start_manager.start_rolloutmanager(policy_generator, env_generator, online_cfg, address='localhost:9899')[source]#

Initializes and starts a RolloutManager actor in a Ray cluster.

This function handles the creation or reuse of a RolloutManager actor. If a RolloutManager with the name “rollout_manager” already exists, it checks if its configuration matches the provided online_cfg. If the configurations differ, the existing actor is killed, and a new one is created. Otherwise, the existing actor is reused.

Parameters:
  • policy_generator – A callable that generates a policy model.

  • env_generator – A callable that generates a Minecraft simulation environment.

  • online_cfg – An OmegaConf DictConfig object containing the online training configuration. This includes sub-configs for train_config and rollout_config.

  • address – The address of the Ray cluster to connect to.

Replay Buffer#

Actor#

class minestudio.online.rollout.replay_buffer.actor.ChunkRecord(fragment_records: List[FragmentRecord], model_version: int, session_id: str, use_count: int)[source]#

Represents a chunk of fragment records, along with model version, session ID, and use count.

Parameters:
  • fragment_records – A list of FragmentRecord objects that form this chunk.

  • model_version – The model version associated with the fragments in this chunk.

  • session_id – The session ID associated with the fragments in this chunk.

  • use_count – How many times this chunk has been used for training.

class minestudio.online.rollout.replay_buffer.actor.FragmentManager(fragment_store: FragmentStore)[source]#

Manages fragments stored in a FragmentStore, primarily by tracking their reference counts.

Parameters:

fragment_store – An instance of FragmentStore where fragments are physically stored.

clean(fragment_id: str)[source]#

Removes a fragment from the reference count and deletes it from the FragmentStore.

This method is called when a fragment’s reference count drops to zero.

Parameters:

fragment_id – The unique identifier of the fragment to clean.

create_fragment_record(fragment_id: str, metadata: FragmentMetadata)[source]#

Creates a new FragmentRecord for a given fragment_id and its metadata.

Parameters:
  • fragment_id – The unique identifier of the fragment.

  • metadata – The FragmentMetadata associated with this fragment.

Returns:

A new FragmentRecord instance.

class minestudio.online.rollout.replay_buffer.actor.FragmentRecord(fragment_id: str, metadata: FragmentMetadata, manager)[source]#

Represents a record of a fragment with its metadata and a reference to the FragmentManager.

This class is used to track references to fragments in the FragmentStore. When a FragmentRecord instance is deleted (no longer referenced), it informs the FragmentManager to potentially clean up the fragment from the store if its reference count drops to zero.

Parameters:
  • fragment_id – The unique identifier of the fragment.

  • metadata – The FragmentMetadata associated with this fragment.

  • manager – The FragmentManager instance that manages this fragment.

Fragment Store#

Date: 2025-05-20 12:09:48 LastEditors: caishaofei-mus1 1744260356@qq.com LastEditTime: 2025-05-23 11:39:33 FilePath: /MineStudio/var/minestudio/online/rollout/replay_buffer/fragment_store.py

class minestudio.online.rollout.replay_buffer.fragment_store.FragmentStore(**kwargs)[source]#

A class that provides an interface to either a local or a remote fragment store.

It checks if the current Ray node has a “database” resource. If so, it uses a LocalFragmentStoreImpl. Otherwise, it uses a RemoteFragmentStoreImpl actor.

Parameters:

kwargs – Keyword arguments to be passed to the underlying store implementation (LocalFragmentStoreImpl or RemoteFragmentStoreImpl).

Raises:

AssertionError – if the local status cannot be determined.

add_fragment(fragment: SampleFragment)[source]#

Adds a fragment to the store (either local or remote).

Parameters:

fragment – The SampleFragment object to store.

Returns:

A unique string ID for the stored fragment.

clear()[source]#

Removes all fragments from the store (either local or remote).

delete_fragment(fragment_uuid: str)[source]#

Deletes a fragment from the store (either local or remote) by its unique ID.

Parameters:

fragment_uuid – The unique ID of the fragment to delete.

get_disk_space()[source]#

Gets the total disk space used by the cache (either local or remote) in bytes.

Returns:

The disk space used by the cache.

get_fragment(fragment_uuid: str) SampleFragment[source]#

Retrieves a fragment from the store (either local or remote) by its unique ID.

Parameters:

fragment_uuid – The unique ID of the fragment to retrieve.

Returns:

The retrieved SampleFragment object.

class minestudio.online.rollout.replay_buffer.fragment_store.LocalFragmentStoreImpl(path: str, num_shards: int)[source]#

A local implementation of a fragment store using diskcache.FanoutCache.

This class provides methods to add, get, delete, and clear fragments stored on the local disk.

Parameters:
  • path – The directory path where the cache will be stored.

  • num_shards – The number of shards to use for the FanoutCache.

add_fragment(fragment: SampleFragment)[source]#

Adds a fragment to the store and returns a unique ID for it.

Parameters:

fragment – The SampleFragment object to store.

Returns:

A unique string ID for the stored fragment.

clear()[source]#

Removes all fragments from the store.

delete_fragment(fragment_uuid: str)[source]#

Deletes a fragment from the store by its unique ID.

Parameters:

fragment_uuid – The unique ID of the fragment to delete.

get_disk_space()[source]#

Gets the total disk space used by the cache in bytes.

Returns:

The disk space used by the cache.

get_fragment(fragment_uuid: str)[source]#

Retrieves a fragment from the store by its unique ID.

Parameters:

fragment_uuid – The unique ID of the fragment to retrieve.

Returns:

The retrieved SampleFragment object.

Replay Buffer Interface#

class minestudio.online.rollout.replay_buffer.interface.ReplayBufferInterface(config: omegaconf.DictConfig | None = None)[source]#

Provides an interface to interact with the ReplayBufferActor.

This class handles the creation or connection to a ReplayBufferActor named “replay_buffer”. It also initializes a FragmentStore based on the actor’s database configuration. All methods to interact with the replay buffer (add, load, fetch fragments, update model version) are routed through the ReplayBufferActor.

Parameters:

config – Optional DictConfig. If provided, a new ReplayBufferActor is created with this config. If None, it attempts to connect to an existing actor named “replay_buffer”.

Raises:

ValueError – If config is provided but an actor already exists, or if config is None and no actor exists.

add_fragment(fragment: SampleFragment, metadata: FragmentMetadata)[source]#

Adds a fragment to the FragmentStore and then informs the ReplayBufferActor.

Parameters:
  • fragment – The SampleFragment to add.

  • metadata – The FragmentMetadata associated with the fragment.

fetch_fragments(num_fragments: int) List[Tuple[FragmentIndex, str]][source]#

Fetches a list of fragment IDs and their indices from the ReplayBufferActor.

Parameters:

num_fragments – The number of fragments to fetch.

Returns:

A list of tuples, each containing a FragmentIndex and the fragment_id.

load_fragment(fragment_id: str) SampleFragment[source]#

Loads a fragment directly from the FragmentStore.

Parameters:

fragment_id – The unique ID of the fragment to load.

Returns:

The loaded SampleFragment.

prepared_fragments() List[Tuple[FragmentIndex, str]][source]#

Retrieves the fragments that were prepared by the last call to fetch_fragments in the ReplayBufferActor.

Returns:

A list of tuples, each containing a FragmentIndex and the fragment_id.

update_model_version(session_id: str, model_version: int)[source]#

Updates the model version in the ReplayBufferActor.

Parameters:
  • session_id – The ID of the current training session.

  • model_version – The new model version.

Returns:

The result of the actor’s method call.

update_training_session()[source]#

Calls the update_training_session method of the ReplayBufferActor.

Returns:

The result of the actor’s method call.

Trainer#

Base Trainer#

class minestudio.online.trainer.basetrainer.BaseTrainer(rollout_manager: ActorHandle, policy_generator: Callable[[], MinePolicy], env_generator: Callable[[], MinecraftSim], num_workers: int, num_readers: int, num_cpus_per_reader: int, num_gpus_per_worker: int, prefetch_batches: int, discount: float, gae_lambda: float, context_length: int, use_normalized_vf: bool, inference_batch_size_per_gpu: int, resume: str | None, resume_optimizer: bool, **kwargs)[source]#

Base class for PPO-style trainers.

This class provides the core logic for distributed training, including: - Managing rollout workers and a replay buffer. - Fetching experience fragments and calculating GAE (Generalized Advantage Estimation). - Broadcasting model updates to rollout workers. - Setting up the training loop using Ray Train.

Subclasses should implement setup_model_and_optimizer and train methods.

Parameters:
  • rollout_manager – ActorHandle for the RolloutManager.

  • policy_generator – Callable that returns a MinePolicy instance.

  • env_generator – Callable that returns a MinecraftSim instance.

  • num_workers – Number of training workers (Ray actors).

  • num_readers – Number of parallel data readers for fetching fragments.

  • num_cpus_per_reader – Number of CPUs allocated to each data reader.

  • num_gpus_per_worker – Number of GPUs allocated to each training worker.

  • prefetch_batches – Number of batches to prefetch during data loading.

  • discount – Discount factor (gamma) for GAE.

  • gae_lambda – Lambda parameter for GAE.

  • context_length – The length of the context window for processing sequences.

  • use_normalized_vf – Whether to use a normalized value function.

  • inference_batch_size_per_gpu – Batch size for inference on each GPU during GAE calculation.

  • resume – Optional path to a checkpoint directory to resume training from.

  • resume_optimizer – Whether to resume the optimizer state if resuming from a checkpoint.

  • kwargs – Additional keyword arguments.

broadcast_model_to_rollout_workers(new_version)[source]#

Broadcasts the current model state_dict to all rollout workers.

This is typically called after a model update. If new_version is True, the internal model version is incremented before broadcasting. Only rank 0 worker performs the broadcast.

Parameters:

new_version – If True, increments the model version.

fetch_fragments_and_estimate_advantages(*, num_fragments: int) Dict[str, Any][source]#

Fetches a batch of fragments from the replay buffer and calculates advantages using GAE.

This method orchestrates the following steps: 1. Rank 0 worker fetches num_fragments from the ReplayBufferInterface. 2. The fetched fragment records are distributed among the training workers. 3. Each worker performs inference on its assigned fragments to get vpreds and logps. 4. Information required for GAE (rewards, vpreds, next_done, next_vpred) is sent to a GAEWorker actor. 5. Rank 0 worker triggers GAE calculation in the GAEWorker. 6. Each worker retrieves its calculated td_targets and advantages from the GAEWorker.

Parameters:

num_fragments – The total number of fragments to fetch and process.

Returns:

A dictionary containing: - “records”: The list of FragmentIndex and fragment_id tuples processed by this worker. - “rewards”: A FragmentDataDict of summed rewards for each fragment. - “td_targets”: A FragmentDataDict of TD targets (GAE targets) for each step. - “advantages”: A FragmentDataDict of advantages for each step. - “old_logps”: A FragmentDataDict of log probabilities of actions under the policy used for rollout. - “old_pi_logits”: A FragmentDataDict of policy logits from the rollout. - “old_vpreds”: A FragmentDataDict of value predictions from the rollout.

fit()[source]#

Initializes and runs the Ray TorchTrainer to start the distributed training process.

setup_model_and_optimizer() Tuple[MinePolicy, Optimizer][source]#

Abstract method to be implemented by subclasses.

This method should initialize and return the policy model and its optimizer.

Returns:

A tuple containing the initialized MinePolicy model and a torch.optim.Optimizer.

Raises:

NotImplementedError – If not implemented by a subclass.

train()[source]#

Abstract method for the main training loop, to be implemented by subclasses.

This method will contain the logic for iterating through training steps/epochs, fetching data, performing model updates, logging, and saving checkpoints.

Raises:

NotImplementedError – If not implemented by a subclass.

PPO Trainer#

class minestudio.online.trainer.ppotrainer.PPOTrainer(num_iterations: int, learning_rate: float, anneal_lr_linearly: bool, weight_decay: float, adam_eps: float, batch_size_per_gpu: int, batches_per_iteration: int, gradient_accumulation: int, epochs_per_iteration: int, vf_warmup: int, ppo_clip: float, clip_vloss: bool, max_grad_norm: float, zero_initial_vf: bool, ppo_vf_coef: float, ppo_policy_coef: float, kl_divergence_coef_rho: float, entropy_bonus_coef: float, coef_rho_decay: float, normalize_advantage_full_batch: bool, record_video_interval: int, save_interval: int, save_path: str | None, keep_interval: int, log_ratio_range: float, enable_ref_update: False, whole_config: str, **kwargs)[source]#

Proximal Policy Optimization (PPO) trainer for reinforcement learning.

This class implements the PPO algorithm for training policies in Minecraft environments. It extends BaseTrainer and provides functionality for distributed training with Ray, gradient accumulation, value function warmup, and various PPO-specific optimizations including clipping, entropy bonuses, and KL divergence regularization.

The trainer supports both policy and value function training with configurable coefficients, learning rate annealing, and checkpoint saving capabilities.

ppo_update(records: List[Tuple[FragmentIndex, str]], td_targets: FragmentDataDict, advantages: FragmentDataDict, old_logps: FragmentDataDict, old_vpreds: FragmentDataDict, rewards: FragmentDataDict)[source]#

Perform PPO policy and value function updates on collected trajectory data.

Implements the core PPO algorithm including: - Policy loss computation with probability ratio clipping - Value function loss with optional clipping - KL divergence regularization against reference policy - Entropy bonus for exploration - Advantage normalization and gradient accumulation - Distributed training synchronization and error handling

The method processes data in batches across multiple epochs, computing various metrics and losses while handling numerical stability issues. Includes checkpoint saving and performance logging.

Parameters:
  • records – List of (fragment_index, worker_id) tuples identifying data fragments

  • td_targets – Temporal difference targets for value function training

  • advantages – Computed advantages for policy gradient estimation

  • old_logps – Log probabilities from the policy that collected the data

  • old_vpreds – Value predictions from the policy that collected the data

  • rewards – Reward values from the trajectory fragments

Returns:

None

Raises:

AssertionError – If reference model is None when KL regularization is enabled

setup_model_and_optimizer(policy_generator) Tuple[MinePolicy, Optimizer][source]#

Set up the model and optimizer for PPO training.

Creates the main policy model using the provided generator function and configures an AdamW optimizer with the specified hyperparameters. Optionally initializes value function weights to zero and sets up a reference model for KL divergence regularization if enabled.

Parameters:

policy_generator – Function that returns a new instance of the policy model

Returns:

Tuple of (model, optimizer) ready for training

Raises:

AssertionError – If reference model setup fails when KL regularization is enabled

train()[source]#

Execute the main PPO training loop.

Runs the complete training process for the specified number of iterations. Each iteration involves collecting rollout data, computing advantages, and performing PPO updates. Handles learning rate annealing, model broadcasting to rollout workers, and checkpoint management.

The method supports resuming from checkpoints by detecting current learning rate and calculating the appropriate starting iteration. Includes distributed training coordination and logging.

Returns:

None

train_iteration()[source]#

Execute a single training iteration of the PPO algorithm.

Performs one complete iteration consisting of: 1. Fetching trajectory fragments from rollout workers 2. Computing Generalized Advantage Estimation (GAE) 3. Performing PPO updates on the collected data 4. Decaying the KL divergence coefficient

This method coordinates between data collection and policy optimization phases of the PPO algorithm.

Returns:

None

minestudio.online.trainer.ppotrainer.print_memory_usage()[source]#

Print current CUDA memory usage information.

This function displays the allocated and reserved memory on the current CUDA device in megabytes (MB). Useful for debugging memory issues during training.

Returns:

None (prints memory information to stdout)

Start Trainer#

minestudio.online.trainer.start_trainer.start_trainer(policy_generator, env_generator, online_cfg, whole_config)[source]#

Starts the training process.

This function initializes and starts a training session, creating a trainer instance based on the provided configuration.

Parameters:
  • policy_generator – A function to generate the policy model.

  • env_generator – A function to generate the environment.

  • online_cfg – Online training configuration.

  • whole_config – The entire configuration as a string.

Utils#

Rollout Utils#

class minestudio.online.utils.rollout.datatypes.FragmentDataDict[source]#

A dictionary that maps FragmentIndex to arbitrary data. It provides a helper method to format a batch of fragments for model input.

format_batch(fragments: List[SampleFragment], device: device)[source]#

Formats a list of SampleFragments into a batch suitable for model input. It retrieves the corresponding data for each fragment from the dictionary, stacks them, and moves them to the specified device.

Parameters:
  • fragments – A list of SampleFragment objects.

  • device – The torch device to move the batch to.

Returns:

A batch of data ready for model input.

class minestudio.online.utils.rollout.datatypes.FragmentIndex(worker_uuid: str, fid_in_worker: int)[source]#

Represents a unique identifier for a SampleFragment.

Parameters:
  • worker_uuid – The unique identifier of the worker that generated the fragment.

  • fid_in_worker – The fragment’s ID within that worker.

class minestudio.online.utils.rollout.datatypes.FragmentMetadata(model_version: int, session_id: str, worker_uuid: str, fid_in_worker: int)[source]#

Metadata associated with a SampleFragment.

Parameters:
  • model_version – The version of the model used to generate this fragment.

  • session_id – The ID of the training session.

  • worker_uuid – The unique identifier of the worker that generated the fragment.

  • fid_in_worker – The fragment’s ID within that worker.

class minestudio.online.utils.rollout.datatypes.SampleFragment(obs: Dict[str, Any] | Tensor, action: Dict[str, Any] | Tensor, next_done: ndarray, reward: ndarray, first: ndarray, in_state: List[ndarray], worker_uuid: str, fid_in_worker: int, next_obs: Dict[str, Any], episode_uuids: List[str])[source]#

Represents a fragment of a trajectory, containing observations, actions, rewards, and other metadata.

Parameters:
  • obs – The observation from the environment.

  • action – The action taken by the agent.

  • next_done – A boolean indicating whether the episode terminated after this fragment.

  • reward – The reward received after taking the action.

  • first – A boolean indicating if this is the first fragment in an episode.

  • in_state – The recurrent state of the agent.

  • worker_uuid – The unique identifier of the worker that generated the fragment.

  • fid_in_worker – The fragment’s ID within that worker.

  • next_obs – The next observation from the environment.

  • episode_uuids – A list of unique identifiers for the episodes this fragment belongs to.

property index: FragmentIndex#

Returns the FragmentIndex for this SampleFragment.

Returns:

The FragmentIndex object.

print() None[source]#

Prints the contents of the SampleFragment for debugging purposes.

class minestudio.online.utils.rollout.datatypes.StepRecord(worker_uuid: str, obs: Dict[str, Any], state: List[ndarray] | None, action: Dict[str, Any], last_reward: float, last_terminated: bool, last_truncated: bool, model_version: int, episode_uuid: str, session_id: str)[source]#

Represents a single step taken in the environment.

Parameters:
  • worker_uuid – The unique identifier of the worker that generated this step.

  • obs – The observation from the environment.

  • state – The recurrent state of the agent.

  • action – The action taken by the agent.

  • last_reward – The reward received from the previous step.

  • last_terminated – Whether the episode terminated after the previous step.

  • last_truncated – Whether the episode was truncated after the previous step.

  • model_version – The version of the model used to generate this step.

  • episode_uuid – The unique identifier of the episode this step belongs to.

  • session_id – The ID of the training session.

class minestudio.online.utils.rollout.monitor.MovingStat(duration: int)[source]#

Calculates a moving statistic (average) over a specified duration.

Parameters:

duration – The duration in seconds over which to calculate the statistic.

average()[source]#

Calculates the average of the data points currently within the duration window.

Returns:

The average of the data points, or float(‘nan’) if no data is available.

update(x: float)[source]#

Adds a new data point to the calculation.

Parameters:

x – The new data point.

class minestudio.online.utils.rollout.monitor.PipelineMonitor(stages: List[str], duration: int = 300)[source]#

Monitors the time spent in different stages of a pipeline.

Parameters:
  • stages – A list of strings representing the names of the pipeline stages.

  • duration – The duration in seconds over which to calculate moving statistics for each stage.

print()[source]#

Prints a table summarizing the average time spent in each stage for all monitored pipelines.

report_enter(stage: str, pipeline_id='default')[source]#

Reports that a pipeline has entered a new stage.

Parameters:
  • stage – The name of the stage being entered.

  • pipeline_id – The unique identifier of the pipeline instance.

Raises:

AssertionError – If the reported stage is not the expected next stage.

Trainer Utils#

minestudio.online.utils.train.data.batchify_next_obs(next_obs: Dict[str, Any], device: device)[source]#

Converts a next_obs dictionary into a batch format suitable for model input.

It stacks the next_obs (assumed to be a single observation) and moves it to the specified device.

Parameters:
  • next_obs – A dictionary representing the next observation.

  • device – The torch device to move the batch to.

Returns:

The batchified next_obs as a torch Tensor.

minestudio.online.utils.train.data.create_loader_pool(num_readers: int, num_cpus_per_reader: int)[source]#

Creates a pool of FragmentLoader actors.

Parameters:
  • num_readers – The number of FragmentLoader actors to create in the pool.

  • num_cpus_per_reader – The number of CPUs to allocate to each FragmentLoader actor.

Returns:

An ActorPool of FragmentLoader actors.

minestudio.online.utils.train.data.data_iter(loader_pool: ActorPool, records: List[Tuple[FragmentIndex, str]], batch_size: int, prefetch_batches: int)[source]#

Creates an iterator that yields batches of SampleFragments loaded by the FragmentLoader pool.

It shuffles the records, prefetches data, and yields batches of a specified size.

Parameters:
  • loader_pool – The ActorPool of FragmentLoader actors.

  • records – A list of tuples, where each tuple contains a FragmentIndex and a fragment UUID.

  • batch_size – The number of fragments per batch.

  • prefetch_batches – The number of batches to prefetch.

Yields:

Batches of SampleFragments, where each batch is a list of dictionaries (output of FragmentLoader.load).

minestudio.online.utils.train.data.prepare_batch(model, batch_fragments: List[SampleFragment])[source]#

Prepares a batch of SampleFragments for model input.

It extracts observations, states, actions, and first flags from the fragments, stacks them, and moves them to the model’s device.

Parameters:
  • model – The model for which the batch is being prepared (used to get the device and merge_state method).

  • batch_fragments – A list of SampleFragment objects.

Returns:

A dictionary containing the prepared batch (obs, state, action, first) as torch Tensors.

minestudio.online.utils.train.gae.get_last_fragment_indexes(fragment_indexs: List[FragmentIndex]) List[FragmentIndex][source]#

Identifies the last fragment index for each worker from a list of fragment indexes.

A fragment is considered the last if it’s the last one from a worker or if the next fragment from the same worker is not contiguous.

Parameters:

fragment_indexs – A list of FragmentIndex objects.

Returns:

A list of FragmentIndex objects, each being the last fragment for a worker in a sequence.

Date: 2025-05-20 18:18:38 LastEditors: caishaofei-mus1 1744260356@qq.com LastEditTime: 2025-05-20 18:23:37 FilePath: /MineStudio/minestudio/online/utils/train/training_session.py

minestudio.online.utils.train.wandb_logger.define_metric(*args, **kwargs)[source]#

Defines a metric for the current Weights & Biases (wandb) training session.

This function retrieves the current training session and calls its define_metric method remotely. If no session is active or an error occurs, an error message is logged. It asserts that a training session must be active.

Parameters:
  • args – Positional arguments to pass to wandb.define_metric().

  • kwargs – Keyword arguments to pass to wandb.define_metric().

minestudio.online.utils.train.wandb_logger.log(*args, **kwargs)[source]#

Logs data to the current Weights & Biases (wandb) training session.

This function retrieves the current training session and calls its log method remotely. If no session is active or an error occurs during logging, an error message is logged.

Parameters:
  • args – Positional arguments to pass to wandb.log().

  • kwargs – Keyword arguments to pass to wandb.log().

minestudio.online.utils.train.wandb_logger.log_video(data: Dict[str, Any], video_key: str, fps: int)[source]#

Logs a video to the current Weights & Biases (wandb) training session.

This function retrieves the current training session and calls its log_video method remotely. If no session is active or an error occurs during logging, an error message is logged.

Parameters:
  • data – A dictionary containing the data to log. The video itself should be under the video_key.

  • video_key – The key in the data dictionary that holds the video data.

  • fps – The frames per second of the video.