Source code for minestudio.offline.lightning_callbacks.speed_monitor

'''
Date: 2024-11-28 15:35:51
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-11-28 15:37:52
FilePath: /MineStudio/minestudio/train/lightning_callbacks/speed_monitor.py
'''
import time
import lightning.pytorch as pl

[docs] class SpeedMonitorCallback(pl.Callback): """ A PyTorch Lightning callback to monitor training speed. This callback logs the training speed in batches per second at regular intervals. It only logs on the global rank 0 process to avoid redundant logging in distributed training setups. """
[docs] def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """ Called at the end of each training batch. Calculates and logs the training speed every `INTERVAL` batches on rank 0. :param trainer: The PyTorch Lightning Trainer instance. :type trainer: pl.Trainer :param pl_module: The PyTorch Lightning LightningModule instance. :type pl_module: pl.LightningModule :param outputs: The outputs of the training step. :type outputs: Any :param batch: The current batch of data. :type batch: Any :param batch_idx: The index of the current batch. :type batch_idx: int """ INTERVAL = 16 if trainer.global_rank != 0 or batch_idx % INTERVAL != 0: return now = time.time() if hasattr(self, 'time_start'): time_cost = now - self.time_start trainer.logger.log_metrics({'train/speed(batch/s)': INTERVAL/time_cost}, step=trainer.global_step) self.time_start = now else: self.time_start = now