Offline API Documentation#
Trainer#
Date: 2024-11-10 13:44:13 LastEditors: muzhancun muzhancun@126.com LastEditTime: 2025-01-18 13:52:32 FilePath: /MineStudio/minestudio/offline/trainer.py
- class minestudio.offline.trainer.MineLightning(mine_policy: MinePolicy, callbacks: List[ObjectiveCallback] = [], hyperparameters: dict = {}, *, log_freq: int = 20, learning_rate: float = 1e-05, warmup_steps: int = 1000, weight_decay: float = 0.01)[source]#
A PyTorch Lightning module for training MinePolicy models.
This class encapsulates the training, validation, and optimization logic for MinePolicy models. It handles memory management, batch processing, and integration with ObjectiveCallbacks for custom training objectives.
- configure_optimizers()[source]#
Configures the optimizer and learning rate scheduler.
This function sets up an AdamW optimizer and a linear warmup learning rate scheduler.
- Returns:
A dictionary containing the optimizer and learning rate scheduler.
- Return type:
dict
- training_step(batch, batch_idx)[source]#
Performs a training step.
This function calls _batch_step with step_name=’train’.
- Parameters:
batch (dict) – The input batch data.
batch_idx (int) – The index of the current batch.
- Returns:
A dictionary containing the loss and other metrics.
- Return type:
dict
- validation_step(batch, batch_idx)[source]#
Performs a validation step.
This function calls _batch_step with step_name=’val’.
- Parameters:
batch (dict) – The input batch data.
batch_idx (int) – The index of the current batch.
- Returns:
A dictionary containing the loss and other metrics.
- Return type:
dict
- minestudio.offline.trainer.tree_detach(tree)[source]#
Detaches a tree of tensors from the computation graph.
This function recursively traverses a nested structure (dictionary or list) and detaches any PyTorch tensors it encounters. This is useful for preventing gradients from flowing back through the detached tensors.
- Parameters:
tree (dict | list | torch.Tensor) – The nested structure (dict, list, or tensor) to detach.
- Returns:
The detached tree.
- Return type:
dict | list | torch.Tensor
Utils#
Date: 2024-11-26 06:26:26 LastEditors: caishaofei caishaofei@stu.pku.edu.cn LastEditTime: 2024-11-26 06:28:27 FilePath: /MineStudio/minestudio/train/utils.py
- minestudio.offline.utils.convert_to_normal(obj)[source]#
Recursively converts OmegaConf DictConfig and ListConfig objects to standard Python dicts and lists.
This function is useful when working with configurations loaded by OmegaConf, as it allows you to convert them to native Python types for easier manipulation or serialization.
- Parameters:
obj (Any) – The object to convert. Can be a DictConfig, ListConfig, or any other type.
- Returns:
The converted object, with DictConfig and ListConfig instances replaced by dicts and lists respectively.
- Return type:
Any
Lightning Callbacks#
EMA Callback#
- class minestudio.offline.lightning_callbacks.ema.EMA(decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False)[source]#
Implements Exponential Moving Averaging (EMA) for PyTorch Lightning.
This callback maintains moving averages of the trained parameters during training. When evaluating or testing, it can swap the original weights with the EMA weights. When saving a checkpoint, it saves an additional set of parameters with the prefix ema.
- Parameters:
decay (float) – The exponential decay factor used for calculating the moving average. Must be between 0 and 1.
validate_original_weights (bool) – If True, validates the original weights instead of the EMA weights. Defaults to False.
every_n_steps (int) – Apply EMA every N training steps. Defaults to 1.
cpu_offload (bool) – If True, offloads EMA weights to CPU to save GPU memory. Defaults to False.
- Raises:
MisconfigurationException – If the decay value is not between 0 and 1.
- on_fit_start(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the fit begins.
Wraps the optimizers with EMAOptimizer.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
- on_load_checkpoint(trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) None [source]#
Called when a checkpoint is loaded.
Handles loading of EMA weights and optimizer states. If an EMA checkpoint (e.g., model-EMA.ckpt) is loaded, it treats the EMA weights as the main weights. If a regular checkpoint is loaded, it looks for an associated EMA checkpoint and restores the EMA state from it.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
checkpoint (Dict[str, Any]) – The loaded checkpoint dictionary.
- Raises:
MisconfigurationException – If a regular checkpoint is loaded but its associated EMA checkpoint is not found.
- on_test_end(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the test loop ends.
Swaps back to original weights if EMA weights were used for testing.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
- on_test_start(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the test loop begins.
Swaps to EMA weights if validate_original_weights is False.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
- on_validation_end(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the validation loop ends.
Swaps back to original weights if EMA weights were used for validation.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
- on_validation_start(trainer: Trainer, pl_module: LightningModule) None [source]#
Called when the validation loop begins.
Swaps to EMA weights if validate_original_weights is False.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
- save_ema_model(trainer: Trainer)[source]#
A context manager to save an EMA copy of the model and EMA optimizer states.
Temporarily swaps to EMA weights, yields, and then swaps back.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
- save_original_optimizer_state(trainer: Trainer)[source]#
A context manager to temporarily set the save_original_optimizer_state flag in EMAOptimizers.
This is used to ensure that the original optimizer state is saved instead of the EMA optimizer state.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
- swap_model_weights(trainer: Trainer, saving_ema_model: bool = False)[source]#
Swaps the model’s main parameters with the EMA parameters.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
saving_ema_model (bool) – If True, indicates that the EMA model is being saved. Defaults to False.
- class minestudio.offline.lightning_callbacks.ema.EMAOptimizer(optimizer: Optimizer, device: device, decay: float = 0.9999, every_n_steps: int = 1, current_step: int = 0)[source]#
Wraps a PyTorch optimizer to compute Exponential Moving Average (EMA) of model parameters.
EMA parameters are updated after each optimizer step using the formula: ema_weight = decay * ema_weight + (1 - decay) * training_weight
Use the swap_ema_weights() context manager to temporarily swap the model’s regular parameters with the EMA parameters, typically for evaluation.
Note
EMAOptimizer is not compatible with APEX AMP O2.
- Parameters:
optimizer (torch.optim.Optimizer) – The PyTorch optimizer to wrap.
device (torch.device) – The device to store EMA parameters on (e.g., ‘cuda’, ‘cpu’).
decay (float) – The EMA decay factor. Defaults to 0.9999.
every_n_steps (int) – Apply EMA update every N optimizer steps. Defaults to 1.
current_step (int) – The initial training step. Defaults to 0.
- add_param_group(param_group)[source]#
Adds a parameter group to the underlying optimizer.
Also flags that EMA parameters need to be rebuilt to include parameters from the new group.
- Parameters:
param_group (dict) – The parameter group to add.
- all_parameters() Iterable[Tensor] [source]#
Returns an iterator over all parameters managed by the optimizer.
- Returns:
An iterator over all parameters.
- Return type:
Iterable[torch.Tensor]
- load_state_dict(state_dict)[source]#
Loads the EMAOptimizer state.
Restores the state of the underlying optimizer, EMA parameters, current_step, decay, and every_n_steps.
- Parameters:
state_dict (dict) – The EMAOptimizer state dictionary to load.
- state_dict()[source]#
Returns the state of the EMAOptimizer.
Includes the state of the underlying optimizer, the EMA parameters, the current step, decay, and every_n_steps. If save_original_optimizer_state is True, only the original optimizer’s state is returned.
- Returns:
A dictionary containing the EMAOptimizer state.
- Return type:
dict
- step(closure=None, grad_scaler=None, **kwargs)[source]#
Performs a single optimization step.
This method calls the underlying optimizer’s step() method and then, if applicable, updates the EMA parameters.
- Parameters:
closure (callable, optional) – A closure that re-evaluates the model and returns the loss. Optional for most optimizers.
grad_scaler (torch.cuda.amp.GradScaler, optional) – A torch.cuda.amp.GradScaler instance for mixed-precision training. Defaults to None.
- Returns:
The loss computed by the closure, or None if no closure is provided.
- swap_ema_weights(enabled: bool = True)[source]#
A context manager to in-place swap regular model parameters with EMA parameters.
Swaps back to the original regular parameters upon exiting the context.
- Parameters:
enabled (bool) – If False, the swap is not performed. Defaults to True.
- swap_tensors(tensor1, tensor2)[source]#
Swaps the data of two tensors in-place.
- Parameters:
tensor1 (torch.Tensor) – The first tensor.
tensor2 (torch.Tensor) – The second tensor.
- switch_main_parameter_weights(saving_ema_model: bool = False)[source]#
Swaps the main model parameters with the EMA parameters.
This method is called by the EMA callback or the swap_ema_weights context manager.
- Parameters:
saving_ema_model (bool) – If True, indicates that the EMA model is being saved. This affects how state_dict behaves. Defaults to False.
- minestudio.offline.lightning_callbacks.ema.ema_update(ema_model_tuple, current_model_tuple, decay)[source]#
Performs the EMA update step.
Updates the EMA parameters using the formula: ema_weight = decay * ema_weight + (1 - decay) * current_weight
This function uses torch._foreach_mul_ and torch._foreach_add_ for efficient element-wise operations on tuples of tensors.
- Parameters:
ema_model_tuple (tuple[torch.Tensor]) – A tuple of EMA parameter tensors.
current_model_tuple (tuple[torch.Tensor]) – A tuple of current model parameter tensors.
decay (float) – The EMA decay factor.
- minestudio.offline.lightning_callbacks.ema.run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None)[source]#
Runs the EMA update on the CPU.
This function is typically used when EMA parameters are offloaded to the CPU. It synchronizes with a CUDA stream if provided, then calls ema_update.
- Parameters:
ema_model_tuple (tuple[torch.Tensor]) – A tuple of EMA parameter tensors.
current_model_tuple (tuple[torch.Tensor]) – A tuple of current model parameter tensors.
decay (float) – The EMA decay factor.
pre_sync_stream (torch.cuda.Stream | None) – A CUDA stream to synchronize with before the update. Defaults to None.
Smart Checkpoint Callback#
Date: 2024-11-28 15:37:18 LastEditors: muzhancun muzhancun@stu.pku.edu.cn LastEditTime: 2025-05-27 14:03:44 FilePath: /MineStudio/minestudio/offline/lightning_callbacks/smart_checkpoint.py
- class minestudio.offline.lightning_callbacks.smart_checkpoint.SmartCheckpointCallback(**kwargs)[source]#
A PyTorch Lightning ModelCheckpoint callback that is aware of EMA (Exponential Moving Average) weights.
This callback extends the standard ModelCheckpoint to also save and remove EMA weights if an EMA callback is present in the trainer. EMA checkpoints are saved with an ‘-EMA’ suffix before the file extension.
Speed Monitor Callback#
Date: 2024-11-28 15:35:51 LastEditors: caishaofei caishaofei@stu.pku.edu.cn LastEditTime: 2024-11-28 15:37:52 FilePath: /MineStudio/minestudio/train/lightning_callbacks/speed_monitor.py
- class minestudio.offline.lightning_callbacks.speed_monitor.SpeedMonitorCallback[source]#
A PyTorch Lightning callback to monitor training speed.
This callback logs the training speed in batches per second at regular intervals. It only logs on the global rank 0 process to avoid redundant logging in distributed training setups.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]#
Called at the end of each training batch.
Calculates and logs the training speed every INTERVAL batches on rank 0.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (pl.LightningModule) – The PyTorch Lightning LightningModule instance.
outputs (Any) – The outputs of the training step.
batch (Any) – The current batch of data.
batch_idx (int) – The index of the current batch.
Mine Callbacks#
Behavior Clone Callback#
Date: 2024-11-12 13:59:08 LastEditors: caishaofei caishaofei@stu.pku.edu.cn LastEditTime: 2024-12-09 15:51:34 FilePath: /MineStudio/minestudio/train/mine_callbacks/behavior_clone.py
- class minestudio.offline.mine_callbacks.behavior_clone.BehaviorCloneCallback(weight: float = 1.0)[source]#
A callback for behavior cloning.
This callback calculates the behavior cloning loss, which is the negative log-likelihood of the agent’s actions under the policy’s action distribution. It also calculates the entropy of the policy’s action distribution.
Callback#
Date: 2024-11-12 10:57:29 LastEditors: muzhancun muzhancun@126.com LastEditTime: 2025-01-18 13:53:17 FilePath: /MineStudio/minestudio/offline/mine_callbacks/callback.py
- class minestudio.offline.mine_callbacks.callback.ObjectiveCallback[source]#
Base class for objective callbacks used in MineLightning training.
Objective callbacks are used to define and calculate specific loss components or metrics during the training or validation step. Subclasses should implement the __call__ method to compute their specific objective.
- before_step(batch, batch_idx, step_name)[source]#
A hook called before the main batch step processing.
This can be used to modify the batch or perform other actions before the model forward pass and objective calculations.
- Parameters:
batch (Dict[str, Any]) – The input batch data.
batch_idx (int) – The index of the current batch.
step_name (str) – The name of the current step (e.g., ‘train’, ‘val’).
- Returns:
The (potentially modified) batch data.
- Return type:
Dict[str, Any]
KL Divergence Callback#
Date: 2024-12-12 13:10:58 LastEditors: muzhancun muzhancun@stu.pku.edu.cn LastEditTime: 2025-05-27 14:14:37 FilePath: /MineStudio/minestudio/offline/mine_callbacks/kl_divergence.py
- class minestudio.offline.mine_callbacks.kl_divergence.KLDivergenceCallback(weight: float = 1.0)[source]#
A callback to compute the KL divergence between two Gaussian distributions.
This callback is typically used in Variational Autoencoders (VAEs) or similar models where a prior distribution is regularized towards a posterior distribution. The KL divergence is calculated between a posterior (q) and a prior (p) distribution, both assumed to be Gaussian and defined by their means (mu) and log variances (log_var).
- kl_divergence(q_mu, q_log_var, p_mu, p_log_var)[source]#
Computes the KL divergence between two Gaussian distributions q and p.
KL(q || p) = -0.5 * sum(1 + log(sigma_q^2 / sigma_p^2) - (sigma_q^2 / sigma_p^2) - ((mu_q - mu_p)^2 / sigma_p^2)) where sigma^2 = exp(log_var).
- Parameters:
q_mu (torch.Tensor) – Mean of the posterior distribution q. Shape: (B, D)
q_log_var (torch.Tensor) – Log variance of the posterior distribution q. Shape: (B, D)
p_mu (torch.Tensor) – Mean of the prior distribution p. Shape: (B, D)
p_log_var (torch.Tensor) – Log variance of the prior distribution p. Shape: (B, D)
- Returns:
The KL divergence for each element in the batch. Shape: (B)
- Return type:
torch.Tensor