Models#
We provided a template for the Minecraft Policy, MinePolicy
, and based on this template, we created various different baseline models. Currently, MineStudio supports VPT, STEVE-1, GROOT, and ROCKET-1, among others. This page details the MinePolicy
template and how to create your own policies.
MineStudio Models
Quick Start#
Here is an example that shows how to load and use OpenAI’s VPT policy within the Minecraft environment provided by MineStudio.
from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import RecordCallback
from minestudio.models import load_vpt_policy, VPTPolicy
# Option 1: Load the policy from local model files
# Ensure you have the .model (architecture) and .weights (parameters) files.
policy = load_vpt_policy(
model_path="/path/to/foundation-model-2x.model",
weights_path="/path/to/foundation-model-2x.weights"
).to("cuda") # Move the policy to GPU if available
# Option 2: Load the policy from the Hugging Face Model Hub
# This is a convenient way to get pre-trained models.
# policy = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")
# Set the policy to evaluation mode. This is important for consistent behavior during inference.
policy.eval()
# Initialize the Minecraft Simulator
# obs_size specifies the resolution of the visual observations.
# callbacks allow for custom actions during the simulation, e.g., recording.
env = MinecraftSim(
obs_size=(128, 128),
callbacks=[RecordCallback(record_path="./output", fps=30, frame_type="pov")]
)
# `memory` stores the recurrent state of the policy (e.g., for RNNs).
# For policies without memory (Markovian), it can be initialized to None.
memory = None
obs, info = env.reset() # Reset the environment to get the initial observation.
# Simulation loop
for i in range(1200): # Run for 1200 steps
# Get an action from the policy.
# `obs`: The current observation from the environment.
# `memory`: The current recurrent state of the policy.
# `input_shape='*'`: Indicates that `obs` is a single sample (not a batch or time sequence).
# The policy handles internal batching/unbatching for its forward pass.
action, memory = policy.get_action(obs, memory, input_shape='*')
# Apply the action to the environment.
obs, reward, terminated, truncated, info = env.step(action)
# Check if the episode has ended.
if terminated or truncated:
print("Episode finished after {} timesteps".format(i+1))
break
env.close() # Close the environment when done.
Hint
In this example, if RecordCallback
is used, the recorded video will be saved in the ./output
directory. The memory
variable handles the policy’s recurrent state, and input_shape='*'
in get_action
is typical for single-instance inference.
Policy Template (MinePolicy
)#
Warning
To ensure compatibility with MineStudio’s training and inference pipelines, custom policies must inherit from minestudio.models.base_policy.MinePolicy
and implement its abstract methods.
The MinePolicy
class serves as the base for all policies. Key methods and properties are described below:
init(self, hiddim, action_space=None, temperature=1.0, nucleus_prob=None)
The constructor for the policy.
hiddim
(int): The hidden dimension size for the policy’s internal layers.action_space
(gymnasium.spaces.Space, optional): The action space of the environment. IfNone
, a default Minecraft action space (camera and buttons) is used.temperature
(float, optional): Temperature for sampling actions from the policy head. Defaults to1.0
.nucleus_prob
(float, optional): Nucleus (top-p) probability for sampling actions. IfNone
, standard sampling is used. Defaults toNone
.
It initializes the policy’s action head (self.pi_head
) and value head (self.value_head
).
# From minestudio.models.base_policy.py
def __init__(self, hiddim, action_space=None, temperature=1.0, nucleus_prob=None) -> None:
torch.nn.Module.__init__(self)
if action_space is None:
action_space = gymnasium.spaces.Dict({
"camera": gymnasium.spaces.MultiDiscrete([121]),
"buttons": gymnasium.spaces.MultiDiscrete([8641]),
})
# self.value_head is a ScaledMSEHead
self.value_head = ScaledMSEHead(hiddim, 1, norm_type="ewma", norm_kwargs=None)
# self.pi_head uses the provided action_space, hiddim, temperature, and nucleus_prob
self.pi_head = make_action_head(action_space, hiddim, temperature=temperature, nucleus_prob=nucleus_prob)
Hint
Users can override self.pi_head
and self.value_head
after calling super().__init__(...)
if custom head implementations are needed.
reset_parameters(self)
Resets the parameters of the policy’s action head (pi_head
) and value head (value_head
). This can be useful for re-initializing a policy.
# From minestudio.models.base_policy.py
def reset_parameters(self):
"""Resets the parameters of the policy and value heads."""
self.pi_head.reset_parameters()
self.value_head.reset_parameters()
forward(self, input: Dict[str, Any], state_in: Optional[List[torch.Tensor]] = None, **kwargs) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
This is an abstract method and must be implemented by derived classes. It defines the main computation of the policy.
input
(Dict[str, Any]): A dictionary of input tensors (e.g., observations from the environment).state_in
(Optional[List[torch.Tensor]]): A list of input recurrent state tensors.None
if the episode is starting or the policy is Markovian.**kwargs
: Additional keyword arguments.
Returns:
latents
(Dict[str, torch.Tensor]): A dictionary containing at least:'pi_logits'
(torch.Tensor): The logits for the action distribution.'vpred'
(torch.Tensor): The predicted value function.
state_out
(List[torch.Tensor]): A list containing the updated recurrent state tensors. For a Markovian policy (no state), this should be an empty list ([]
).
# From minestudio.models.base_policy.py
@abstractmethod
def forward(self,
input: Dict[str, Any],
state_in: Optional[List[torch.Tensor]] = None,
**kwargs
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
pass
initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]
This is an abstract method and must be implemented by derived classes. It returns the initial recurrent state of the policy.
batch_size
(Optional[int]): The batch size for which to create the initial state.
Returns:
(List[torch.Tensor]): A list of initial state tensors. For a Markovian policy (no state), this should be an empty list (
[]
).
# From minestudio.models.base_policy.py
@abstractmethod
def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
pass
get_action(self, input: Dict[str, Any], state_in: Optional[List[torch.Tensor]], deterministic: bool = False, input_shape: str = “BT*”, **kwargs) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
This method computes and returns an action from the policy based on the current input and state. It’s typically used during inference.
input
(Dict[str, Any]): Current observation from the environment.state_in
(Optional[List[torch.Tensor]]): Current recurrent state.deterministic
(bool): IfTrue
, samples actions deterministically (e.g., argmax). IfFalse
, samples stochastically. Defaults toFalse
.input_shape
(str): Specifies the shape of theinput
."*"
: Single instance input (e.g., one observation at a time during inference). The method handles batching/unbatching internally."BT*"
: Batched sequence input (Batch, Time, …). Defaults to"BT*"
.
**kwargs
: Additional keyword arguments passed toforward
.
Returns:
action
(Dict[str, torch.Tensor]): The sampled action.state_out
(List[torch.Tensor]): The updated recurrent state.
# Simplified from minestudio.models.base_policy.py
@torch.inference_mode()
def get_action(self,
input: Dict[str, Any],
state_in: Optional[List[torch.Tensor]],
deterministic: bool = False,
input_shape: str = "BT*",
**kwargs,
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
if input_shape == "*":
# Internal batching for single instance input
input = dict_map(self._batchify, input)
if state_in is not None:
state_in = recursive_tensor_op(lambda x: x.unsqueeze(0), state_in)
elif input_shape != "BT*":
raise NotImplementedError("Unsupported input_shape")
latents, state_out = self.forward(input, state_in, **kwargs)
action = self.pi_head.sample(latents['pi_logits'], deterministic=deterministic)
self.vpred = latents['vpred'] # Cache for potential later use
if input_shape == "*":
# Internal unbatching for single instance output
action = dict_map(lambda tensor: tensor[0][0], action)
state_out = recursive_tensor_op(lambda x: x[0], state_out)
return action, state_out
Note
Empirically, setting deterministic=False
(stochastic sampling) can often improve policy performance during evaluation compared to deterministic actions.
The input_shape="*"
is common for inference when processing one observation at a time.
device (property)
A property that returns the torch.device
(e.g., ‘cpu’, ‘cuda:0’) on which the policy’s parameters are located.
# From minestudio.models.base_policy.py
@property
def device(self) -> torch.device:
return next(self.parameters()).device
Hint
The minimal set of methods you must implement in your custom policy are forward
and initial_state
.
Your First Policy#
Here are basic examples of how to create custom policies by inheriting from MinePolicy
.
Load necessary modules:
import torch
import torch.nn as nn
from einops import rearrange
from typing import Dict, List, Optional, Tuple, Any
from minestudio.models.base_policy import MinePolicy
# Assuming make_action_head and ScaledMSEHead are accessible for custom heads,
# or rely on those initialized in MinePolicy's __init__.
Example 1: Condition-Free (Markovian) Policy#
This policy does not depend on any external condition beyond the current observation and has no recurrent state.
class MySimpleMarkovPolicy(MinePolicy):
def __init__(self, hiddim, action_space=None, image_size=(64, 64), image_channels=3) -> None:
super().__init__(hiddim, action_space) # Initializes self.pi_head and self.value_head
# Example backbone: a simple MLP
# Input image is flattened: image_size[0] * image_size[1] * image_channels
self.feature_dim = image_size[0] * image_size[1] * image_channels
self.net = nn.Sequential(
nn.Linear(self.feature_dim, hiddim),
nn.ReLU(),
nn.Linear(hiddim, hiddim),
nn.ReLU()
)
# self.pi_head and self.value_head are already defined in the parent class.
def forward(self,
input: Dict[str, Any],
state_in: Optional[List[torch.Tensor]] = None, # Will be None for Markovian
**kwargs
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
# Assuming input['image'] is (B, T, H, W, C)
# For a Markovian policy, we typically expect T=1 or process frames independently.
# If T > 1, this example processes each time step independently.
img_obs = input['image'] # Shape: (B, T, H, W, C)
b, t, h, w, c = img_obs.shape
# Flatten image: (B, T, H, W, C) -> (B*T, H*W*C)
# Normalize image (example: scale to [0,1])
x = rearrange(img_obs / 255.0, 'b t h w c -> (b t) (h w c)')
features = self.net(x) # Shape: (B*T, hiddim)
# Reshape for policy and value heads if they expect (B, T, hiddim)
features_reshaped = rearrange(features, '(b t) d -> b t d', b=b, t=t)
pi_logits = self.pi_head(features_reshaped) # pi_head handles (B, T, D) or (B*T, D)
vpred = self.value_head(features_reshaped) # value_head handles (B, T, D) or (B*T, D)
result = {
'pi_logits': pi_logits,
'vpred': vpred,
}
# For a Markovian policy, state_out is an empty list
return result, []
def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
# Markovian policy has no state, return empty list
return []
Example 2: Condition-Based (Markovian) Policy#
This policy takes an additional ‘condition’ tensor as input.
class MySimpleConditionedPolicy(MinePolicy):
def __init__(self, hiddim, action_space=None, image_size=(64, 64), image_channels=3, condition_dim=64) -> None:
super().__init__(hiddim, action_space)
self.feature_dim = image_size[0] * image_size[1] * image_channels
self.net = nn.Sequential(
nn.Linear(self.feature_dim, hiddim),
nn.ReLU(),
)
# Embedding for the condition
self.condition_net = nn.Linear(condition_dim, hiddim) # Or nn.Embedding if condition is discrete
# Fusion layer
self.fusion_net = nn.Sequential(
nn.Linear(hiddim * 2, hiddim), # Example: concatenate image and condition features
nn.ReLU(),
nn.Linear(hiddim, hiddim),
nn.ReLU()
)
def forward(self,
input: Dict[str, Any],
state_in: Optional[List[torch.Tensor]] = None,
**kwargs
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
img_obs = input['image'] # Shape: (B, T, H, W, C)
condition = input['condition'] # Shape: (B, T, condition_dim)
b, t, h, w, c = img_obs.shape
x_img = rearrange(img_obs / 255.0, 'b t h w c -> (b t) (h w c)')
img_features = self.net(x_img) # Shape: (B*T, hiddim)
# Process condition
# Assuming condition is already (B, T, condition_dim) -> (B*T, condition_dim)
cond_features = self.condition_net(rearrange(condition, 'b t d -> (b t) d')) # Shape: (B*T, hiddim)
# Fuse features (example: concatenation)
fused_features = torch.cat([img_features, cond_features], dim=-1)
final_features = self.fusion_net(fused_features) # Shape: (B*T, hiddim)
final_features_reshaped = rearrange(final_features, '(b t) d -> b t d', b=b, t=t)
pi_logits = self.pi_head(final_features_reshaped)
vpred = self.value_head(final_features_reshaped)
result = {
'pi_logits': pi_logits,
'vpred': vpred,
}
return result, []
def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
return []
Warning
These examples are simplified for demonstration. Real-world policies, especially for complex environments like Minecraft, often require more sophisticated architectures (e.g., CNNs for image processing, recurrent layers like LSTMs or Transformers for temporal dependencies if not Markovian).
The input['image']
format and normalization (e.g., / 255.0
) should match how your environment provides observations and how your model expects them.