Source code for minestudio.online.rollout.start_manager

'''
Date: 2025-05-22 21:43:52
LastEditors: caishaofei-mus1 1744260356@qq.com
LastEditTime: 2025-05-23 11:39:21
FilePath: /MineStudio/var/minestudio/online/rollout/start_manager.py
'''
from numpy import roll
from omegaconf import OmegaConf
import hydra
import logging
from minestudio.online.rollout.rollout_manager import RolloutManager
from minestudio.online.utils.rollout import get_rollout_manager
import ray
import wandb
import uuid
import torch

[docs] def start_rolloutmanager(policy_generator, env_generator, online_cfg, address="localhost:9899"): """ Initializes and starts a RolloutManager actor in a Ray cluster. This function handles the creation or reuse of a RolloutManager actor. If a RolloutManager with the name "rollout_manager" already exists, it checks if its configuration matches the provided `online_cfg`. If the configurations differ, the existing actor is killed, and a new one is created. Otherwise, the existing actor is reused. :param policy_generator: A callable that generates a policy model. :param env_generator: A callable that generates a Minecraft simulation environment. :param online_cfg: An OmegaConf DictConfig object containing the online training configuration. This includes sub-configs for `train_config` and `rollout_config`. :param address: The address of the Ray cluster to connect to. """ ray.init(address=address, ignore_reinit_error=True, namespace="online") logger = logging.getLogger("Main") torch.backends.cudnn.benchmark = False # type: ignore rollout_manager = get_rollout_manager() rollout_manager_kwargs = dict( policy_generator = policy_generator, env_generator = env_generator, resume = online_cfg.train_config.resume, discount=online_cfg.train_config.discount, use_normalized_vf=online_cfg.train_config.use_normalized_vf, **online_cfg.rollout_config ) print("rollout_manager_kwargs", rollout_manager_kwargs) if rollout_manager is not None: if (ray.get(rollout_manager.get_saved_config.remote()) != rollout_manager_kwargs): logger.warning("Rollout manager config changed, killing and restarting rollout manager") ray.kill(rollout_manager) rollout_manager = None else: logger.info("Reusing existing rollout manager") if rollout_manager is None: if online_cfg.detach_rollout_manager: rollout_manager = RolloutManager.options(name="rollout_manager", lifetime="detached").remote(**rollout_manager_kwargs) # type: ignore else : rollout_manager = RolloutManager.options(name="rollout_manager").remote(**rollout_manager_kwargs) # type: ignore ray.get(rollout_manager.start.remote())
if __name__ == "__main__": logger = logging.getLogger("Main") logger.info("Starting rollout manager") start_rolloutmanager(None, None, None) logger.info("Rollout manager started") ray.shutdown() logger.info("Ray shutdown")