Source code for minestudio.models.steve_one.body

from copy import deepcopy
import hashlib
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

from minestudio.models.base_policy import MinePolicy
from huggingface_hub import PyTorchModelHubMixin

from minestudio.utils.mineclip_lib.mineclip import MineCLIP
from minestudio.utils.vpt_lib.impala_cnn import ImpalaCNN
from minestudio.utils.vpt_lib.util import FanInInitReLULayer, ResidualRecurrentBlocks
from minestudio.models.base_policy import MinePolicy

[docs] class TranslatorVAE(torch.nn.Module): """ Variational Autoencoder for translating between text and visual embeddings. This VAE learns to map text embeddings to visual embeddings through a latent space, enabling text-to-visual translation for multimodal learning tasks. The encoder takes concatenated visual and text embeddings to produce latent representations, while the decoder reconstructs visual embeddings from latent codes and text. """ def __init__(self, input_dim=512, hidden_dim=256, latent_dim=256): """ Initialize the TranslatorVAE with specified dimensions. :param input_dim: Dimension of input visual and text embeddings :param hidden_dim: Dimension of hidden layers in encoder and decoder :param latent_dim: Dimension of the latent space """ super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.encoder = torch.nn.Sequential( torch.nn.Linear(input_dim * 2, hidden_dim), torch.nn.LayerNorm(hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LayerNorm(hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, 2 * latent_dim), ) self.decoder = torch.nn.Sequential( torch.nn.Linear(latent_dim + input_dim, hidden_dim), torch.nn.LayerNorm(hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LayerNorm(hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, input_dim), )
[docs] def encode(self, visual_embeddings, text_embeddings): """ Encode concatenated visual and text embeddings into latent parameters. :param visual_embeddings: Visual embedding tensor of shape (B, input_dim) :param text_embeddings: Text embedding tensor of shape (B, input_dim) :returns: Encoded tensor containing mean and log variance (B, 2*latent_dim) """ # Concatenate the visual and text embeddings. x = torch.cat([visual_embeddings, text_embeddings], dim=1) # Encode the concatenated embeddings into a latent vector. return self.encoder(x)
[docs] def sample(self, mu, logvar): """ Sample a latent vector using the reparameterization trick. Applies the reparameterization trick to sample from a Gaussian distribution defined by mu and logvar, enabling backpropagation through stochastic sampling. :param mu: Mean tensor of shape (B, latent_dim) :param logvar: Log variance tensor of shape (B, latent_dim) :returns: Sampled latent vector of shape (B, latent_dim) """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std
[docs] def decode(self, latent_vector, text_embeddings): """ Decode latent vector and text embeddings into visual embeddings. :param latent_vector: Latent representation tensor of shape (B, latent_dim) :param text_embeddings: Text embedding tensor of shape (B, input_dim) :returns: Reconstructed visual embeddings of shape (B, input_dim) """ # Concatenate the latent vector and text embeddings. x = torch.cat([latent_vector, text_embeddings], dim=1) # Decode the concatenated embeddings into a visual embedding. return self.decoder(x)
[docs] def forward(self, text_embeddings, deterministic=False): """ Generate visual embeddings from text embeddings using prior distribution. Uses a zero-mean, unit-variance prior to sample latent vectors and decode them into visual embeddings conditioned on the input text embeddings. :param text_embeddings: Input text embeddings of shape (B, input_dim) :param deterministic: If True, use mean of prior; if False, sample from prior :returns: Generated visual embeddings of shape (B, input_dim) """ # Use the prior as the mean and logvar. mu = torch.zeros(text_embeddings.shape[0], self.latent_dim).to(text_embeddings.device) logvar = torch.zeros(text_embeddings.shape[0], self.latent_dim).to(text_embeddings.device) # Sample a latent vector from the mu and logvar. if deterministic: latent_vector = mu else: latent_vector = self.sample(mu, logvar) # Decode the latent vector into a visual embedding. pred_visual_embeddings = self.decode(latent_vector, text_embeddings) return pred_visual_embeddings
[docs] class ImgPreprocessing(torch.nn.Module): """ Image preprocessing module for normalization and scaling. Normalizes incoming images using either provided statistics (mean/std) or simple scaling. Supports both statistical normalization from pre-computed statistics and basic scaling by a constant factor. :param img_statistics: Path to npz file containing mean and std statistics. If specified, normalize images using these statistics. :param scale_img: If True and img_statistics not specified, scale images by 1/255. """ def __init__(self, img_statistics: Optional[str] = None, scale_img: bool = True): """ Initialize image preprocessing with normalization parameters. :param img_statistics: Optional path to npz file with mean/std statistics :param scale_img: Whether to scale images by 1/255 when no statistics provided """ super().__init__() self.img_mean = None if img_statistics is not None: img_statistics = dict(**np.load(img_statistics)) self.img_mean = torch.nn.Parameter(torch.Tensor(img_statistics["mean"]), requires_grad=False) self.img_std = torch.nn.Parameter(torch.Tensor(img_statistics["std"]), requires_grad=False) else: self.ob_scale = 255.0 if scale_img else 1.0
[docs] def forward(self, img): """ Apply preprocessing normalization to input images. :param img: Input image tensor of shape (B, C, H, W) or (B, T, C, H, W) :returns: Normalized image tensor with same shape as input """ x = img if self.img_mean is not None: x = (x - self.img_mean) / self.img_std else: x = x / self.ob_scale return x
[docs] class ImgObsProcess(torch.nn.Module): """ Image observation processing using ImpalaCNN followed by a linear layer. Processes image observations through an Impala CNN architecture and then applies a linear transformation to produce the final output embeddings. :param cnn_outsize: Output dimension of the Impala CNN :param output_size: Output size of the final linear layer :param dense_init_norm_kwargs: Initialization kwargs for linear FanInInitReLULayer :param init_norm_kwargs: Initialization kwargs for 2D and 3D conv FanInInitReLULayer """ def __init__( self, cnn_outsize: int, output_size: int, dense_init_norm_kwargs: Dict = {}, init_norm_kwargs: Dict = {}, **kwargs, ): """ Initialize image observation processing module. :param cnn_outsize: Output dimension of the CNN backbone :param output_size: Final output embedding dimension :param dense_init_norm_kwargs: Kwargs for dense layer initialization :param init_norm_kwargs: Kwargs for convolution layer initialization :param kwargs: Additional arguments passed to ImpalaCNN """ super().__init__() self.cnn = ImpalaCNN( outsize=cnn_outsize, init_norm_kwargs=init_norm_kwargs, dense_init_norm_kwargs=dense_init_norm_kwargs, **kwargs, ) self.linear = FanInInitReLULayer( cnn_outsize, output_size, layer_type="linear", **dense_init_norm_kwargs, )
[docs] def forward(self, img): """ Process image through CNN and linear transformation. :param img: Input image tensor of shape (B, C, H, W) :returns: Processed image features of shape (B, output_size) """ return self.linear(self.cnn(img))
[docs] class MinecraftPolicy(torch.nn.Module): """ Neural network policy for Minecraft gameplay with multimodal inputs. This policy combines visual and textual information processing through CNN and recurrent architectures. It supports various recurrence types including LSTM, masked LSTM, and transformer variants for sequential decision making. The architecture processes images through ImpalaCNN, combines with MineCLIP embeddings for text conditioning, and uses recurrent layers for temporal modeling. :param recurrence_type: Type of recurrent architecture: - 'multi_layer_lstm': Multi-layer LSTM (no ragged batching support) - 'multi_layer_bilstm': Bidirectional multi-layer LSTM - 'multi_masked_lstm': Multi-layer LSTM with ragged batching support - 'transformer': Dense transformer architecture - 'none': No recurrence :param init_norm_kwargs: Initialization kwargs for all FanInInitReLULayers """ def __init__( self, recurrence_type="lstm", impala_width=1, impala_chans=(16, 32, 32), obs_processing_width=256, hidsize=512, single_output=False, # True if we don't need separate outputs for action/value outputs img_shape=None, scale_input_img=True, only_img_input=False, init_norm_kwargs={}, impala_kwargs={}, # Unused argument assumed by forc. input_shape=None, # pylint: disable=unused-argument active_reward_monitors=None, img_statistics=None, first_conv_norm=False, diff_mlp_embedding=False, attention_mask_style="clipped_causal", attention_heads=8, attention_memory_size=2048, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, n_recurrence_layers=1, recurrence_is_residual=True, timesteps=None, use_pre_lstm_ln=True, # Not needed for transformer mineclip_embed_dim=512, # MODIFIED (added this) **unused_kwargs, ): """ Initialize the Minecraft policy network. :param recurrence_type: Type of recurrent layer ('multi_layer_lstm', 'multi_layer_bilstm', 'multi_masked_lstm', 'transformer', 'none') :param impala_width: Width multiplier for Impala CNN channels :param impala_chans: Base channel sizes for Impala CNN layers :param obs_processing_width: Width for observation processing layers :param hidsize: Hidden state size for recurrent layers and output :param single_output: Whether to use single output for both policy and value :param img_shape: Expected input image shape :param scale_input_img: Whether to scale input images by 1/255 :param only_img_input: Whether to use only image inputs (unused) :param init_norm_kwargs: Initialization kwargs for normalization layers :param impala_kwargs: Additional kwargs for Impala CNN :param input_shape: Legacy parameter (unused) :param active_reward_monitors: Dictionary of active reward monitors :param img_statistics: Path to image normalization statistics :param first_conv_norm: Whether to apply normalization to first conv layer :param diff_mlp_embedding: Whether to use different MLP embedding (unused) :param attention_mask_style: Attention masking style for transformer :param attention_heads: Number of attention heads for transformer :param attention_memory_size: Memory size for attention mechanism :param use_pointwise_layer: Whether to use pointwise layers in transformer :param pointwise_ratio: Ratio for pointwise layer dimensions :param pointwise_use_activation: Whether to use activation in pointwise layers :param n_recurrence_layers: Number of recurrent layers to stack :param recurrence_is_residual: Whether recurrent layers use residual connections :param timesteps: Number of timesteps for sequence processing :param use_pre_lstm_ln: Whether to use layer norm before LSTM :param mineclip_embed_dim: Dimension of MineCLIP embeddings :param unused_kwargs: Additional unused arguments """ super().__init__() assert recurrence_type in [ "multi_layer_lstm", "multi_layer_bilstm", "multi_masked_lstm", "transformer", "none", ] active_reward_monitors = active_reward_monitors or {} self.single_output = single_output chans = tuple(int(impala_width * c) for c in impala_chans) self.hidsize = hidsize # Dense init kwargs replaces batchnorm/groupnorm with layernorm self.init_norm_kwargs = init_norm_kwargs self.dense_init_norm_kwargs = deepcopy(init_norm_kwargs) if self.dense_init_norm_kwargs.get("group_norm_groups", None) is not None: self.dense_init_norm_kwargs.pop("group_norm_groups", None) self.dense_init_norm_kwargs["layer_norm"] = True if self.dense_init_norm_kwargs.get("batch_norm", False): self.dense_init_norm_kwargs.pop("batch_norm", False) self.dense_init_norm_kwargs["layer_norm"] = True # Setup inputs self.img_preprocess = ImgPreprocessing(img_statistics=img_statistics, scale_img=scale_input_img) self.img_process = ImgObsProcess( cnn_outsize=256, output_size=hidsize, inshape=img_shape, chans=chans, nblock=2, dense_init_norm_kwargs=self.dense_init_norm_kwargs, init_norm_kwargs=init_norm_kwargs, first_conv_norm=first_conv_norm, **impala_kwargs, ) self.pre_lstm_ln = torch.nn.LayerNorm(hidsize) if use_pre_lstm_ln else None self.diff_obs_process = None self.recurrence_type = recurrence_type self.recurrent_layer = None self.recurrent_layer = ResidualRecurrentBlocks( hidsize=hidsize, timesteps=timesteps, recurrence_type=recurrence_type, is_residual=recurrence_is_residual, use_pointwise_layer=use_pointwise_layer, pointwise_ratio=pointwise_ratio, pointwise_use_activation=pointwise_use_activation, attention_mask_style=attention_mask_style, attention_heads=attention_heads, attention_memory_size=attention_memory_size, n_block=n_recurrence_layers, ) self.lastlayer = FanInInitReLULayer(hidsize, hidsize, layer_type="linear", **self.dense_init_norm_kwargs) self.final_ln = torch.nn.LayerNorm(hidsize) # MODIFIED (added this) self.mineclip_embed_linear = torch.nn.Linear(mineclip_embed_dim, hidsize)
[docs] def output_latent_size(self): """ Get the size of the output latent representation. :returns: Integer size of the hidden/latent dimension """ return self.hidsize
[docs] def forward(self, ob, state_in, context): """ Forward pass through the Minecraft policy network. Processes multimodal observations (images and MineCLIP embeddings) through CNN and recurrent layers to produce policy and value latent representations. :param ob: Observation dictionary containing 'img' and 'mineclip_embed' keys :param state_in: Input recurrent state from previous timestep :param context: Context dictionary containing 'first' episode flags :returns: Tuple of ((pi_latent, vf_latent), state_out) where latents are policy and value representations and state_out is updated recurrent state """ b, t = ob["img"].shape[:2] first = context["first"].bool() x = self.img_preprocess(ob["img"]) x = self.img_process(x) if self.diff_obs_process: processed_obs = self.diff_obs_process(ob["diff_goal"]) x = processed_obs + x # MODIFIED (added this) mineclip_embed = ob["mineclip_embed"].reshape(b * t, -1) # Normalize mineclip_embed (doesn't work because the norm is way too small then?) # mineclip_embed = F.normalize(mineclip_embed, dim=-1) mineclip_embed = self.mineclip_embed_linear(mineclip_embed) mineclip_embed = mineclip_embed.reshape(b, t, -1) x = x + mineclip_embed if self.pre_lstm_ln is not None: x = self.pre_lstm_ln(x) if self.recurrent_layer is not None: x, state_out = self.recurrent_layer(x, first, state_in) else: state_out = state_in x = F.relu(x, inplace=False) x = self.lastlayer(x) x = self.final_ln(x) pi_latent = vf_latent = x if self.single_output: return pi_latent, state_out return (pi_latent, vf_latent), state_out
[docs] def initial_state(self, batchsize): """ Initialize the recurrent state for a new episode. :param batchsize: Batch size for state initialization :returns: Initial recurrent state tensors or None if no recurrence """ if self.recurrent_layer: return self.recurrent_layer.initial_state(batchsize) else: return None
[docs] class SteveOnePolicy(MinePolicy, PyTorchModelHubMixin): """ Complete STEVE-1 policy combining visual processing, text conditioning, and action prediction. This policy integrates MineCLIP for multimodal understanding, a TranslatorVAE for text-to-visual translation, and a MinecraftPolicy network for decision making. It supports both text and video conditioning through classifier-free guidance. The architecture enables text-conditioned gameplay by translating textual instructions into visual embeddings that guide the policy's behavior in Minecraft environments. """ def __init__( self, mineclip_kwargs: dict = {}, prior_kwargs: dict = {}, policy_kwargs: dict = {}, freeze_mineclip: bool = True, action_space = None, ): """ Initialize the STEVE-1 policy with all components. :param mineclip_kwargs: Configuration for MineCLIP multimodal encoder :param prior_kwargs: Configuration for TranslatorVAE prior model :param policy_kwargs: Configuration for MinecraftPolicy network :param freeze_mineclip: Whether to freeze MineCLIP parameters during training :param action_space: Action space specification for the environment """ net = MinecraftPolicy(** policy_kwargs) super().__init__(hiddim=net.hidsize, action_space=action_space) self.net = net self.prior = TranslatorVAE(** prior_kwargs) self.mineclip = MineCLIP(** mineclip_kwargs) if freeze_mineclip: for param in self.mineclip.parameters(): param.requires_grad = False
[docs] def prepare_condition(self, instruction: Dict[str, Any], deterministic: bool = False) -> Dict[str, Any]: """ Prepare conditioning information from text or video instructions. Processes either text instructions (via TranslatorVAE) or video demonstrations (via direct MineCLIP encoding) to create conditioning embeddings for the policy. :param instruction: Dictionary containing either 'text' or 'video' key plus 'cond_scale' :param deterministic: Whether to use deterministic sampling for text conditioning :returns: Dictionary with 'cond_scale' and 'mineclip_embeds' for policy conditioning :raises AssertionError: If instruction lacks required keys or has conflicting modalities """ assert 'cond_scale' in instruction, "instruction must have 'cond_scale' key." if 'video' in instruction: assert 'text' not in instruction, "cannot have both text and video in instruction" video = instruction['video'] if isinstance(video, np.ndarray): video = torch.from_numpy(video).to(self.device) if video.dim() == 4: video = rearrange(video, 'T H W C -> 1 T C H W') if video.shape[2] != 3: video = rearrange(video, 'B T H W C -> B T C H W') assert video.dim() == 5 and video.shape[2] == 3, "video must be a 5D tensor with shape (B, T, C, H, W) or (B, T, H, W, C)" B, T, C, H, W = video.shape if video.dtype == torch.uint8: mineclip_inputs = video.float() elif video.dtype == torch.float32: assert video.abs().max() <= 1.0, "float32 video must be in range [-1, 1]" mineclip_inputs = video * 255.0 else: raise ValueError("video must be either uint8 or float32.") mineclip_inputs = rearrange( torch.nn.functional.interpolate( rearrange(mineclip_inputs, 'B T C H W -> (B T) C H W'), size=(160, 256), mode='bilinear', align_corners=False ), '(B T) C H W -> B T C H W', B=B, T=T ) mineclip_embeds = self.mineclip.encode_video(mineclip_inputs) else: assert 'text' in instruction, "instruction must have either text or video." texts = instruction['text'] if isinstance(texts, str): texts = [texts] assert isinstance(texts, list) and isinstance(texts[0], str), "text must be a string or a list of strings." text_embeds = self.mineclip.encode_text(texts) mineclip_embeds = self.prior(text_embeds, deterministic=deterministic) return { 'cond_scale': instruction['cond_scale'], 'mineclip_embeds': mineclip_embeds, }
[docs] def forward( self, input: Dict[str, Any], state_in: Optional[List[torch.Tensor]] = None, **kwargs ) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]: """ Forward pass with classifier-free guidance for conditioned generation. Performs a forward pass through the policy network with optional classifier-free guidance. When cond_scale > 0, runs both conditioned and unconditioned inference and combines the outputs using the specified guidance scale. :param condition: Dictionary with 'cond_scale' and 'mineclip_embeds' :param input: Dictionary with 'image' key containing observation images :param state_in: Optional list of recurrent state tensors from previous timestep :param kwargs: Additional keyword arguments (unused) :returns: Tuple of (latents_dict, state_out) where latents_dict contains 'pi_logits' and 'vpred' and state_out is updated recurrent state """ condition = input['condition'].copy() if 'mineclip_embeds' not in condition: condition = self.prepare_condition(condition) if state_in is None: state_in = self.initial_state(condition, input['image'].shape[0]) input, state_in = input.copy(), state_in.copy() mineclip_embeds = condition['mineclip_embeds'] if mineclip_embeds.shape[0] == 1 and input['image'].shape[0] > 1: mineclip_embeds = repeat(mineclip_embeds, '1 ... -> b ...', b=input['image'].shape[0]) if condition['cond_scale'] != 0 and condition['cond_scale'] is not None: state_in = [rearrange(x, "b ... c -> (b c) ...") for x in state_in] images = repeat(input['image'], "b ... -> (b c) ...", c=2) mineclip_embeds = rearrange( torch.stack([mineclip_embeds, torch.zeros_like(mineclip_embeds)]), 'c b ... -> (b c) ...' ) else: images = input['image'] dummy_first = torch.zeros((images.shape[0], images.shape[1]), dtype=torch.bool, device=self.device) if images.shape[-1] != 3: images = rearrange(images, 'b t c h w -> b t h w c') if images.dtype == torch.uint8: images = images.float() elif images.dtype == torch.float32: assert images.abs().max() <= 1.0, "float32 image must be in range [-1, 1]" images = images * 255.0 else: raise ValueError("image must be either uint8 or float32.") (pi_latent, vf_latent), state_out = self.net( ob={"img": images, "mineclip_embed": repeat(mineclip_embeds, 'b c -> b t c', t=images.shape[1])}, context={"first": dummy_first}, state_in=state_in ) pi_logits = self.pi_head(pi_latent) vpred = self.value_head(vf_latent) if condition['cond_scale'] != 0 and condition['cond_scale'] is not None: pi_logits = {k: rearrange(v, '(b c) ... -> b c ...', c=2) for k, v in pi_logits.items()} vpred = rearrange(vpred, '(b c) ... -> b c ...', c=2) state_out = [rearrange(x, '(b c) ... -> b ... c', c=2) for x in state_out] pi_logits = {k: (1 + condition['cond_scale']) * v[:, 0] - condition['cond_scale'] * v[:, 1] for k, v in pi_logits.items()} vpred = vpred[:, 0] latents = { "pi_logits": pi_logits, "vpred": vpred, } return latents, state_out
[docs] def initial_state( self, condition: Dict[str, Any], batch_size: Optional[int] = None ) -> List[torch.Tensor]: """ Initialize recurrent state for policy inference. Creates initial state tensors for the recurrent layers. When classifier-free guidance is enabled (cond_scale > 0), duplicates states for both conditioned and unconditioned inference branches. :param condition: Conditioning dictionary with 'cond_scale' key :param batch_size: Batch size for state initialization :returns: List of initial state tensors """ initial_state = self.net.initial_state(batch_size) if condition['cond_scale'] == 0.0 or condition['cond_scale'] is None: return initial_state else: return [torch.stack([x, x], dim=-1) for x in initial_state]
[docs] def reset_parameters(self): """ Reset all trainable parameters in the policy network. :raises NotImplementedError: This method is not yet implemented """ super().reset_parameters() self.net.reset_parameters() raise NotImplementedError()
@property def device(self) -> torch.device: """ Get the device of the policy model. :returns: torch.device where the model parameters are located """ return next(self.parameters()).device
[docs] def load_steve_one_policy(ckpt_path: str) -> SteveOnePolicy: """ Load a pre-trained STEVE-1 policy from a checkpoint path. Convenience function to load a STEVE-1 policy using the HuggingFace Hub model loading mechanism. :param ckpt_path: Path or HuggingFace model identifier for the checkpoint :returns: Loaded SteveOnePolicy instance """ return SteveOnePolicy.from_pretrained(ckpt_path)
if __name__ == '__main__': model = SteveOnePolicy.from_pretrained("CraftJarvis/MineStudio_STEVE-1.official").to("cuda") model.eval() condition = model.prepare_condition( { 'cond_scale': 4.0, 'video': np.random.randint(0, 255, (2, 16, 224, 224, 3)).astype(np.uint8) } ) output, memory = model(condition, input={ 'image': torch.zeros(2, 8, 128, 128, 3).to("cuda"), 'condition': condition }, state_in=model.initial_state(condition, 2) )