Models#

We provided a template for the Minecraft Policy and based on this template, we created various different baseline models, currently suppoerting VPT, STEVE-1, GROOT, ROCKET-1.

Quick Start#

Here is an example that shows how to load the OpenAI’s VPT policy in the Minecraft environment.

from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import RecordCallback
from minestudio.models import load_vpt_policy, VPTPolicy

# load the policy from the local model files
policy = load_vpt_policy(
    model_path="/path/to/foundation-model-2x.model", 
    weights_path="/path/to/foundation-model-2x.weights"
).to("cuda")

# or load the policy from the Hugging Face model hub
policy = VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x").to("cuda")

policy.eval()

env = MinecraftSim(
    obs_size=(128, 128), 
    callbacks=[RecordCallback(record_path="./output", fps=30, frame_type="pov")]
)
memory = None
obs, info = env.reset()
for i in range(1200):
    action, memory = policy.get_action(obs, memory, input_shape='*')
    obs, reward, terminated, truncated, info = env.step(action)
env.close()

Hint

In this example, the recorded video will be saved in the ./output directory.

Policy Template#

Warning

One must implement their own policies based on the template to be compatible with our training and inference pipelines.

The policy template lies in minestudio.models.base_policy.MinePolicy. It consists of the following methods:

init(self, hiddim, action_space=None)

The constructor of the policy. It initializes the policy head and value head. The hiddim is the hidden dimension of the policy. The action_space is the action space of the environment. If it is not provided, the default action space of the Minecraft environment will be used.

def __init__(self, hiddim, action_space=None) -> None:
    torch.nn.Module.__init__(self)
    if action_space is None:
        action_space = gymnasium.spaces.Dict({
            "camera": gymnasium.spaces.MultiDiscrete([121]), 
            "buttons": gymnasium.spaces.MultiDiscrete([8641]),
        })
    self.pi_head = make_action_head(action_space, hiddim, temperature=2.0)
    self.value_head = ScaledMSEHead(hiddim, 1, norm_type="ewma", norm_kwargs=None)

Hint

If users want to customize the pi_head and value_head modules, they can override them after calling the super().__init__ method.

forward(self, input, state_in, **kwargs)

The forward method of the policy. It takes the input and the state tensors and returns the latent tensors and the updated state tensors.

@abstractmethod
def forward(self, 
            input: Dict[str, Any], 
            state_in: Optional[List[torch.Tensor]] = None,
            **kwargs
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
    """
    Returns:
        latents: containing `pi_logits` and `vpred` latent tensors.
        state_out: containing the updated state tensors.
    """
    pass

Note

This method should be implemented by the derived classes.

initial_state(self, batch_size=None)

This is an important method that returns the initial state of the policy.

@abstractmethod
def initial_state(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
    pass
get_action(self, input, state_in, deterministic, input_shape)

This is the method that returns the action of the policy. It takes the input, the state tensors, and the deterministic flag, and returns the action tensor and the updated state tensors. This method is usually called during the inference process.

@torch.inference_mode()
def get_action(self,
                input: Dict[str, Any],
                state_in: Optional[List[torch.Tensor]],
                deterministic: bool = False,
                input_shape: str = "BT*",
                **kwargs, 
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]:
    if input_shape == "*":
        input = dict_map(self._batchify, input)
        if state_in is not None:
            state_in = recursive_tensor_op(lambda x: x.unsqueeze(0), state_in)
    elif input_shape != "BT*":
        raise NotImplementedError
    latents, state_out = self.forward(input, state_in, **kwargs)
    action = self.pi_head.sample(latents['pi_logits'], deterministic)
    self.vpred = latents['vpred']
    if input_shape == "BT*":
        return action, state_out
    elif input_shape == "*":
        return dict_map(lambda tensor: tensor[0][0], action), recursive_tensor_op(lambda x: x[0], state_out)
    else:
        raise NotImplementedError

Note

deterministic is a flag that indicates whether the action is generates with argmax or stochastic sampling. We emperically find that setting deterministic=False can improve the performance of the policy.

Note

input_shape is a string that indicates the shape of the input. It can be "BT*" or "*". "BT*" means the input is a batch of time sequences, and "*" means the input is a single sample. Generally speaking, if you are in inference mode, you feed an observation once a time. So you should set input_shape="*".

device(self)

This is a property method that returns the device of the policy.

@property
def device(self) -> torch.device:
    return next(self.parameters()).device

Hint

The minimal set you need to care about is forward and initial_state.

Your First Policy#

Load the necessary modules:

import torch
import torch.nn as nn
from minestudio.models.base_policy import MinePolicy

To customize a condition-free policy, you can follow this example:

class MyConditionFreePolicy(MinePolicy):
    def __init__(self, hiddim, action_space=None) -> None:
        super().__init__(hiddim, action_space)
        # we use the original pi_head and value_head here.
        self.net = nn.Sequential(
            nn.Linear(128*128*3, hiddim), 
            nn.ReLU(),
            nn.Linear(hiddim, hiddim),
            nn.ReLU()
        )
        # we implement a simple mlp network here as the backbone. 

    def forward(self, input, state_in, **kwargs):
        x = rearrange(input['image'] / 255., 'b t h w c -> b t (h w c)')
        x = self.net(x)
        result = {
            'pi_logits': self.pi_head(x), 
            'vpred': self.value_head(x), 
        }
        return result, state_in

    def initial_state(self, batch_size=None):
        # we implement a simple markov policy here, so the state is always None.
        None

To customize a condition-based policy, you can follow this example:

class MyConditionBasedPolicy(MinePolicy):
    def __init__(self, hiddim, action_space=None) -> None:
        super().__init__(hiddim, action_space)
        # we use the original pi_head and value_head here.
        self.net = nn.Sequential(
            nn.Linear(128*128*3, hiddim), 
            nn.ReLU(),
            nn.Linear(hiddim, hiddim),
            nn.ReLU()
        )
        self.condition_net = nn.Embedding(10, hiddim)
        # we implement a simple mlp network here as the backbone. 

    def forward(self, input, state_in, **kwargs):
        x = rearrange(input['image'] / 255., 'b t h w c -> b t (h w c)')
        x = self.net(x) # b t c
        y = self.condition_net(input['condition']) # b t -> b t c
        z = x + y # simple addition fusion
        result = {
            'pi_logits': self.pi_head(z), 
            'vpred': self.value_head(z), 
        }
        return result, state_in

    def initial_state(self, batch_size=None):
        # we implement a simple markov policy here, so the state is always None.
        None

Warning

These examples are just for demonstration purposes and may perform poorly in practice.