Source code for minestudio.online.utils.train.wandb_logger

from typing import Dict, Any
import ray
import logging
from minestudio.online.utils.train import get_current_session

logger = logging.getLogger("ray")

[docs] def log(*args, **kwargs): """ Logs data to the current Weights & Biases (wandb) training session. This function retrieves the current training session and calls its `log` method remotely. If no session is active or an error occurs during logging, an error message is logged. :param args: Positional arguments to pass to `wandb.log()`. :param kwargs: Keyword arguments to pass to `wandb.log()`. """ if (training_session := get_current_session()) is not None: try: ray.get(training_session.log.remote(*args, **kwargs)) except Exception as e: logger.error(f"Error logging to wandb: {e}")
[docs] def define_metric(*args, **kwargs): """ Defines a metric for the current Weights & Biases (wandb) training session. This function retrieves the current training session and calls its `define_metric` method remotely. If no session is active or an error occurs, an error message is logged. It asserts that a training session must be active. :param args: Positional arguments to pass to `wandb.define_metric()`. :param kwargs: Keyword arguments to pass to `wandb.define_metric()`. """ assert (training_session := get_current_session()) is not None try: ray.get(training_session.define_metric.remote(*args, **kwargs)) except Exception as e: logger.error(f"Error defining metric to wandb: {e}")
[docs] def log_video(data: Dict[str, Any], video_key: str, fps: int): """ Logs a video to the current Weights & Biases (wandb) training session. This function retrieves the current training session and calls its `log_video` method remotely. If no session is active or an error occurs during logging, an error message is logged. :param data: A dictionary containing the data to log. The video itself should be under the `video_key`. :param video_key: The key in the `data` dictionary that holds the video data. :param fps: The frames per second of the video. """ if (training_session := get_current_session()) is not None: try: ray.get(training_session.log_video.remote(data, video_key, fps)) except Exception as e: logger.error(f"Error logging video to wandb: {e}")