Tutorial: Pre-Training ROCKET-1 from Scratch#

Let’s break down the parameters required to train ROCKET-1 step by step.

We begin with the introduction of dependencies:

import hydra
import torch
import torch.nn as nn
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from einops import rearrange
from typing import Dict, Any, Tuple

from minestudio.data import MineDataModule
from minestudio.models import RocketPolicy
from minestudio.offline import MineLightning
from minestudio.offline.utils import convert_to_normal
from minestudio.offline.mine_callbacks import BehaviorCloneCallback
from minestudio.offline.lightning_callbacks import SmartCheckpointCallback, SpeedMonitorCallback, EMA

Note

We observed that the dependencies include SmartCheckpointCallback and EMA. The SmartCheckpointCallback is responsible for saving model weights and checkpoints, while EMA implements Exponential Moving Average. EMA is typically used during training to smooth model weights, enhancing the model’s generalization capability. It also maintains an additional copy of the model.

Next, we proceed to construct the policy model for ROCKET-1:

rocket_policy = RocketPolicy(
    backbone="efficientnet_b0.ra_in1k",
    hiddim=1024,
    num_heads=8,
    num_layers=4,
    timesteps=128,
    mem_len=128,
)

Note

The timesteps parameter represents the maximum sequence length within a batch during training. The mem_len parameter specifies the memory length supported by TransformerXL, which stores cached key and value data during both inference and training. These cached values are directly involved in the attention computation.

Next, configure the offline lightning module:

mine_lightning = MineLightning(
    mine_policy=rocket_policy, 
    log_freq=20,
    learning_rate=0.00004,
    weight_decay=0.001,
    warmup_steps=2000,
    callbacks=[
        BehaviorCloneCallback(weight=args.objective_weight),
    ], 
    hyperparameters=convert_to_normal(args),
)

Note

The learning_rate and weight_decay parameters are empirically set. The warmup_steps parameter specifies the number of warmup steps for the learning rate, which is crucial when training a Transformer model from scratch.

The callbacks parameter defines the optimization objectives. The MineStudio offline training module supports setting multiple optimization objectives. For example, BehaviorCloneCallback is used to specify the optimization objective for behavior cloning.

The hyperparameters parameter is used to define various hyperparameters. The convert_to_normal function converts parameters from a Hydra configuration file into a standard dictionary format, which is then logged in wandb.

Next, configure the dataloader:

mine_data = MineDataModule(
    data_params=dict(
        mode='raw',
        dataset_dirs=[
            '/nfs-shared-2/data/contractors/dataset_6xx',
            '/nfs-shared-2/data/contractors/dataset_7xx',
            '/nfs-shared-2/data/contractors/dataset_8xx',
            '/nfs-shared-2/data/contractors/dataset_9xx',
            '/nfs-shared-2/data/contractors/dataset_10xx'
        ],
        frame_width=224,
        frame_height=224,
        win_len=128,
        enable_segment=True,
    ),
    batch_size=8,
    num_workers=8,
    prefetch_factor=4,
    split_ratio=0.95, 
    shuffle_episodes=True,
    episode_continuous_batch=True,
)

Note

The win_len parameter defines the sequence length within a batch, which should be set to the same value as timesteps.

Setting enable_segment=True enables the use of semantic segmentation information from the data, which provides the conditional information required by the ROCKET-1 model.

The shuffle_episodes=True parameter indicates that the trajectory arrangement in the original data will be shuffled, affecting the train-val split.

Setting episode_continuous_batch=True ensures that continuous episode segments are used as batches, which influences the sampling strategy of the dataloader.

Configure the Trainer and initiate training:

callbacks=[
    LearningRateMonitor(logging_interval='step'), 
    SpeedMonitorCallback(),
    SmartCheckpointCallback(
        dirpath='./weights', filename='weight-{epoch}-{step}', save_top_k=-1, 
        every_n_train_steps=args.save_freq, save_weights_only=True,
    ), 
    SmartCheckpointCallback(
        dirpath='./checkpoints', filename='ckpt-{epoch}-{step}', save_top_k=1, 
        every_n_train_steps=args.save_freq+1, save_weights_only=False,
    ), 
    EMA(
        decay=0.999, 
        validate_original_weights=True, 
        every_n_steps=8, 
        cpu_offload=False, 
    )
]

L.Trainer(
    logger=WandbLogger(project="minestudio"), 
    devices=8, 
    precision='bf16', 
    strategy='ddp_find_unused_parameters_true', 
    use_distributed_sampler=not episode_continuous_batch,
    callbacks=callbacks, 
    gradient_clip_val=1.0, 
).fit(
    model=mine_lightning, 
    datamodule=mine_data, 
    ckpt_path=ckpt_path,
)

To monitor the training process effectively, we use multiple callback functions.

  • LearningRateMonitor records changes in the learning rate.

  • SpeedMonitorCallback tracks the training speed.

  • SmartCheckpointCallback saves model weights and checkpoints.

  • EMA implements Exponential Moving Average.

The parameter use_distributed_sampler=not episode_continuous_batch indicates that when episode_continuous_batch=True, the dataloader will automatically use our distributed batch sampler. While configuring the Trainer, we must set use_distributed_sampler=False.

Finally, we start the training process using the fit function.

  • The model parameter corresponds to the configured mine_lightning.

  • The datamodule parameter refers to the configured mine_data.

  • The ckpt_path parameter specifies the checkpoint path. To resume training from a specific checkpoint, set ckpt_path to the path of that checkpoint. Otherwise, set it to None.