Source code for minestudio.online.utils.rollout.monitor

from typing import List
from collections import deque, defaultdict
import time
from rich.table import Table
import rich

[docs] class MovingStat: """ Calculates a moving statistic (average) over a specified duration. :param duration: The duration in seconds over which to calculate the statistic. """ def __init__(self, duration: int): self.duration = duration self.data_queue = deque() self.sum = 0.0 def _pop(self): """ Removes data points from the queue that are older than the specified duration. """ while len(self.data_queue) > 0 and time.time() - self.data_queue[0][0] > self.duration: self.sum -= self.data_queue[0][1] self.data_queue.popleft()
[docs] def update(self, x: float): """ Adds a new data point to the calculation. :param x: The new data point. """ self.data_queue.append((time.time(), x)) self.sum += x self._pop()
[docs] def average(self): """ Calculates the average of the data points currently within the duration window. :returns: The average of the data points, or float('nan') if no data is available. """ self._pop() if len(self.data_queue) == 0: return float('nan') return self.sum / len(self.data_queue)
[docs] class PipelineMonitor: """ Monitors the time spent in different stages of a pipeline. :param stages: A list of strings representing the names of the pipeline stages. :param duration: The duration in seconds over which to calculate moving statistics for each stage. """ def __init__(self, stages: List[str], duration: int = 300): self.stages = stages self.records = defaultdict(lambda: [MovingStat(duration) for _ in stages]) self.last_stage = defaultdict(lambda: len(stages) - 1) self.last_updated_time = {}
[docs] def report_enter(self, stage: str, pipeline_id="default"): """ Reports that a pipeline has entered a new stage. :param stage: The name of the stage being entered. :param pipeline_id: The unique identifier of the pipeline instance. :raises AssertionError: If the reported stage is not the expected next stage. """ stage_idx = self.stages.index(stage) assert stage_idx == (self.last_stage[pipeline_id] + 1) % len(self.stages) if pipeline_id in self.last_updated_time: self.records[pipeline_id][self.last_stage[pipeline_id]].update(time.time() - self.last_updated_time[pipeline_id]) self.last_updated_time[pipeline_id] = time.time() self.last_stage[pipeline_id] = stage_idx
[docs] def print(self): """ Prints a table summarizing the average time spent in each stage for all monitored pipelines. """ table = Table() table.add_column("id") for idx in range(len(self.stages)): table.add_column(f"{self.stages[idx]}->{self.stages[(idx + 1) % len(self.stages)]}") table.add_column("total") rows = [] for pipeline_id in self.records.keys(): record = self.records[pipeline_id] row = [stat.average() for stat in record] row.append(sum(row)) rows.append(row) table.add_row(pipeline_id, * ["%.6f" % r for r in row] ) summary = [sum(col) / len(col) for col in zip(*rows)] table.add_row("Summary", * ["%.6f" % r for r in summary] ) rich.print(table)
if __name__ == "__main__": pipeline = PipelineMonitor(['a', 'b', 'c']) pipeline.report_enter('a') pipeline.report_enter('b') pipeline.report_enter('c') time.sleep(1) pipeline.report_enter('a') pipeline.print()