'''
Date: 2024-11-10 13:44:13
LastEditors: muzhancun muzhancun@126.com
LastEditTime: 2025-01-18 13:52:32
FilePath: /MineStudio/minestudio/offline/trainer.py
'''
import os
import torch
import torch.nn as nn
import lightning as L
from rich import print
from minestudio.models import MinePolicy
from minestudio.offline.mine_callbacks import ObjectiveCallback
from typing import List
IMPORTANT_VARIABLES = [
"MINESTUDIO_SAVE_DIR",
"MINESTUDIO_DATABASE_DIR",
]
for var in IMPORTANT_VARIABLES:
val = os.environ.get(var, "not found")
print(f"[Env Variable] {var}: {val}")
[docs]
def tree_detach(tree):
"""
Detaches a tree of tensors from the computation graph.
This function recursively traverses a nested structure (dictionary or list)
and detaches any PyTorch tensors it encounters. This is useful for
preventing gradients from flowing back through the detached tensors.
:param tree: The nested structure (dict, list, or tensor) to detach.
:type tree: dict | list | torch.Tensor
:returns: The detached tree.
:rtype: dict | list | torch.Tensor
"""
if isinstance(tree, dict):
return {k: tree_detach(v) for k, v in tree.items()}
elif isinstance(tree, list):
return [tree_detach(v) for v in tree]
elif isinstance(tree, torch.Tensor):
return tree.detach()
else:
return tree
[docs]
class MineLightning(L.LightningModule):
"""
A PyTorch Lightning module for training MinePolicy models.
This class encapsulates the training, validation, and optimization logic
for MinePolicy models. It handles memory management, batch processing,
and integration with ObjectiveCallbacks for custom training objectives.
"""
def __init__(
self,
mine_policy: MinePolicy,
callbacks: List[ObjectiveCallback] = [],
hyperparameters: dict = {},
*,
log_freq: int = 20,
learning_rate: float = 1e-5,
warmup_steps: int = 1000,
weight_decay: float = 0.01,
):
"""
Initializes the MineLightning module.
:param mine_policy: The MinePolicy model to train.
:type mine_policy: MinePolicy
:param callbacks: A list of ObjectiveCallback instances for custom objectives.
:type callbacks: List[ObjectiveCallback]
:param hyperparameters: A dictionary of hyperparameters to save.
:type hyperparameters: dict
:param log_freq: The frequency (in batches) for logging metrics.
:type log_freq: int
:param learning_rate: The learning rate for the optimizer.
:type learning_rate: float
:param warmup_steps: The number of warmup steps for the learning rate scheduler.
:type warmup_steps: int
:param weight_decay: The weight decay for the optimizer.
:type weight_decay: float
"""
super().__init__()
self.mine_policy = mine_policy
self.callbacks = callbacks
self.log_freq = log_freq
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.weight_decay = weight_decay
self.memory_dict = {
"memory": None,
"init_memory": None,
"last_timestamp": None,
}
self.automatic_optimization = True
self.save_hyperparameters(hyperparameters)
def _make_memory(self, batch):
"""
Manages the recurrent memory for the MinePolicy model.
This function initializes and updates the memory state based on the
batch data. It handles episode boundaries by resetting the memory
when a new episode begins.
:param batch: The input batch data.
:type batch: dict
:returns: The current memory state.
:rtype: dict
"""
if self.memory_dict["init_memory"] is None:
self.memory_dict["init_memory"] = self.mine_policy.initial_state(batch['image'].shape[0])
if self.memory_dict["memory"] is None:
self.memory_dict["memory"] = self.memory_dict["init_memory"]
if self.memory_dict["last_timestamp"] is None:
self.memory_dict["last_timestamp"] = torch.zeros(batch['image'].shape[0], dtype=torch.long).to(self.device)
boe = batch["timestamp"][:, 0].ne(self.memory_dict["last_timestamp"] + 1)
self.memory_dict["last_timestamp"] = batch["timestamp"][:, -1]
# if boe's (begin-of-episode) item is True, then we keep the original memory, otherwise we reset the memory
mem_cache = []
for om, im in zip(self.memory_dict["memory"], self.memory_dict["init_memory"]):
boe_f = boe[:, None, None].expand_as(om)
mem_line = torch.where(boe_f, im, om)
mem_cache.append(mem_line)
self.memory_dict["memory"] = mem_cache
return self.memory_dict["memory"]
def _batch_step(self, batch, batch_idx, step_name):
"""
Performs a single training or validation step.
This function processes a batch of data, computes the model output,
calculates the loss using the registered callbacks, and logs the metrics.
:param batch: The input batch data.
:type batch: dict
:param batch_idx: The index of the current batch.
:type batch_idx: int
:param step_name: The name of the current step (e.g., 'train', 'val').
:type step_name: str
:returns: A dictionary containing the loss and other metrics.
:rtype: dict
"""
result = {'loss': 0}
memory_in = self._make_memory(batch)
for callback in self.callbacks:
batch = callback.before_step(batch, batch_idx, step_name)
# memory_in = None
latents, memory_out = self.mine_policy(batch, memory_in)
self.memory_dict["memory"] = tree_detach(memory_out)
for callback in self.callbacks:
call_result = callback(batch, batch_idx, step_name, latents, self.mine_policy)
for key, val in call_result.items():
result[key] = result.get(key, 0) + val
if batch_idx % self.log_freq == 0:
for key, val in result.items():
prog_bar = ('loss' in key) and (step_name == 'train')
self.log(f'{step_name}/{key}', val, sync_dist=False, prog_bar=prog_bar)
return result
[docs]
def training_step(self, batch, batch_idx):
"""
Performs a training step.
This function calls _batch_step with step_name='train'.
:param batch: The input batch data.
:type batch: dict
:param batch_idx: The index of the current batch.
:type batch_idx: int
:returns: A dictionary containing the loss and other metrics.
:rtype: dict
"""
return self._batch_step(batch, batch_idx, 'train')
[docs]
def validation_step(self, batch, batch_idx):
"""
Performs a validation step.
This function calls _batch_step with step_name='val'.
:param batch: The input batch data.
:type batch: dict
:param batch_idx: The index of the current batch.
:type batch_idx: int
:returns: A dictionary containing the loss and other metrics.
:rtype: dict
"""
return self._batch_step(batch, batch_idx, 'val')
if __name__ == '__main__':
...