Offline Training#

Pre-training is a crucial approach for equipping policy models with diverse behaviors, as demonstrated in VPT. MineStudio supports pre-training with offline data, enabling users to easily perform pre-training through a straightforward configuration file.

Note

The MineStudio offline module is built on PyTorch Lightning, providing high flexibility and enabling users to customize it to suit their specific needs.

Quick Start#

Basic Arguments#

minestudio.offline.trainer.MineLightning is the core class for offline training. It is a subclass of lightning.LightningModule and provides a simple interface for users to customize their training process.

Arguments

Description

mine_policy

The policy model to be trained.

callbacks

A list of objective callbacks to be used during training.

hyperparameters

A dictionary of hyperparameters to be logged to wandb.

log_freq

The frequency at which logs are uploaded to wandb.

learning_rate

The learning rate for the optimizer.

weight_decay

The weight decay for the optimizer.

warmup_steps

The number of warm-up steps for the learning rate scheduler. It is important to train transformer-like networks.

Note

We use AdamW as the default optimizer with a linear learning rate scheduler for warmup stage.

Long-Trajectory Training

Due to our advanced data structure, the offline trainer seamlessly supports long-trajectory training. By setting episode_continuous_batch=True when creating the data module and implementing a memory-based policy, such as a TransformerXL-based policy, the trainer will automatically manage memory iteration for you.

Objective Callbacks#

The loss function is a key component that users often wish to customize when developing new algorithms. MineStudio standardizes this interface and offers a selection of built-in loss functions that users can utilize directly.

The objective callback template is simple:

class ObjectiveCallback:

    def __init__(self):
        ...

    def __call__(
        self, 
        batch: Dict[str, Any], 
        batch_idx: int, 
        step_name: str, 
        latents: Dict[str, torch.Tensor], 
        mine_policy: MinePolicy
    ) -> Dict[str, torch.Tensor]:
        return {
            'loss': ..., 
            'other_key': ...,
        }

Hint

latents will be returned by the MinePolicy object, so users can pass any objective-related information to the callback via latents variable.

Warning

loss term will be added to the final loss function, and all other keys will only be logged to the wandb or other loggers.

Here are some examples of built-in objective callbacks:

Behavior Cloning Callback

The built-in minestudio.offline.mine_callbacks.BehaviorCloneCallback looks like this:

class BehaviorCloneCallback(ObjectiveCallback):

    def __init__(self, weight: float=1.0):
        super().__init__()
        self.weight = weight

    def __call__(
        self, 
        batch: Dict[str, Any], 
        batch_idx: int, 
        step_name: str, 
        latents: Dict[str, torch.Tensor], 
        mine_policy: MinePolicy, 
    ) -> Dict[str, torch.Tensor]:
        assert 'agent_action' in batch, "key `agent_action` is required for behavior cloning."
        agent_action = batch['agent_action']
        pi_logits = latents['pi_logits']
        log_prob = mine_policy.pi_head.logprob(agent_action, pi_logits, return_dict=True)
        entropy  = mine_policy.pi_head.entropy(pi_logits, return_dict=True)
        camera_mask = (agent_action['camera'] != 60).float().squeeze(-1)
        global_mask = batch.get('mask', torch.ones_like(camera_mask))
        logp_camera = (log_prob['camera'] * global_mask * camera_mask).sum(-1)
        logp_buttons = (log_prob['buttons'] * global_mask).sum(-1)
        entropy_camera  = (entropy['camera'] * global_mask * camera_mask).sum(-1)
        entropy_buttons = (entropy['buttons'] * global_mask).sum(-1)
        camera_loss, button_loss = -logp_camera, -logp_buttons
        bc_loss = camera_loss + button_loss
        entropy = entropy_camera + entropy_buttons
        result = {
            'loss': bc_loss.mean() * self.weight,
            'camera_loss': camera_loss.mean(),
            'button_loss': button_loss.mean(),
            'entropy': entropy.mean(),
            'bc_weight': self.weight,
        }
        return result

While subclassing the MinePolicy, one need to return latents (pi_logits) in the forward function:

def forward(self, input, state_in, **kwargs):
    B, T = input["image"].shape[:2]
    first = torch.tensor([[False]], device=self.device).repeat(B, T)
    state_in = self.initial_state(B) if state_in is None else state_in
    (pi_h, v_h), state_out = self.net(input, state_in, context={"first": first})
    pi_logits = self.pi_head(pi_h)
    vpred = self.value_head(v_h)
    latents = {'pi_logits': pi_logits, 'vpred': vpred}
    return latents, state_out
Kullback–Leibler Divergence Callback

The built-in minestudio.offline.mine_callbacks.KLDivergenceCallback looks like this:

class KLDivergenceCallback(ObjectiveCallback):
        
    def __init__(self, weight: float=1.0):
        super().__init__()
        self.weight = weight

    def __call__(
        self, 
        batch: Dict[str, Any], 
        batch_idx: int, 
        step_name: str, 
        latents: Dict[str, torch.Tensor], 
        mine_policy: MinePolicy, 
    ) -> Dict[str, torch.Tensor]:
        posterior_dist = latents['posterior_dist']
        prior_dist = latents['prior_dist']
        q_mu, q_log_var = posterior_dist['mu'], posterior_dist['log_var']
        p_mu, p_log_var = prior_dist['mu'], prior_dist['log_var']
        kl_div = self.kl_divergence(q_mu, q_log_var, p_mu, p_log_var)
        result = {
            'loss': kl_div.mean() * self.weight,
            'kl_div': kl_div.mean(),
            'kl_weight': self.weight,
        }
        return result

    def kl_divergence(self, q_mu, q_log_var, p_mu, p_log_var):
        # shape: (B, D)
        KL = -0.5 * torch.sum(
            1 + (q_log_var - p_log_var) - (q_log_var - p_log_var).exp() - (q_mu - p_mu).pow(2) / p_log_var.exp(), dim=-1
        ) # shape: (B)
        return KL

While subclassing the MinePolicy, one need to return latents (posterior_dist and prior_dist) in the forward function: (taking GROOT’ forward function as an example)

def forward(self, input: Dict, memory: Optional[List[torch.Tensor]] = None) -> Dict:
    ...
    posterior_dist = self.video_encoder(reference)
    prior_dist = self.image_encoder(reference[:, 0])
    ...
    x, memory = self.decoder(x, memory)
    pi_h = v_h = x
    pi_logits = self.pi_head(pi_h)
    vpred = self.value_head(v_h)
    latents = {
        "pi_logits": pi_logits, 
        "vpred": vpred, 
        "posterior_dist": posterior_dist, 
        "prior_dist": prior_dist
    }
    return latents, memory