Source code for minestudio.online.trainer.start_trainer
import ray
from minestudio.online.utils.rollout import get_rollout_manager
import minestudio.online.trainer
from minestudio.online.utils.train.training_session import TrainingSession
[docs]
def start_trainer(policy_generator, env_generator, online_cfg, whole_config):
"""Starts the training process.
This function initializes and starts a training session, creating a trainer
instance based on the provided configuration.
:param policy_generator: A function to generate the policy model.
:param env_generator: A function to generate the environment.
:param online_cfg: Online training configuration.
:param whole_config: The entire configuration as a string.
"""
training_session = None
try:
training_session = ray.get_actor("training_session")
except ValueError:
pass
if training_session is not None:
print("Trainer already running!")
return
training_session = TrainingSession.options(name="training_session").remote(hyperparams=online_cfg, logger_config=online_cfg.logger_config) # type: ignore
ray.get(training_session.get_session_id.remote()) # Assure that the session is created before the trainer
print("Making trainer")
trainer_class_name = online_cfg.trainer_name # 字符串,如 "DQNTrainer"
trainer_class = getattr(minestudio.online.trainer, trainer_class_name, None)
rollout_manager = get_rollout_manager()
trainer = trainer_class(
rollout_manager=rollout_manager,
policy_generator=policy_generator,
env_generator=env_generator,
**online_cfg.train_config,
whole_config = whole_config
)
trainer.fit()