Models#

We provided a template for the Minecraft Policy, MinePolicy, and based on this template, we created various different baseline models. Currently, MineStudio supports VPT, STEVE-1, GROOT, and ROCKET-1, among others. This page details the MinePolicy template and how to create your own policies.

Quick Start#

Here is an example that shows how to load and use OpenAI’s VPT policy within the Minecraft environment provided by MineStudio.

from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import RecordCallback
from minestudio.models import load_vpt_policy, VPTPolicy

# Option 1: Load the policy from local model files
# Ensure you have the .model (architecture) and .weights (parameters) files.
policy = load_vpt_policy(
    model_path="/path/to/foundation-model-2x.model", 
    weights_path="/path/to/foundation-model-2x.weights"
).to("cuda") # Move the policy to GPU if available

# Option 2: Load the policy from the Hugging Face Model Hub
# This is a convenient way to get pre-trained models.
# policy = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")

# Set the policy to evaluation mode. This is important for consistent behavior during inference.
policy.eval()

# Initialize the Minecraft Simulator
# obs_size specifies the resolution of the visual observations.
# callbacks allow for custom actions during the simulation, e.g., recording.
env = MinecraftSim(
    obs_size=(128, 128), 
    callbacks=[RecordCallback(record_path="./output", fps=30, frame_type="pov")]
)

# `memory` stores the recurrent state of the policy (e.g., for RNNs).
# For policies without memory (Markovian), it can be initialized to None.
memory = None 
obs, info = env.reset() # Reset the environment to get the initial observation.

# Simulation loop
for i in range(1200): # Run for 1200 steps
    # Get an action from the policy.
    # `obs`: The current observation from the environment.
    # `memory`: The current recurrent state of the policy.
    # `input_shape='*'`: Indicates that `obs` is a single sample (not a batch or time sequence).
    # The policy handles internal batching/unbatching for its forward pass.
    action, memory = policy.get_action(obs, memory, input_shape='*')
    
    # Apply the action to the environment.
    obs, reward, terminated, truncated, info = env.step(action)
    
    # Check if the episode has ended.
    if terminated or truncated:
        print("Episode finished after {} timesteps".format(i+1))
        break

env.close() # Close the environment when done.

Hint

In this example, if RecordCallback is used, the recorded video will be saved in the ./output directory. The memory variable handles the policy’s recurrent state, and input_shape='*' in get_action is typical for single-instance inference.

Policy Template (MinePolicy)#

Warning

To ensure compatibility with MineStudio’s training and inference pipelines, custom policies must inherit from minestudio.models.base_policy.MinePolicy and implement its abstract methods.

The MinePolicy class serves as the base for all policies. Key methods and properties are described below:

init(self, hiddim, action_space=None, temperature=1.0, nucleus_prob=None)

The constructor for the policy.

  • hiddim (int): The hidden dimension size for the policy’s internal layers.

  • action_space (gymnasium.spaces.Space, optional): The action space of the environment. If None, a default Minecraft action space (camera and buttons) is used.

  • temperature (float, optional): Temperature for sampling actions from the policy head. Defaults to 1.0.

  • nucleus_prob (float, optional): Nucleus (top-p) probability for sampling actions. If None, standard sampling is used. Defaults to None.

It initializes the policy’s action head (self.pi_head) and value head (self.value_head).

# From minestudio.models.base_policy.py
def __init__(self, hiddim, action_space=None, temperature=1.0, nucleus_prob=None) -> None:
    torch.nn.Module.__init__(self)
    if action_space is None:
        action_space = gymnasium.spaces.Dict({
            "camera": gymnasium.spaces.MultiDiscrete([121]), 
            "buttons": gymnasium.spaces.MultiDiscrete([8641]),
        })
    # self.value_head is a ScaledMSEHead
    self.value_head = ScaledMSEHead(hiddim, 1, norm_type="ewma", norm_kwargs=None)
    # self.pi_head uses the provided action_space, hiddim, temperature, and nucleus_prob
    self.pi_head = make_action_head(action_space, hiddim, temperature=temperature, nucleus_prob=nucleus_prob)

Hint

Users can override self.pi_head and self.value_head after calling super().__init__(...) if custom head implementations are needed.

reset_parameters(self)

Resets the parameters of the policy’s action head (pi_head) and value head (value_head). This can be useful for re-initializing a policy.

# From minestudio.models.base_policy.py
def reset_parameters(self):
    """Resets the parameters of the policy and value heads."""
    self.pi_head.reset_parameters()
    self.value_head.reset_parameters()
forward(self, input: Dict[str, Any], state_in: Optional[List[torch.Tensor]] = None, **kwargs) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]

This is an abstract method and must be implemented by derived classes. It defines the main computation of the policy.

  • input (Dict[str, Any]): A dictionary of input tensors (e.g., observations from the environment).

  • state_in (Optional[List[torch.Tensor]]): A list of input recurrent state tensors. None if the episode is starting or the policy is Markovian.

  • **kwargs: Additional keyword arguments.

Returns:

  • latents (Dict[str, torch.Tensor]): A dictionary containing at least:

    • 'pi_logits' (torch.Tensor): The logits for the action distribution.

    • 'vpred' (torch.Tensor): The predicted value function.

  • state_out (List[torch.Tensor]): A list containing the updated recurrent state tensors. For a Markovian policy (no state), this should be an empty list ([]).

# From minestudio.models.base_policy.py
@abstractmethod
def forward(self,
            input: Dict[str, Any],
            state_in: Optional[List[torch.Tensor]] = None,
            **kwargs
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
    pass
initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]

This is an abstract method and must be implemented by derived classes. It returns the initial recurrent state of the policy.

  • batch_size (Optional[int]): The batch size for which to create the initial state.

Returns:

  • (List[torch.Tensor]): A list of initial state tensors. For a Markovian policy (no state), this should be an empty list ([]).

# From minestudio.models.base_policy.py
@abstractmethod
def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
    pass
get_action(self, input: Dict[str, Any], state_in: Optional[List[torch.Tensor]], deterministic: bool = False, input_shape: str = “BT*”, **kwargs) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]

This method computes and returns an action from the policy based on the current input and state. It’s typically used during inference.

  • input (Dict[str, Any]): Current observation from the environment.

  • state_in (Optional[List[torch.Tensor]]): Current recurrent state.

  • deterministic (bool): If True, samples actions deterministically (e.g., argmax). If False, samples stochastically. Defaults to False.

  • input_shape (str): Specifies the shape of the input.

    • "*": Single instance input (e.g., one observation at a time during inference). The method handles batching/unbatching internally.

    • "BT*": Batched sequence input (Batch, Time, …). Defaults to "BT*".

  • **kwargs: Additional keyword arguments passed to forward.

Returns:

  • action (Dict[str, torch.Tensor]): The sampled action.

  • state_out (List[torch.Tensor]): The updated recurrent state.

# Simplified from minestudio.models.base_policy.py
@torch.inference_mode()
def get_action(self,
                input: Dict[str, Any],
                state_in: Optional[List[torch.Tensor]],
                deterministic: bool = False,
                input_shape: str = "BT*",
                **kwargs,
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
    if input_shape == "*":
        # Internal batching for single instance input
        input = dict_map(self._batchify, input)
        if state_in is not None:
            state_in = recursive_tensor_op(lambda x: x.unsqueeze(0), state_in)
    elif input_shape != "BT*":
        raise NotImplementedError("Unsupported input_shape")
    
    latents, state_out = self.forward(input, state_in, **kwargs)
    action = self.pi_head.sample(latents['pi_logits'], deterministic=deterministic)
    self.vpred = latents['vpred'] # Cache for potential later use
    
    if input_shape == "*":
        # Internal unbatching for single instance output
        action = dict_map(lambda tensor: tensor[0][0], action)
        state_out = recursive_tensor_op(lambda x: x[0], state_out)
        
    return action, state_out

Note

Empirically, setting deterministic=False (stochastic sampling) can often improve policy performance during evaluation compared to deterministic actions. The input_shape="*" is common for inference when processing one observation at a time.

device (property)

A property that returns the torch.device (e.g., ‘cpu’, ‘cuda:0’) on which the policy’s parameters are located.

# From minestudio.models.base_policy.py
@property
def device(self) -> torch.device:
    return next(self.parameters()).device

Hint

The minimal set of methods you must implement in your custom policy are forward and initial_state.

Your First Policy#

Here are basic examples of how to create custom policies by inheriting from MinePolicy.

Load necessary modules:

import torch
import torch.nn as nn
from einops import rearrange
from typing import Dict, List, Optional, Tuple, Any

from minestudio.models.base_policy import MinePolicy
# Assuming make_action_head and ScaledMSEHead are accessible for custom heads,
# or rely on those initialized in MinePolicy's __init__.

Example 1: Condition-Free (Markovian) Policy#

This policy does not depend on any external condition beyond the current observation and has no recurrent state.

class MySimpleMarkovPolicy(MinePolicy):
    def __init__(self, hiddim, action_space=None, image_size=(64, 64), image_channels=3) -> None:
        super().__init__(hiddim, action_space) # Initializes self.pi_head and self.value_head
        
        # Example backbone: a simple MLP
        # Input image is flattened: image_size[0] * image_size[1] * image_channels
        self.feature_dim = image_size[0] * image_size[1] * image_channels
        self.net = nn.Sequential(
            nn.Linear(self.feature_dim, hiddim), 
            nn.ReLU(),
            nn.Linear(hiddim, hiddim),
            nn.ReLU()
        )
        # self.pi_head and self.value_head are already defined in the parent class.

    def forward(self, 
                input: Dict[str, Any], 
                state_in: Optional[List[torch.Tensor]] = None, # Will be None for Markovian
                **kwargs
    ) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
        # Assuming input['image'] is (B, T, H, W, C)
        # For a Markovian policy, we typically expect T=1 or process frames independently.
        # If T > 1, this example processes each time step independently.
        img_obs = input['image'] # Shape: (B, T, H, W, C)
        b, t, h, w, c = img_obs.shape
        
        # Flatten image: (B, T, H, W, C) -> (B*T, H*W*C)
        # Normalize image (example: scale to [0,1])
        x = rearrange(img_obs / 255.0, 'b t h w c -> (b t) (h w c)')
        
        features = self.net(x) # Shape: (B*T, hiddim)
        
        # Reshape for policy and value heads if they expect (B, T, hiddim)
        features_reshaped = rearrange(features, '(b t) d -> b t d', b=b, t=t)
        
        pi_logits = self.pi_head(features_reshaped) # pi_head handles (B, T, D) or (B*T, D)
        vpred = self.value_head(features_reshaped)  # value_head handles (B, T, D) or (B*T, D)
        
        result = {
            'pi_logits': pi_logits, 
            'vpred': vpred, 
        }
        # For a Markovian policy, state_out is an empty list
        return result, []

    def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
        # Markovian policy has no state, return empty list
        return []

Example 2: Condition-Based (Markovian) Policy#

This policy takes an additional ‘condition’ tensor as input.

class MySimpleConditionedPolicy(MinePolicy):
    def __init__(self, hiddim, action_space=None, image_size=(64, 64), image_channels=3, condition_dim=64) -> None:
        super().__init__(hiddim, action_space)
        
        self.feature_dim = image_size[0] * image_size[1] * image_channels
        self.net = nn.Sequential(
            nn.Linear(self.feature_dim, hiddim), 
            nn.ReLU(),
        )
        # Embedding for the condition
        self.condition_net = nn.Linear(condition_dim, hiddim) # Or nn.Embedding if condition is discrete
        
        # Fusion layer
        self.fusion_net = nn.Sequential(
            nn.Linear(hiddim * 2, hiddim), # Example: concatenate image and condition features
            nn.ReLU(),
            nn.Linear(hiddim, hiddim),
            nn.ReLU()
        )

    def forward(self, 
                input: Dict[str, Any], 
                state_in: Optional[List[torch.Tensor]] = None, 
                **kwargs
    ) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
        img_obs = input['image'] # Shape: (B, T, H, W, C)
        condition = input['condition'] # Shape: (B, T, condition_dim)
        
        b, t, h, w, c = img_obs.shape
        
        x_img = rearrange(img_obs / 255.0, 'b t h w c -> (b t) (h w c)')
        img_features = self.net(x_img) # Shape: (B*T, hiddim)
        
        # Process condition
        # Assuming condition is already (B, T, condition_dim) -> (B*T, condition_dim)
        cond_features = self.condition_net(rearrange(condition, 'b t d -> (b t) d')) # Shape: (B*T, hiddim)
        
        # Fuse features (example: concatenation)
        fused_features = torch.cat([img_features, cond_features], dim=-1)
        final_features = self.fusion_net(fused_features) # Shape: (B*T, hiddim)

        final_features_reshaped = rearrange(final_features, '(b t) d -> b t d', b=b, t=t)
        
        pi_logits = self.pi_head(final_features_reshaped)
        vpred = self.value_head(final_features_reshaped)
        
        result = {
            'pi_logits': pi_logits, 
            'vpred': vpred, 
        }
        return result, []

    def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
        return []

Warning

These examples are simplified for demonstration. Real-world policies, especially for complex environments like Minecraft, often require more sophisticated architectures (e.g., CNNs for image processing, recurrent layers like LSTMs or Transformers for temporal dependencies if not Markovian). The input['image'] format and normalization (e.g., / 255.0) should match how your environment provides observations and how your model expects them.