Offline Training#
Pre-training is a crucial approach for equipping policy models with diverse behaviors, as demonstrated in VPT. MineStudio supports pre-training with offline data, enabling users to easily perform pre-training through a straightforward configuration file.
Note
The MineStudio offline module is built on PyTorch Lightning, providing high flexibility and enabling users to customize it to suit their specific needs.
Quick Start#
Basic Arguments#
minestudio.offline.trainer.MineLightning
is the core class for offline training. It is a subclass of lightning.LightningModule
and provides a simple interface for users to customize their training process.
Arguments |
Description |
---|---|
|
The policy model to be trained. |
|
A list of objective callbacks to be used during training. |
|
A dictionary of hyperparameters to be logged to |
|
The frequency at which logs are uploaded to |
|
The learning rate for the optimizer. |
|
The weight decay for the optimizer. |
|
The number of warm-up steps for the learning rate scheduler. It is important to train transformer-like networks. |
Note
We use AdamW
as the default optimizer with a linear learning rate scheduler for warmup stage.
Long-Trajectory Training
Due to our advanced data structure, the offline trainer seamlessly supports long-trajectory training. By setting episode_continuous_batch=True
when creating the data module and implementing a memory-based policy, such as a TransformerXL-based policy, the trainer will automatically manage memory iteration for you.
Objective Callbacks#
The loss function is a key component that users often wish to customize when developing new algorithms. MineStudio standardizes this interface and offers a selection of built-in loss functions that users can utilize directly.
The objective callback template is simple:
class ObjectiveCallback:
def __init__(self):
...
def __call__(
self,
batch: Dict[str, Any],
batch_idx: int,
step_name: str,
latents: Dict[str, torch.Tensor],
mine_policy: MinePolicy
) -> Dict[str, torch.Tensor]:
return {
'loss': ...,
'other_key': ...,
}
Hint
latents
will be returned by the MinePolicy
object, so users can pass any objective-related information to the callback via latents
variable.
Warning
loss
term will be added to the final loss function, and all other keys will only be logged to the wandb
or other loggers.
Here are some examples of built-in objective callbacks:
Behavior Cloning Callback
The built-in minestudio.offline.mine_callbacks.BehaviorCloneCallback
looks like this:
class BehaviorCloneCallback(ObjectiveCallback):
def __init__(self, weight: float=1.0):
super().__init__()
self.weight = weight
def __call__(
self,
batch: Dict[str, Any],
batch_idx: int,
step_name: str,
latents: Dict[str, torch.Tensor],
mine_policy: MinePolicy,
) -> Dict[str, torch.Tensor]:
assert 'agent_action' in batch, "key `agent_action` is required for behavior cloning."
agent_action = batch['agent_action']
pi_logits = latents['pi_logits']
log_prob = mine_policy.pi_head.logprob(agent_action, pi_logits, return_dict=True)
entropy = mine_policy.pi_head.entropy(pi_logits, return_dict=True)
camera_mask = (agent_action['camera'] != 60).float().squeeze(-1)
global_mask = batch.get('mask', torch.ones_like(camera_mask))
logp_camera = (log_prob['camera'] * global_mask * camera_mask).sum(-1)
logp_buttons = (log_prob['buttons'] * global_mask).sum(-1)
entropy_camera = (entropy['camera'] * global_mask * camera_mask).sum(-1)
entropy_buttons = (entropy['buttons'] * global_mask).sum(-1)
camera_loss, button_loss = -logp_camera, -logp_buttons
bc_loss = camera_loss + button_loss
entropy = entropy_camera + entropy_buttons
result = {
'loss': bc_loss.mean() * self.weight,
'camera_loss': camera_loss.mean(),
'button_loss': button_loss.mean(),
'entropy': entropy.mean(),
'bc_weight': self.weight,
}
return result
While subclassing the MinePolicy
, one need to return latents (pi_logits
) in the forward function:
def forward(self, input, state_in, **kwargs):
B, T = input["image"].shape[:2]
first = torch.tensor([[False]], device=self.device).repeat(B, T)
state_in = self.initial_state(B) if state_in is None else state_in
(pi_h, v_h), state_out = self.net(input, state_in, context={"first": first})
pi_logits = self.pi_head(pi_h)
vpred = self.value_head(v_h)
latents = {'pi_logits': pi_logits, 'vpred': vpred}
return latents, state_out
Kullback–Leibler Divergence Callback
The built-in minestudio.offline.mine_callbacks.KLDivergenceCallback
looks like this:
class KLDivergenceCallback(ObjectiveCallback):
def __init__(self, weight: float=1.0):
super().__init__()
self.weight = weight
def __call__(
self,
batch: Dict[str, Any],
batch_idx: int,
step_name: str,
latents: Dict[str, torch.Tensor],
mine_policy: MinePolicy,
) -> Dict[str, torch.Tensor]:
posterior_dist = latents['posterior_dist']
prior_dist = latents['prior_dist']
q_mu, q_log_var = posterior_dist['mu'], posterior_dist['log_var']
p_mu, p_log_var = prior_dist['mu'], prior_dist['log_var']
kl_div = self.kl_divergence(q_mu, q_log_var, p_mu, p_log_var)
result = {
'loss': kl_div.mean() * self.weight,
'kl_div': kl_div.mean(),
'kl_weight': self.weight,
}
return result
def kl_divergence(self, q_mu, q_log_var, p_mu, p_log_var):
# shape: (B, D)
KL = -0.5 * torch.sum(
1 + (q_log_var - p_log_var) - (q_log_var - p_log_var).exp() - (q_mu - p_mu).pow(2) / p_log_var.exp(), dim=-1
) # shape: (B)
return KL
While subclassing the MinePolicy
, one need to return latents (posterior_dist
and prior_dist
) in the forward function: (taking GROOT’ forward function as an example)
def forward(self, input: Dict, memory: Optional[List[torch.Tensor]] = None) -> Dict:
...
posterior_dist = self.video_encoder(reference)
prior_dist = self.image_encoder(reference[:, 0])
...
x, memory = self.decoder(x, memory)
pi_h = v_h = x
pi_logits = self.pi_head(pi_h)
vpred = self.value_head(v_h)
latents = {
"pi_logits": pi_logits,
"vpred": vpred,
"posterior_dist": posterior_dist,
"prior_dist": prior_dist
}
return latents, memory