Tutorial: Pre-Training ROCKET-1 from Scratch#
Let’s break down the parameters required to train ROCKET-1 step by step.
We begin with the introduction of dependencies:
import hydra
import torch
import torch.nn as nn
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from einops import rearrange
from typing import Dict, Any, Tuple
from minestudio.data import MineDataModule
from minestudio.models import RocketPolicy
from minestudio.offline import MineLightning
from minestudio.offline.utils import convert_to_normal
from minestudio.offline.mine_callbacks import BehaviorCloneCallback
from minestudio.offline.lightning_callbacks import SmartCheckpointCallback, SpeedMonitorCallback, EMA
Note
We observed that the dependencies include SmartCheckpointCallback
and EMA
. The SmartCheckpointCallback
is responsible for saving model weights and checkpoints, while EMA
implements Exponential Moving Average. EMA is typically used during training to smooth model weights, enhancing the model’s generalization capability. It also maintains an additional copy of the model.
Next, we proceed to construct the policy model for ROCKET-1:
rocket_policy = RocketPolicy(
backbone="efficientnet_b0.ra_in1k",
hiddim=1024,
num_heads=8,
num_layers=4,
timesteps=128,
mem_len=128,
)
Note
The timesteps
parameter represents the maximum sequence length within a batch during training. The mem_len
parameter specifies the memory length supported by TransformerXL, which stores cached key
and value
data during both inference and training. These cached values are directly involved in the attention computation.
Next, configure the offline lightning module:
mine_lightning = MineLightning(
mine_policy=rocket_policy,
log_freq=20,
learning_rate=0.00004,
weight_decay=0.001,
warmup_steps=2000,
callbacks=[
BehaviorCloneCallback(weight=args.objective_weight),
],
hyperparameters=convert_to_normal(args),
)
Note
The learning_rate
and weight_decay
parameters are empirically set. The warmup_steps
parameter specifies the number of warmup steps for the learning rate, which is crucial when training a Transformer model from scratch.
The callbacks
parameter defines the optimization objectives. The MineStudio offline training module supports setting multiple optimization objectives. For example, BehaviorCloneCallback
is used to specify the optimization objective for behavior cloning.
The hyperparameters
parameter is used to define various hyperparameters. The convert_to_normal
function converts parameters from a Hydra configuration file into a standard dictionary format, which is then logged in wandb
.
Next, configure the dataloader:
mine_data = MineDataModule(
data_params=dict(
mode='raw',
dataset_dirs=[
'/nfs-shared-2/data/contractors/dataset_6xx',
'/nfs-shared-2/data/contractors/dataset_7xx',
'/nfs-shared-2/data/contractors/dataset_8xx',
'/nfs-shared-2/data/contractors/dataset_9xx',
'/nfs-shared-2/data/contractors/dataset_10xx'
],
frame_width=224,
frame_height=224,
win_len=128,
enable_segment=True,
),
batch_size=8,
num_workers=8,
prefetch_factor=4,
split_ratio=0.95,
shuffle_episodes=True,
episode_continuous_batch=True,
)
Note
The win_len
parameter defines the sequence length within a batch, which should be set to the same value as timesteps
.
Setting enable_segment=True
enables the use of semantic segmentation information from the data, which provides the conditional information required by the ROCKET-1 model.
The shuffle_episodes=True
parameter indicates that the trajectory arrangement in the original data will be shuffled, affecting the train-val
split.
Setting episode_continuous_batch=True
ensures that continuous episode segments are used as batches, which influences the sampling strategy of the dataloader.
Configure the Trainer
and initiate training:
callbacks=[
LearningRateMonitor(logging_interval='step'),
SpeedMonitorCallback(),
SmartCheckpointCallback(
dirpath='./weights', filename='weight-{epoch}-{step}', save_top_k=-1,
every_n_train_steps=args.save_freq, save_weights_only=True,
),
SmartCheckpointCallback(
dirpath='./checkpoints', filename='ckpt-{epoch}-{step}', save_top_k=1,
every_n_train_steps=args.save_freq+1, save_weights_only=False,
),
EMA(
decay=0.999,
validate_original_weights=True,
every_n_steps=8,
cpu_offload=False,
)
]
L.Trainer(
logger=WandbLogger(project="minestudio"),
devices=8,
precision='bf16',
strategy='ddp_find_unused_parameters_true',
use_distributed_sampler=not episode_continuous_batch,
callbacks=callbacks,
gradient_clip_val=1.0,
).fit(
model=mine_lightning,
datamodule=mine_data,
ckpt_path=ckpt_path,
)
To monitor the training process effectively, we use multiple callback functions.
LearningRateMonitor
records changes in the learning rate.SpeedMonitorCallback
tracks the training speed.SmartCheckpointCallback
saves model weights and checkpoints.EMA
implements Exponential Moving Average.
The parameter use_distributed_sampler=not episode_continuous_batch
indicates that when episode_continuous_batch=True
, the dataloader will automatically use our distributed batch sampler. While configuring the Trainer
, we must set use_distributed_sampler=False
.
Finally, we start the training process using the fit
function.
The
model
parameter corresponds to the configuredmine_lightning
.The
datamodule
parameter refers to the configuredmine_data
.The
ckpt_path
parameter specifies the checkpoint path. To resume training from a specific checkpoint, setckpt_path
to the path of that checkpoint. Otherwise, set it toNone
.