Models API Documentation#
BasePolicy#
Date: 2024-11-11 15:59:37 LastEditors: muzhancun muzhancun@stu.pku.edu.cn LastEditTime: 2025-05-26 21:41:12 FilePath: /MineStudio/minestudio/models/base_policy.py
- class minestudio.models.base_policy.MinePolicy(hiddim, action_space=None, temperature=1.0, nucleus_prob=None)[source]#
Abstract base class for Minecraft policies.
This class defines the basic interface for a policy, including methods for getting actions, computing initial states, and resetting parameters. It also handles batching and unbatching of inputs and states.
- Parameters:
hiddim – The hidden dimension size.
action_space – The action space of the environment. Defaults to a predefined Dict space for “camera” and “buttons” if None.
temperature – Temperature parameter for sampling actions from the policy head. Defaults to 1.0.
nucleus_prob – Nucleus probability for sampling actions. Defaults to None.
- property device: device#
Gets the device of the policy’s parameters.
- Returns:
The device (e.g., ‘cpu’, ‘cuda’).
- Return type:
torch.device
- abstract forward(input: Dict[str, Any], state_in: List[Tensor] | None = None, **kwargs) Tuple[Dict[str, Tensor], List[Tensor]] [source]#
Abstract method for the forward pass of the policy.
Subclasses must implement this method to define the policy’s computation.
- Parameters:
input (Dict[str, Any]) – A dictionary of input tensors.
state_in (Optional[List[torch.Tensor]]) – An optional list of input state tensors.
kwargs – Additional keyword arguments.
- Returns:
A tuple containing: - latents (Dict[str, torch.Tensor]): A dictionary containing pi_logits and vpred latent tensors. - state_out (List[torch.Tensor]): A list containing the updated state tensors.
- Return type:
Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
- get_action(input: Dict[str, Any], state_in: List[Tensor] | None, deterministic: bool = False, input_shape: str = 'BT*', **kwargs) Tuple[Dict[str, Tensor], List[Tensor]] [source]#
Gets an action from the policy.
This method performs a forward pass, samples an action, and handles different input shapes.
- Parameters:
input (Dict[str, Any]) – A dictionary of input tensors.
state_in (Optional[List[torch.Tensor]]) – An optional list of input state tensors.
deterministic (bool) – Whether to sample actions deterministically. Defaults to False.
input_shape (str) – The shape of the input. Can be “*” (single instance) or “BT*” (batched sequence). Defaults to “BT*”.
kwargs – Additional keyword arguments.
- Returns:
A tuple containing: - action (Dict[str, torch.Tensor]): The sampled action. - state_out (List[torch.Tensor]): The updated state.
- Return type:
Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
- Raises:
NotImplementedError – if input_shape is not “*” or “BT*”.
- abstract initial_state(batch_size: int | None = None) List[Tensor] [source]#
Abstract method to get the initial state of the policy.
Subclasses must implement this method.
- Parameters:
batch_size (Optional[int]) – The batch size for the initial state. Defaults to None.
- Returns:
A list of initial state tensors.
- Return type:
List[torch.Tensor]
- merge_input(inputs) tensor [source]#
Abstract method to merge multiple inputs.
Subclasses should implement this if they support merging inputs for, e.g., batched inference across multiple environments.
- Parameters:
inputs – The inputs to merge.
- Returns:
The merged input tensor.
- Return type:
torch.tensor
- Raises:
NotImplementedError – if not implemented by the subclass.
- merge_state(states) List[Tensor] | None [source]#
Abstract method to merge multiple states.
Subclasses should implement this if they support merging states.
- Parameters:
states – The states to merge.
- Returns:
The merged state as an optional list of tensors.
- Return type:
Optional[List[torch.Tensor]]
- Raises:
NotImplementedError – if not implemented by the subclass.
- split_action(action, split_num) List[Dict[str, Tensor]] | None [source]#
Splits a batched action into a list of individual actions.
Handles actions as dictionaries of tensors, single tensors, or lists. Converts tensors to numpy arrays on CPU.
- Parameters:
action – The batched action to split.
split_num – The number of individual actions in the batch.
- Returns:
A list of individual actions, or the original action list if already a list.
- Return type:
Optional[List[Dict[str, torch.Tensor]]]
- Raises:
NotImplementedError – if the action type is not supported.
- split_state(state, split_num) List[List[Tensor]] | None [source]#
Abstract method to split a state into multiple states.
Subclasses should implement this if they support splitting states.
- Parameters:
state – The state to split.
split_num – The number of ways to split the state.
- Returns:
An optional list of split states.
- Return type:
Optional[List[List[torch.Tensor]]]
- Raises:
NotImplementedError – if not implemented by the subclass.
- minestudio.models.base_policy.dict_map(fn, d)[source]#
Recursively apply a function to all values in a dictionary or DictConfig.
- Parameters:
fn – The function to apply.
d – The dictionary or DictConfig.
- Returns:
A new dictionary with the function applied to its values.
- minestudio.models.base_policy.recursive_tensor_op(fn, d: T) T [source]#
Recursively apply a function to all tensors in a nested structure of lists, tuples, or dictionaries.
- Parameters:
fn – The function to apply to tensors.
d (TypeVar("T")) – The nested structure containing tensors.
- Returns:
A new nested structure with the function applied to its tensors.
- Return type:
TypeVar(“T”)
- Raises:
ValueError – if an unexpected type is encountered.
Action Head#
- class minestudio.utils.vpt_lib.action_head.ActionHead(*args: Any, **kwargs: Any)[source]#
Abstract base class for action heads. Action heads are responsible for converting network outputs into action probability distributions and providing methods for sampling, calculating log probabilities, entropy, and KL divergence.
- entropy(pd_params)[source]#
Calculates the entropy of the probability distribution described by pd_params.
- Parameters:
pd_params (Any) – Parameters describing the probability distribution.
- Returns:
The entropy of the distribution.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- forward(input_data, **kwargs) Any [source]#
Performs a forward pass through the action head.
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
**kwargs – Additional keyword arguments.
- Returns:
Parameters describing the probability distribution of actions.
- Return type:
Any
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- kl_divergence(params_q, params_p)[source]#
Calculates the KL divergence between two distributions described by params_q and params_p. KL(Q || P).
- Parameters:
params_q (Any) – Parameters of the first distribution (Q).
params_p (Any) – Parameters of the second distribution (P).
- Returns:
The KL divergence between the two distributions.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- logprob(action_sample, pd_params, **kwargs)[source]#
Calculates the logarithm of the probability of sampling action_sample from a probability distribution described by pd_params.
- Parameters:
action_sample (Any) – The sampled action.
pd_params (Any) – Parameters describing the probability distribution.
**kwargs – Additional keyword arguments.
- Returns:
The log probability of the action sample.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- sample(pd_params, deterministic: bool = False) Any [source]#
Draws a sample from the probability distribution given by pd_params.
- Parameters:
pd_params (Any) – Parameters of a probability distribution.
deterministic (bool) – Whether to return a stochastic sample or the deterministic mode of the distribution.
- Returns:
A sampled action.
- Return type:
Any
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- class minestudio.utils.vpt_lib.action_head.CategoricalActionHead(*args: Any, **kwargs: Any)[source]#
Action head for categorical (discrete) actions. It uses a linear layer to produce logits for each action. Supports temperature scaling and nucleus sampling.
- entropy(logits: torch.Tensor) torch.Tensor [source]#
Calculates the entropy of the categorical distribution defined by logits. Entropy = - sum(probs * log_probs). The result is summed if the action space has multiple dimensions.
- Parameters:
logits (torch.Tensor) – The log probabilities (output of the forward pass). Shape: (…, *self.output_shape)
- Returns:
The entropy of the distribution. Shape: (…)
- Return type:
torch.Tensor
- forward(input_data: torch.Tensor, mask=None, **kwargs) Any [source]#
Computes the log probabilities (logits) for each action. Applies temperature scaling and masking if provided.
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
mask (Optional[torch.Tensor]) – An optional boolean mask. Logits for masked-out actions are set to a very small number (LOG0). Shape should be broadcastable to the logits shape before the num_actions dimension.
**kwargs – Additional keyword arguments.
- Returns:
Logits for each action after processing. Shape: (…, *self.output_shape)
- Return type:
torch.Tensor
- kl_divergence(logits_q: torch.Tensor, logits_p: torch.Tensor) torch.Tensor [source]#
Calculates the KL divergence KL(Q || P) between two categorical distributions Q and P, defined by their logits. Formula: sum(exp(Q_i) * (Q_i - P_i)). The result is summed if the action space has multiple dimensions.
- logprob(actions: torch.Tensor, logits: torch.Tensor) torch.Tensor [source]#
Calculates the log probability of the given actions based on the logits. It gathers the log probabilities corresponding to the chosen actions and sums them if the action space has multiple dimensions (e.g., for MultiDiscrete).
- Parameters:
- Returns:
The sum of log probabilities for the chosen actions. Shape: (…)
- Return type:
torch.Tensor
- nucleus_sample(logits: torch.Tensor, deterministic: bool = False, p: float = 0.85, **kwargs) Any [source]#
Samples an action using nucleus (top-p) sampling. It considers the smallest set of actions whose cumulative probability exceeds p. If deterministic, falls back to vanilla sampling with determinism.
- Parameters:
logits (torch.Tensor) – The log probabilities for each action. Shape: (…, *self.output_shape)
deterministic (bool) – If True, uses vanilla deterministic sampling.
p (float) – The cumulative probability threshold for nucleus sampling.
**kwargs – Additional keyword arguments (passed to vanilla_sample if deterministic).
- Returns:
A sampled action. Shape: (…, *self.output_shape[:-1])
- Return type:
torch.Tensor
- sample(logits: torch.Tensor, deterministic: bool = False, **kwargs) Any [source]#
Samples an action from the categorical distribution. Uses nucleus sampling if self.nucleus_prob is set, otherwise uses vanilla sampling.
- Parameters:
logits (torch.Tensor) – The log probabilities for each action. Shape: (…, *self.output_shape)
deterministic (bool) – If True, returns the most likely action. If False, returns a stochastic sample.
**kwargs – Additional keyword arguments for the specific sampling method.
- Returns:
A sampled action. Shape: (…, *self.output_shape[:-1])
- Return type:
torch.Tensor
- vanilla_sample(logits: torch.Tensor, deterministic: bool = False, **kwargs) Any [source]#
Samples an action from the categorical distribution using the Gumbel-Max trick for stochastic sampling, or argmax for deterministic sampling. This is the original sampling method from the VPT library.
- Parameters:
logits (torch.Tensor) – The log probabilities for each action. Shape: (…, *self.output_shape)
deterministic (bool) – If True, returns the action with the highest logit (argmax). If False, returns a stochastic sample using Gumbel-Max.
**kwargs – Additional keyword arguments (not used).
- Returns:
A sampled action. Shape: (…, *self.output_shape[:-1])
- Return type:
torch.Tensor
- class minestudio.utils.vpt_lib.action_head.DiagGaussianActionHead(*args: Any, **kwargs: Any)[source]#
Action head for normally distributed, uncorrelated continuous actions. Means are predicted by a linear layer, while standard deviations are learnable parameters.
- entropy(pd_params: torch.Tensor) torch.Tensor [source]#
Calculates the entropy of the Gaussian distribution. For a diagonal Gaussian, entropy is 0.5 * sum(log(2 * pi * e * sigma_i^2)).
- Parameters:
pd_params (torch.Tensor) – Parameters of the Gaussian distribution (means and log_stds). Shape: (…, num_dimensions, 2)
- Returns:
The entropy of the distribution. Shape: (…)
- Return type:
torch.Tensor
- forward(input_data: torch.Tensor, mask=None, **kwargs) torch.Tensor [source]#
Computes the means and log standard deviations of the Gaussian distribution.
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
mask (Optional[torch.Tensor]) – An optional mask (not used in this head).
**kwargs – Additional keyword arguments.
- Returns:
A tensor where the last dimension contains means and log_stds. Shape: (…, num_dimensions, 2)
- Return type:
torch.Tensor
- Raises:
AssertionError – If a mask is provided.
- kl_divergence(params_q: torch.Tensor, params_p: torch.Tensor) torch.Tensor [source]#
Calculates the KL divergence KL(Q || P) between two diagonal Gaussian distributions Q and P. Formula: log(sigma_p/sigma_q) + (sigma_q^2 + (mu_q - mu_p)^2) / (2 * sigma_p^2) - 0.5, summed over dimensions.
- Parameters:
params_q (torch.Tensor) – Parameters of the first Gaussian distribution Q (means_q, log_std_q). Shape: (…, num_dimensions, 2)
params_p (torch.Tensor) – Parameters of the second Gaussian distribution P (means_p, log_std_p). Shape: (…, num_dimensions, 2)
- Returns:
The KL divergence. Shape: (…, 1)
- Return type:
torch.Tensor
- logprob(action_sample: torch.Tensor, pd_params: torch.Tensor) torch.Tensor [source]#
Calculates the log-likelihood of action_sample given the distribution parameters. The distribution is a multivariate Gaussian with a diagonal covariance matrix.
- Parameters:
action_sample (torch.Tensor) – The sampled actions. Shape: (…, num_dimensions)
pd_params (torch.Tensor) – Parameters of the Gaussian distribution (means and log_stds). Shape: (…, num_dimensions, 2)
- Returns:
The log probability of the action samples. Shape: (…)
- Return type:
torch.Tensor
- sample(pd_params: torch.Tensor, deterministic: bool = False) torch.Tensor [source]#
Samples an action from the Gaussian distribution.
- Parameters:
pd_params (torch.Tensor) – Parameters of the Gaussian distribution (means and log_stds). Shape: (…, num_dimensions, 2)
deterministic (bool) – If True, returns the mean (mode) of the distribution. If False, returns a stochastic sample.
- Returns:
A sampled action. Shape: (…, num_dimensions)
- Return type:
torch.Tensor
- class minestudio.utils.vpt_lib.action_head.DictActionHead(*args: Any, **kwargs: Any)[source]#
An action head that combines multiple sub-action heads, where actions are structured as a dictionary. Each key-value pair in the dictionary corresponds to an action from a named sub-head. Inherits from nn.ModuleDict to manage sub-heads and ActionHead for the interface.
- entropy(logits: Dict[str, torch.Tensor], return_dict=False) torch.Tensor | Dict[str, torch.Tensor] [source]#
Calculates the entropy for each sub-distribution. Can return a dictionary of entropies or their sum.
- Parameters:
logits (Dict[str, torch.Tensor]) – A dictionary of pd_params from each sub-head, keyed by sub-head name.
return_dict (bool) – If True, returns a dictionary of entropies. If False, returns the sum of entropies.
- Returns:
Either a sum of entropies (Tensor) or a dictionary of entropies.
- Return type:
Union[torch.Tensor, Dict[str, torch.Tensor]]
- forward(input_data: torch.Tensor, **kwargs) Any [source]#
Passes input data through each sub-head. Allows passing specific keyword arguments to individual sub-heads based on their keys.
Example: If this ModuleDict has submodules keyed by ‘A’, ‘B’, and ‘C’, we could call: forward(input_data, foo={‘A’: True, ‘C’: False}, bar={‘A’: 7}) Then children will be called with:
A: subhead_A(input_data, foo=True, bar=7) B: subhead_B(input_data) C: subhead_C(input_data, foo=False)
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
**kwargs – Keyword arguments. If a kwarg’s value is a dictionary, its items are passed to sub-heads matching the keys.
- Returns:
A dictionary where keys are sub-head names and values are their outputs (pd_params).
- Return type:
Dict[str, Any]
- kl_divergence(logits_q: Dict[str, torch.Tensor], logits_p: Dict[str, torch.Tensor]) torch.Tensor [source]#
Calculates the KL divergence for each pair of sub-distributions (Q_k || P_k) and returns their sum.
- Parameters:
logits_q (Dict[str, torch.Tensor]) – A dictionary of parameters for the first set of distributions (Q).
logits_p (Dict[str, torch.Tensor]) – A dictionary of parameters for the second set of distributions (P).
- Returns:
The sum of KL divergences from all sub-heads.
- Return type:
torch.Tensor
- logprob(actions: Dict[str, torch.Tensor], logits: Dict[str, torch.Tensor], return_dict=False) torch.Tensor | Dict[str, torch.Tensor] [source]#
Calculates log probabilities for actions from each sub-head. Can return a dictionary of log probabilities or their sum.
- Parameters:
actions (Dict[str, torch.Tensor]) – A dictionary of sampled actions, keyed by sub-head name.
logits (Dict[str, torch.Tensor]) – A dictionary of pd_params from each sub-head, keyed by sub-head name.
return_dict (bool) – If True, returns a dictionary of log probabilities. If False, returns the sum of log probabilities.
- Returns:
Either a sum of log probabilities (Tensor) or a dictionary of log probabilities.
- Return type:
Union[torch.Tensor, Dict[str, torch.Tensor]]
- sample(logits: Dict[str, torch.Tensor], deterministic: bool = False) Any [source]#
Samples an action from each sub-head and returns a dictionary of these actions.
- Parameters:
logits (Dict[str, torch.Tensor]) – A dictionary of pd_params from each sub-head, keyed by sub-head name.
deterministic (bool) – Whether to perform deterministic sampling for each sub-head.
- Returns:
A dictionary of sampled actions, keyed by sub-head name.
- Return type:
Dict[str, Any]
- class minestudio.utils.vpt_lib.action_head.MSEActionHead(*args: Any, **kwargs: Any)[source]#
Action head for continuous actions where the loss is Mean Squared Error (MSE) between the predicted actions (means) and the target actions. This head essentially predicts the mean of a distribution with fixed, infinitesimal variance.
- entropy(pd_params: torch.Tensor) torch.Tensor [source]#
Returns zero entropy, as this head represents a deterministic prediction (delta distribution).
- Parameters:
pd_params (torch.Tensor) – The predicted mean actions.
- Returns:
A tensor of zeros with the same batch shape as pd_params. Shape: (…)
- Return type:
torch.Tensor
- forward(input_data: torch.Tensor, mask=None, **kwargs) torch.Tensor [source]#
Computes the predicted mean actions using a linear layer.
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
mask (Optional[torch.Tensor]) – An optional mask (not used in this head).
**kwargs – Additional keyword arguments.
- Returns:
The predicted mean actions. Shape: (…, num_dimensions)
- Return type:
torch.Tensor
- Raises:
AssertionError – If a mask is provided.
- kl_divergence(params_q: torch.Tensor, params_p: torch.Tensor) torch.Tensor [source]#
KL divergence is not well-defined for this action head in a general sense as it represents a delta distribution.
- Parameters:
params_q (torch.Tensor) – Parameters of the first distribution.
params_p (torch.Tensor) – Parameters of the second distribution.
- Raises:
NotImplementedError – This method is not implemented.
- logprob(action_sample: torch.Tensor, pd_params: torch.Tensor) torch.Tensor [source]#
Calculates a pseudo log-probability, which is the negative squared error. This is not a true log-probability but is used for compatibility in some RL frameworks.
- Parameters:
action_sample (torch.Tensor) – The target actions. Shape: (…, num_dimensions)
pd_params (torch.Tensor) – The predicted mean actions (output of the forward pass). Shape: (…, num_dimensions)
- Returns:
The negative sum of squared errors. Shape: (…)
- Return type:
torch.Tensor
- sample(pd_params: torch.Tensor, deterministic: bool = False, **kwargs) torch.Tensor [source]#
Returns the predicted mean actions, as this head is deterministic.
- Parameters:
pd_params (torch.Tensor) – The predicted mean actions (output of the forward pass). Shape: (…, num_dimensions)
deterministic (bool) – Ignored, as sampling is always deterministic.
**kwargs – Additional keyword arguments (not used).
- Returns:
The predicted mean actions. Shape: (…, num_dimensions)
- Return type:
torch.Tensor
- class minestudio.utils.vpt_lib.action_head.TupleActionHead(*args: Any, **kwargs: Any)[source]#
An action head that combines multiple sub-action heads, where actions are structured as a tuple. Each element of the tuple corresponds to an action from a sub-head. Inherits from nn.ModuleList to manage sub-heads and ActionHead for the interface.
- entropy(logits: Tuple[torch.Tensor]) torch.Tensor [source]#
Calculates the entropy for each sub-distribution and returns a tuple of entropies.
- Parameters:
logits (Tuple[torch.Tensor, ...]) – A tuple of probability distribution parameters from each sub-head.
- Returns:
A tuple of entropies, one for each sub-distribution.
- Return type:
Tuple[torch.Tensor, …]
- forward(input_data: torch.Tensor, **kwargs) Any [source]#
Passes the input data through each sub-head and returns a tuple of their outputs.
- Parameters:
input_data (torch.Tensor) – The input tensor from the policy network.
**kwargs – Additional keyword arguments (passed to each sub-head).
- Returns:
A tuple where each element is the output (pd_params) of a sub-head.
- Return type:
Tuple[Any, …]
- kl_divergence(logits_q: Tuple[torch.Tensor], logits_p: Tuple[torch.Tensor]) torch.Tensor [source]#
Calculates the KL divergence for each pair of sub-distributions (Q_k || P_k) and returns their sum.
- Parameters:
logits_q (Tuple[torch.Tensor, ...]) – A tuple of parameters for the first set of distributions (Q).
logits_p (Tuple[torch.Tensor, ...]) – A tuple of parameters for the second set of distributions (P).
- Returns:
The sum of KL divergences from all sub-heads.
- Return type:
torch.Tensor
- logprob(actions: Tuple[torch.Tensor], logits: Tuple[torch.Tensor]) torch.Tensor [source]#
Calculates the log probability for each action in the tuple using the corresponding sub-head and its logits. Returns a tuple of log probabilities.
- Parameters:
actions (Tuple[torch.Tensor, ...]) – A tuple of sampled actions, one for each sub-head.
logits (Tuple[torch.Tensor, ...]) – A tuple of probability distribution parameters (e.g., logits) from each sub-head.
- Returns:
A tuple of log probabilities, one for each sub-action.
- Return type:
Tuple[torch.Tensor, …]
- sample(logits: Tuple[torch.Tensor], deterministic: bool = False) Any [source]#
Samples an action from each sub-head and returns a tuple of these actions.
- Parameters:
logits (Tuple[torch.Tensor, ...]) – A tuple of probability distribution parameters from each sub-head.
deterministic (bool) – Whether to perform deterministic sampling for each sub-head.
- Returns:
A tuple of sampled actions.
- Return type:
Tuple[Any, …]
- minestudio.utils.vpt_lib.action_head.fan_in_linear(module: torch.nn.Module, scale=1.0, bias=True)[source]#
Initializes the weights of a linear module using fan-in initialization. The weights are scaled by scale / norm, where norm is the L2 norm of the weights. Biases are initialized to zero if bias is True.
- Parameters:
module (nn.Module) – The linear module to initialize.
scale (float) – The scaling factor for the weights.
bias (bool) – Whether to initialize biases to zero.
- minestudio.utils.vpt_lib.action_head.make_action_head(ac_space: gym3.types.ValType, pi_out_size: int, temperature: float = 1.0, **kwargs)[source]#
Helper function to create an appropriate action head based on the action space type. Supports gymnasium.spaces and some gym.spaces.
- Parameters:
ac_space (Union[gymnasium.spaces.Space, gym.spaces.Space, ValType]) – The action space of the environment.
pi_out_size (int) – The output size of the policy network feature extractor.
temperature (float) – Temperature for categorical action heads.
**kwargs – Additional keyword arguments to pass to the action head constructor.
- Returns:
An initialized action head.
- Return type:
- Raises:
NotImplementedError – If the action space type is not supported.
Value Head#
- class minestudio.utils.vpt_lib.scaled_mse_head.ScaledMSEHead(input_size: int, output_size: int, norm_type: str | None = 'ewma', norm_kwargs: Dict | None = None)[source]#
A linear output layer that normalizes its targets to have a mean of 0 and a standard deviation of 1. This is achieved by using an Exponential Moving Average (EWMA) normalizer on the targets before calculating the Mean Squared Error (MSE) loss. The predictions are made in the original (unnormalized) space, but the loss is computed in the normalized space.
- denormalize(input_data)[source]#
Converts an input value from the normalized space back into the original space using the inverse operation of the internal normalizer.
- Parameters:
input_data (torch.Tensor) – The data in the normalized space.
- Returns:
The data in the original space.
- Return type:
torch.Tensor
- forward(input_data)[source]#
Performs a forward pass through the linear layer.
- Parameters:
input_data (torch.Tensor) – The input tensor.
- Returns:
The output of the linear layer (predictions in the original space).
- Return type:
torch.Tensor
- loss(prediction, target, reduction='mean')[source]#
Calculates the Mean Squared Error (MSE) loss between the prediction and the target. The target is first normalized using the internal EWMA normalizer. The loss is computed in this normalized space.
- Parameters:
prediction (torch.Tensor) – The predicted output from the forward pass (in original space).
target (torch.Tensor) – The target values (in original space).
reduction (str) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed.
- Returns:
The MSE loss.
- Return type:
torch.Tensor
GrootOne#
Date: 2024-11-25 07:03:41 LastEditors: caishaofei caishaofei@stu.pku.edu.cn LastEditTime: 2025-01-07 14:21:06 FilePath: /MineStudio/minestudio/models/groot_one/body.py
- class minestudio.models.groot_one.body.Decoder(hiddim: int, num_heads: int = 8, num_layers: int = 4, timesteps: int = 128, mem_len: int = 128)[source]#
Decodes a sequence of latent vectors using a recurrent Transformer architecture.
This module is typically used to generate sequences for policy and value estimation.
- Parameters:
hiddim (int) – The hidden dimension of the model.
num_heads (int) – Number of attention heads in the recurrent Transformer blocks. Defaults to 8.
num_layers (int) – Number of recurrent Transformer blocks. Defaults to 4.
timesteps (int) – The number of timesteps the recurrent model processes at once. Defaults to 128.
mem_len (int) – The length of the memory used by the causal attention mechanism. Defaults to 128.
- forward(x: Tensor, memory: List) Tuple[Tensor, List] [source]#
Forward pass of the Decoder.
Processes the input sequence x using the recurrent Transformer blocks, updating the memory (recurrent state).
- Parameters:
x (torch.Tensor) – Input tensor of shape (b, t, c), where b=batch_size, t=sequence_length, c=features.
memory (List[torch.Tensor]) – The recurrent state from the previous step. If None, an initial state is created.
- Returns:
A tuple containing: - x (torch.Tensor): The output tensor of shape (b, t, c). - memory (List[torch.Tensor]): The updated recurrent state.
- Return type:
Tuple[torch.Tensor, List[torch.Tensor]]
- initial_state(batch_size: int | None = None) List[Tensor] [source]#
Returns the initial recurrent state for the decoder.
- Parameters:
batch_size (Optional[int]) – The batch size for the initial state. If None, returns state for batch_size=1. Defaults to None.
- Returns:
A list of tensors representing the initial recurrent state, moved to the model’s device.
- Return type:
List[torch.Tensor]
- class minestudio.models.groot_one.body.GrootPolicy(*args: Any, **kwargs: Any)[source]#
GrootPolicy model for Minecraft, combining visual encoders and a recurrent decoder.
This policy uses a pre-trained backbone (e.g., EfficientNet, ViT) to extract features from images. It has separate encoders for video sequences (reference trajectory) and single images (current observation). The features are fused and then processed by a recurrent decoder to produce policy and value outputs.
- Parameters:
backbone (str) – Name of the timm model to use as a backbone (e.g., ‘efficientnet_b0.ra_in1k’).
freeze_backbone (bool) – Whether to freeze the weights of the pre-trained backbone. Defaults to True.
hiddim (int) – Hidden dimension for the policy network. Defaults to 1024.
video_encoder_kwargs (Dict) – Keyword arguments for the VideoEncoder. Defaults to {}.
image_encoder_kwargs (Dict) – Keyword arguments for the ImageEncoder. Defaults to {}.
decoder_kwargs (Dict) – Keyword arguments for the Decoder. Defaults to {}.
action_space (Optional[Any]) – The action space definition. Passed to MinePolicy.
- encode_video(ref_video_path: str, resolution: Tuple[int, int] = (224, 224)) Dict [source]#
Encodes a reference video from a file path into prior and posterior latent distributions.
Reads a video file, extracts frames, preprocesses them using the backbone, and then uses the VideoEncoder and ImageEncoder to get latent distributions.
- Parameters:
ref_video_path (str) – Path to the reference video file.
resolution (Tuple[int, int]) – Target resolution (width, height) to reformat video frames. Defaults to (224, 224).
- Returns:
A dictionary containing: - ‘posterior_dist’ (Dict): Latent distribution from the VideoEncoder. - ‘prior_dist’ (Dict): Latent distribution from the ImageEncoder (using the first frame).
- Return type:
Dict[str, Dict[str, torch.Tensor]]
- forward(input: Dict, memory: List[Tensor] | None = None) Dict [source]#
Forward pass of the GrootPolicy.
Processes the current image observation. If a ref_video_path is provided in the input (inference mode), it encodes the reference video (or uses a cached encoding) to get a condition z. If not (training mode), z is derived from the current batch of images. The image features and z are fused and passed to the decoder to get policy and value outputs.
- Parameters:
input (Dict) – A dictionary of inputs. Expected to contain: - ‘image’ (torch.Tensor): Current image observations (b, t, h, w, c). - ‘ref_video_path’ (Optional[str] or Optional[List[str]]): Path to a reference video for conditioning (inference).
memory (Optional[List[torch.Tensor]]) – The recurrent state for the decoder. Defaults to None (initial state will be used).
- Returns:
A tuple containing: - latents (Dict): A dictionary with ‘pi_logits’, ‘vpred’, ‘posterior_dist’, and ‘prior_dist’. - memory (List[torch.Tensor]): The updated recurrent state from the decoder.
- Return type:
Tuple[Dict[str, Any], List[torch.Tensor]]
- initial_state(*args, **kwargs) Any [source]#
Returns the initial recurrent state for the policy (from the decoder).
- Parameters:
args – Positional arguments passed to the decoder’s initial_state method.
kwargs – Keyword arguments passed to the decoder’s initial_state method.
- Returns:
The initial recurrent state.
- Return type:
Any
- class minestudio.models.groot_one.body.ImageEncoder(hiddim: int, num_layers: int = 2, num_heads: int = 8, dropout: float = 0.1)[source]#
Encodes a single image into a latent distribution.
Uses a Transformer encoder for spatial pooling, followed by a LatentSpace module.
- Parameters:
hiddim (int) – The hidden dimension for the model.
num_layers (int) – Number of Transformer encoder layers for pooling. Defaults to 2.
num_heads (int) – Number of attention heads in Transformer layers. Defaults to 8.
dropout (float) – Dropout rate in Transformer layers. Defaults to 0.1.
- forward(image: Tensor) Dict [source]#
Encodes a batch of images.
- Parameters:
image (torch.Tensor) – A tensor of images with shape (b, c, h, w), where b=batch_size, c=channels, h=height, w=width.
- Returns:
A dictionary representing the latent distribution from LatentSpace, containing ‘mu’, ‘log_var’, and ‘z’.
- Return type:
Dict[str, torch.Tensor]
- class minestudio.models.groot_one.body.LatentSpace(hiddim: int)[source]#
A module for creating a latent space with mean and log variance.
This module takes an input tensor, projects it to a mean (mu) and a log variance (log_var), and then samples from the resulting Gaussian distribution during training. During evaluation, it returns the mean.
- Parameters:
hiddim (int) – The hidden dimension of the input and latent space.
- forward(x: Tensor) Tensor [source]#
Forward pass to compute latent variable z, its mean mu, and log variance log_var.
During training, z is sampled from the distribution. During evaluation, z is mu.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
A dictionary containing: - ‘mu’ (torch.Tensor): The mean of the latent distribution. - ‘log_var’ (torch.Tensor): The log variance of the latent distribution. - ‘z’ (torch.Tensor): The sampled (training) or mean (evaluation) latent variable.
- Return type:
Dict[str, torch.Tensor]
- sample(mu: Tensor, log_var: Tensor) Tensor [source]#
Samples from a Gaussian distribution defined by mu and log_var.
- Parameters:
mu (torch.Tensor) – The mean of the Gaussian distribution.
log_var (torch.Tensor) – The logarithm of the variance of the Gaussian distribution.
- Returns:
A tensor sampled from the N(mu, exp(log_var)) distribution.
- Return type:
torch.Tensor
- class minestudio.models.groot_one.body.VideoEncoder(hiddim: int, num_spatial_layers: int = 2, num_temporal_layers: int = 2, num_heads: int = 8, dropout: float = 0.1)[source]#
Encodes a sequence of video frames into a latent distribution.
It uses Transformer encoders for spatial pooling within frames and temporal encoding across frames, followed by a LatentSpace module to get a distribution.
- Parameters:
hiddim (int) – The hidden dimension for the model.
num_spatial_layers (int) – Number of Transformer encoder layers for spatial pooling. Defaults to 2.
num_temporal_layers (int) – Number of Transformer encoder layers for temporal encoding. Defaults to 2.
num_heads (int) – Number of attention heads in Transformer layers. Defaults to 8.
dropout (float) – Dropout rate in Transformer layers. Defaults to 0.1.
- forward(images: Tensor) Dict [source]#
Encodes a batch of video frames.
- Parameters:
images (torch.Tensor) – A tensor of video frames with shape (b, t, c, h, w), where b=batch_size, t=time_steps, c=channels, h=height, w=width.
- Returns:
A dictionary representing the latent distribution from LatentSpace, containing ‘mu’, ‘log_var’, and ‘z’.
- Return type:
Dict[str, torch.Tensor]
- minestudio.models.groot_one.body.load_groot_policy(ckpt_path: str | None = None)[source]#
Loads a GrootPolicy model.
If ckpt_path is provided, it loads the model from the checkpoint. Otherwise, it loads a pre-trained model from Hugging Face Hub.
- Parameters:
ckpt_path (Optional[str]) – Path to a .ckpt model checkpoint file. Defaults to None.
- Returns:
The loaded GrootPolicy model.
- Return type:
- minestudio.models.groot_one.body.peel_off(item: List | str) str [source]#
Recursively extracts a string from a potentially nested list of strings.
If the item is a list, it calls itself with the first element of the list. If the item is a string, it returns the string.
- Parameters:
item (Union[List, str]) – The item to peel, which can be a string or a list containing strings or other lists.
- Returns:
The innermost string found.
- Return type:
str
RocketOne#
Date: 2024-11-10 15:52:16 LastEditors: caishaofei-mus1 1744260356@qq.com LastEditTime: 2025-01-15 17:08:36 FilePath: /MineStudio/minestudio/models/rocket_one/body.py
- class minestudio.models.rocket_one.body.RocketPolicy(*args: Any, **kwargs: Any)[source]#
RocketPolicy model for Minecraft, using a Vision Transformer (ViT) backbone.
This policy processes an RGB image concatenated with an object mask. It uses a pre-trained ViT backbone for feature extraction, followed by Transformer-based pooling and recurrent blocks for temporal processing. It also incorporates an embedding for interaction types.
- Parameters:
backbone (str) – Name of the timm model to use as a backbone (e.g., ‘timm/vit_base_patch16_224.dino’). Defaults to ‘timm/vit_base_patch16_224.dino’.
hiddim (int) – Hidden dimension for the policy network. Defaults to 1024.
num_heads (int) – Number of attention heads in Transformer layers. Defaults to 8.
num_layers (int) – Number of recurrent Transformer blocks. Defaults to 4.
timesteps (int) – The number of timesteps the recurrent model processes at once. Defaults to 128.
mem_len (int) – The length of the memory used by the causal attention mechanism in recurrent blocks. Defaults to 128.
action_space (Optional[Any]) – The action space definition. Passed to MinePolicy. Defaults to None.
nucleus_prob (float) – Nucleus probability for sampling actions. Passed to MinePolicy. Defaults to 0.85.
- forward(input: Dict, memory: List[Tensor] | None = None) Dict [source]#
Forward pass of the RocketPolicy.
Processes the input image and segmentation mask, extracts features, and passes them through recurrent layers to produce policy and value predictions.
Input dictionary is expected to contain: - ‘image’: (b, t, h, w, c) tensor of RGB images. - ‘segment’ or ‘segmentation’: Dictionary containing:
‘obj_mask’: (b, t, h, w) tensor of object masks.
‘obj_id’: (b, t) tensor of object interaction type IDs.
- Parameters:
input (Dict) – Dictionary of input tensors.
memory (Optional[List[torch.Tensor]]) – Optional list of recurrent state tensors. If None, an initial state is used.
- Returns:
A tuple containing: - latents (Dict): Dictionary with ‘pi_logits’ and ‘vpred’. - memory (List[torch.Tensor]): Updated list of recurrent state tensors.
- Return type:
Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
- initial_state(batch_size: int | None = None) List[Tensor] [source]#
Returns the initial recurrent state for the policy.
- Parameters:
batch_size (Optional[int]) – The batch size for the initial state. If None, returns state for batch_size=1. Defaults to None.
- Returns:
A list of tensors representing the initial recurrent state, moved to the model’s device.
- Return type:
List[torch.Tensor]
- minestudio.models.rocket_one.body.load_rocket_policy(ckpt_path: str | None = None)[source]#
Loads a RocketPolicy model.
If ckpt_path is provided, it loads the model from the checkpoint. Otherwise, it loads a pre-trained model from Hugging Face Hub.
- Parameters:
ckpt_path (Optional[str]) – Path to a .ckpt model checkpoint file. Defaults to None.
- Returns:
The loaded RocketPolicy model.
- Return type:
SteveOne#
- class minestudio.models.steve_one.body.ImgObsProcess(cnn_outsize: int, output_size: int, dense_init_norm_kwargs: Dict = {}, init_norm_kwargs: Dict = {}, **kwargs)[source]#
Image observation processing using ImpalaCNN followed by a linear layer.
Processes image observations through an Impala CNN architecture and then applies a linear transformation to produce the final output embeddings.
- Parameters:
cnn_outsize – Output dimension of the Impala CNN
output_size – Output size of the final linear layer
dense_init_norm_kwargs – Initialization kwargs for linear FanInInitReLULayer
init_norm_kwargs – Initialization kwargs for 2D and 3D conv FanInInitReLULayer
- class minestudio.models.steve_one.body.ImgPreprocessing(img_statistics: str | None = None, scale_img: bool = True)[source]#
Image preprocessing module for normalization and scaling.
Normalizes incoming images using either provided statistics (mean/std) or simple scaling. Supports both statistical normalization from pre-computed statistics and basic scaling by a constant factor.
- Parameters:
img_statistics – Path to npz file containing mean and std statistics. If specified, normalize images using these statistics.
scale_img – If True and img_statistics not specified, scale images by 1/255.
- class minestudio.models.steve_one.body.MinecraftPolicy(recurrence_type='lstm', impala_width=1, impala_chans=(16, 32, 32), obs_processing_width=256, hidsize=512, single_output=False, img_shape=None, scale_input_img=True, only_img_input=False, init_norm_kwargs={}, impala_kwargs={}, input_shape=None, active_reward_monitors=None, img_statistics=None, first_conv_norm=False, diff_mlp_embedding=False, attention_mask_style='clipped_causal', attention_heads=8, attention_memory_size=2048, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, n_recurrence_layers=1, recurrence_is_residual=True, timesteps=None, use_pre_lstm_ln=True, mineclip_embed_dim=512, **unused_kwargs)[source]#
Neural network policy for Minecraft gameplay with multimodal inputs.
This policy combines visual and textual information processing through CNN and recurrent architectures. It supports various recurrence types including LSTM, masked LSTM, and transformer variants for sequential decision making.
The architecture processes images through ImpalaCNN, combines with MineCLIP embeddings for text conditioning, and uses recurrent layers for temporal modeling.
- Parameters:
recurrence_type – Type of recurrent architecture: - ‘multi_layer_lstm’: Multi-layer LSTM (no ragged batching support) - ‘multi_layer_bilstm’: Bidirectional multi-layer LSTM - ‘multi_masked_lstm’: Multi-layer LSTM with ragged batching support - ‘transformer’: Dense transformer architecture - ‘none’: No recurrence
init_norm_kwargs – Initialization kwargs for all FanInInitReLULayers
- forward(ob, state_in, context)[source]#
Forward pass through the Minecraft policy network.
Processes multimodal observations (images and MineCLIP embeddings) through CNN and recurrent layers to produce policy and value latent representations.
- Parameters:
ob – Observation dictionary containing ‘img’ and ‘mineclip_embed’ keys
state_in – Input recurrent state from previous timestep
context – Context dictionary containing ‘first’ episode flags
- Returns:
Tuple of ((pi_latent, vf_latent), state_out) where latents are policy and value representations and state_out is updated recurrent state
- class minestudio.models.steve_one.body.SteveOnePolicy(*args: Any, **kwargs: Any)[source]#
Complete STEVE-1 policy combining visual processing, text conditioning, and action prediction.
This policy integrates MineCLIP for multimodal understanding, a TranslatorVAE for text-to-visual translation, and a MinecraftPolicy network for decision making. It supports both text and video conditioning through classifier-free guidance.
The architecture enables text-conditioned gameplay by translating textual instructions into visual embeddings that guide the policy’s behavior in Minecraft environments.
- property device: device#
Get the device of the policy model.
- Returns:
torch.device where the model parameters are located
- forward(input: Dict[str, Any], state_in: List[Tensor] | None = None, **kwargs) Tuple[Dict[str, Tensor], List[Tensor]] [source]#
Forward pass with classifier-free guidance for conditioned generation.
Performs a forward pass through the policy network with optional classifier-free guidance. When cond_scale > 0, runs both conditioned and unconditioned inference and combines the outputs using the specified guidance scale.
- Parameters:
condition – Dictionary with ‘cond_scale’ and ‘mineclip_embeds’
input – Dictionary with ‘image’ key containing observation images
state_in – Optional list of recurrent state tensors from previous timestep
kwargs – Additional keyword arguments (unused)
- Returns:
Tuple of (latents_dict, state_out) where latents_dict contains ‘pi_logits’ and ‘vpred’ and state_out is updated recurrent state
- initial_state(condition: Dict[str, Any], batch_size: int | None = None) List[Tensor] [source]#
Initialize recurrent state for policy inference.
Creates initial state tensors for the recurrent layers. When classifier-free guidance is enabled (cond_scale > 0), duplicates states for both conditioned and unconditioned inference branches.
- Parameters:
condition – Conditioning dictionary with ‘cond_scale’ key
batch_size – Batch size for state initialization
- Returns:
List of initial state tensors
- prepare_condition(instruction: Dict[str, Any], deterministic: bool = False) Dict[str, Any] [source]#
Prepare conditioning information from text or video instructions.
Processes either text instructions (via TranslatorVAE) or video demonstrations (via direct MineCLIP encoding) to create conditioning embeddings for the policy.
- Parameters:
instruction – Dictionary containing either ‘text’ or ‘video’ key plus ‘cond_scale’
deterministic – Whether to use deterministic sampling for text conditioning
- Returns:
Dictionary with ‘cond_scale’ and ‘mineclip_embeds’ for policy conditioning
- Raises:
AssertionError – If instruction lacks required keys or has conflicting modalities
- class minestudio.models.steve_one.body.TranslatorVAE(input_dim=512, hidden_dim=256, latent_dim=256)[source]#
Variational Autoencoder for translating between text and visual embeddings.
This VAE learns to map text embeddings to visual embeddings through a latent space, enabling text-to-visual translation for multimodal learning tasks. The encoder takes concatenated visual and text embeddings to produce latent representations, while the decoder reconstructs visual embeddings from latent codes and text.
- decode(latent_vector, text_embeddings)[source]#
Decode latent vector and text embeddings into visual embeddings.
- Parameters:
latent_vector – Latent representation tensor of shape (B, latent_dim)
text_embeddings – Text embedding tensor of shape (B, input_dim)
- Returns:
Reconstructed visual embeddings of shape (B, input_dim)
- encode(visual_embeddings, text_embeddings)[source]#
Encode concatenated visual and text embeddings into latent parameters.
- Parameters:
visual_embeddings – Visual embedding tensor of shape (B, input_dim)
text_embeddings – Text embedding tensor of shape (B, input_dim)
- Returns:
Encoded tensor containing mean and log variance (B, 2*latent_dim)
- forward(text_embeddings, deterministic=False)[source]#
Generate visual embeddings from text embeddings using prior distribution.
Uses a zero-mean, unit-variance prior to sample latent vectors and decode them into visual embeddings conditioned on the input text embeddings.
- Parameters:
text_embeddings – Input text embeddings of shape (B, input_dim)
deterministic – If True, use mean of prior; if False, sample from prior
- Returns:
Generated visual embeddings of shape (B, input_dim)
- sample(mu, logvar)[source]#
Sample a latent vector using the reparameterization trick.
Applies the reparameterization trick to sample from a Gaussian distribution defined by mu and logvar, enabling backpropagation through stochastic sampling.
- Parameters:
mu – Mean tensor of shape (B, latent_dim)
logvar – Log variance tensor of shape (B, latent_dim)
- Returns:
Sampled latent vector of shape (B, latent_dim)
- minestudio.models.steve_one.body.load_steve_one_policy(ckpt_path: str) SteveOnePolicy [source]#
Load a pre-trained STEVE-1 policy from a checkpoint path.
Convenience function to load a STEVE-1 policy using the HuggingFace Hub model loading mechanism.
- Parameters:
ckpt_path – Path or HuggingFace model identifier for the checkpoint
- Returns:
Loaded SteveOnePolicy instance
VPT#
Date: 2024-11-11 20:54:15 LastEditors: muzhancun muzhancun@stu.pku.edu.cn LastEditTime: 2025-05-28 14:29:56 FilePath: /MineStudio/minestudio/models/vpt/body.py
- class minestudio.models.vpt.body.ImgObsProcess(cnn_outsize: int, output_size: int, dense_init_norm_kwargs: Dict = {}, init_norm_kwargs: Dict = {}, **kwargs)[source]#
ImpalaCNN followed by a linear layer.
- Parameters:
cnn_outsize – impala output dimension
output_size – output size of the linear layer.
dense_init_norm_kwargs – kwargs for linear FanInInitReLULayer
init_norm_kwargs – kwargs for 2d and 3d conv FanInInitReLULayer
- class minestudio.models.vpt.body.ImgPreprocessing(img_statistics: str | None = None, scale_img: bool = True)[source]#
Normalize incoming images.
- Parameters:
img_statistics – remote path to npz file with a mean and std image. If specified normalize images using this.
scale_img – If true and img_statistics not specified, scale incoming images by 1/255.
- forward(img)[source]#
Apply image preprocessing.
Normalizes the input image tensor. If img_statistics was provided during initialization, it uses the mean and std from the file. Otherwise, it scales the image by 1.0 / self.ob_scale.
- Parameters:
img (torch.Tensor) – The input image tensor.
- Returns:
The preprocessed image tensor.
- Return type:
torch.Tensor
- class minestudio.models.vpt.body.MinecraftPolicy(recurrence_type='transformer', impala_width=1, impala_chans=(16, 32, 32), obs_processing_width=256, hidsize=512, single_output=False, img_shape=None, scale_input_img=True, only_img_input=False, init_norm_kwargs={}, impala_kwargs={}, input_shape=None, active_reward_monitors=None, img_statistics=None, first_conv_norm=False, diff_mlp_embedding=False, attention_mask_style='clipped_causal', attention_heads=8, attention_memory_size=2048, use_pointwise_layer=True, pointwise_ratio=4, pointwise_use_activation=False, n_recurrence_layers=1, recurrence_is_residual=True, timesteps=None, use_pre_lstm_ln=True, **unused_kwargs)[source]#
- Parameters:
recurrence_type –
None - No recurrence, adds no extra layers lstm - (Depreciated). Singular LSTM multi_layer_lstm - Multi-layer LSTM. Uses n_recurrence_layers to determine number of consecututive LSTMs
Does NOT support ragged batching
- multi_masked_lstm - Multi-layer LSTM that supports ragged batching via the first vector. This model is slower
Uses n_recurrence_layers to determine number of consecututive LSTMs
transformer - Dense transformer
init_norm_kwargs – kwargs for all FanInInitReLULayers.
- forward(ob, state_in, context)[source]#
Forward pass of the MinecraftPolicy.
Processes image observations, passes them through recurrent layers, and produces latent representations.
- Parameters:
ob (Dict[str, torch.Tensor]) – Dictionary of observations, expected to contain “image”.
state_in (Any # Type depends on recurrence_type) – Input recurrent state.
context (Dict[str, torch.Tensor]) – Context dictionary, expected to contain “first” (a tensor indicating episode starts).
- Returns:
A tuple containing: - pi_latent_or_tuple (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
If single_output is True, this is a single tensor for both policy and value. Otherwise, it’s a tuple (pi_latent, vf_latent).
state_out (Any): Output recurrent state.
- Return type:
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], Any]
- class minestudio.models.vpt.body.VPTPolicy(*args: Any, **kwargs: Any)[source]#
VPT (Video PreTraining) Policy.
This class wraps the MinecraftPolicy network and integrates it with the MinePolicy base class, providing methods for action selection, initial state, and state/input merging/splitting for batched online inference. It also supports loading from Hugging Face Hub.
- Parameters:
policy_kwargs (Dict) – Keyword arguments to initialize the MinecraftPolicy (self.net).
action_space (Optional[gymnasium.spaces.Space]) – The action space of the environment. Passed to MinePolicy constructor. Defaults to None.
kwargs – Additional keyword arguments passed to the MinePolicy constructor (e.g., temperature).
- forward(input, state_in, **kwargs)[source]#
Forward pass of the VPTPolicy.
Takes observations and recurrent state, passes them through the underlying MinecraftPolicy network, and then through policy and value heads.
- Parameters:
input (Dict[str, torch.Tensor]) – Dictionary of input observations, expected to contain “image”. The “image” tensor should have shape (B, T, H, W, C) or similar.
state_in (Optional[List[torch.Tensor]]) – Input recurrent state. If None, an initial state is generated.
kwargs – Additional keyword arguments (not directly used in this method but part of signature).
- Returns:
A tuple containing: - latents (Dict[str, torch.Tensor]): Dictionary with ‘pi_logits’ and ‘vpred’. - state_out (List[torch.Tensor]): Output recurrent state.
- Return type:
Tuple[Dict[str, torch.Tensor], List[torch.Tensor]]
- initial_state(batch_size: int | None = None)[source]#
Get the initial recurrent state for a given batch size.
Caches initial states for frequently used batch sizes.
- Parameters:
batch_size (Optional[int]) – The batch size. If None, returns state for batch size 1 (squeezed). Defaults to None.
- Returns:
A list of initial state tensors for the recurrent network, moved to the correct device.
- Return type:
List[torch.Tensor]
- merge_input(inputs) tensor [source]#
Merge a list of individual inputs into a single batched input tensor.
Handles inputs where “image” is 3D (single frame) or 4D (already batched/sequence). All inputs are moved to the policy’s device.
- Parameters:
inputs (List[Dict[str, Any]] # Values are typically np.ndarray or torch.Tensor) – A list of input dictionaries, each expected to have an “image” key.
- Returns:
A batched input dictionary with “image” as a torch.Tensor.
- Return type:
Dict[str, torch.Tensor]
- merge_state(states) List[Tensor] | None [source]#
Merge a list of individual recurrent states into a single batched state.
Concatenates corresponding state tensors along the batch dimension.
- Parameters:
states (List[List[torch.Tensor]]) – A list of recurrent states. Each state is a list of tensors. Example: [[s1_env1, s2_env1, …], [s1_env2, s2_env2, …], …]
- Returns:
The batched recurrent state, where each tensor is a concatenation of the corresponding tensors from the input states.
- Return type:
Optional[List[torch.Tensor]]
- split_state(states, split_num) List[List[Tensor]] | None [source]#
Split a batched recurrent state into a list of individual states.
- Parameters:
states (List[torch.Tensor]) – The batched recurrent state (a list of tensors, where the first dimension of each tensor is the batch size).
split_num (int) – The number of individual states to split into (should match batch size).
- Returns:
A list of individual recurrent states. Each state in the list corresponds to one item from the original batch. Example: [[s1_item1, s2_item1, …], [s1_item2, s2_item2, …], …]
- Return type:
Optional[List[List[torch.Tensor]]]
- minestudio.models.vpt.body.load_vpt_policy(model_path: str, weights_path: str | None = None)[source]#
Load a VPTPolicy model.
Can load from a local pickle file and optionally apply weights from a .ckpt file, or load a pretrained model from Hugging Face Hub if model_path is None.
- Parameters:
model_path (Optional[str]) – Path to the .model pickle file containing policy configuration. If None, attempts to load from Hugging Face Hub.
weights_path (Optional[str]) – Path to the .ckpt file containing model weights. If None, weights are not loaded separately (e.g., if part of .model or Hub model). Defaults to None.
- Raises:
ValueError – If model_path is None and ckpt_path (internal, seems like a typo for weights_path in the original conditional logic, but repo_id is used if weights_path is also None) is also None, and no default repo_id is hit.
- Returns:
The loaded VPTPolicy model.
- Return type:
Utils#
Download#
Date: 2024-12-14 01:46:36 LastEditors: muzhancun muzhancun@126.com LastEditTime: 2024-12-14 02:00:17 FilePath: /MineStudio/minestudio/models/utils/download.py
- minestudio.models.utils.download.download_model(model_name: str, local_dir: str = 'downloads') str [source]#
Downloads a specified model from Hugging Face Hub if it doesn’t exist locally.
Prompts the user for confirmation before downloading. The model is saved to a subdirectory within local_dir named after model_name.
- Parameters:
model_name (str) – The name of the model to download. Valid names are “ROCKET-1”, “VPT”, “GROOT”, “STEVE-1”.
local_dir (str) – The base directory to save downloaded models. Defaults to “downloads”. This will be relative to the parent directory of this script’s location.
- Returns:
The local path to the downloaded model directory if successful or if download was skipped by user choice but the directory was expected to exist (though it might not in that case). Returns None if download is skipped and directory doesn’t exist, or if an error occurs.
- Return type:
str
- Raises:
AssertionError – if model_name is not one of the recognized names.