Source code for minestudio.simulator.callbacks.demonstration

'''
Date: 2025-01-07 05:58:26
LastEditors: caishaofei-mus1 1744260356@qq.com
LastEditTime: 2025-01-15 20:40:34
FilePath: /ROCKET-2/var/nfs-shared/shaofei/nfs-workspace/MineStudio/minestudio/simulator/callbacks/demonstration.py
'''
import random
import numpy as np
import os

from minestudio.simulator.callbacks.callback import MinecraftCallback
from minestudio.utils import get_mine_studio_dir
from minestudio.utils.register import Registers

[docs] def download_reference_videos(): """Downloads reference videos from Hugging Face. Retrieves the Minecraft reference videos dataset (CraftJarvis/MinecraftReferenceVideos) and saves them to "reference_videos" in the MineStudio root directory. """ import huggingface_hub root_dir = get_mine_studio_dir() local_dir = os.path.join(root_dir, "reference_videos") print(f"Downloading reference videos to {local_dir}") huggingface_hub.snapshot_download(repo_id='CraftJarvis/MinecraftReferenceVideos', repo_type='dataset', local_dir=local_dir)
[docs] @Registers.simulator_callback.register class DemonstrationCallback(MinecraftCallback): """Provides demonstration data, primarily for GROOT. Manages access to task-specific reference videos, including downloading them if absent. :param task: Name of the task for demonstration data. :type task: str """
[docs] def create_from_conf(source): """Creates a DemonstrationCallback from a configuration. Loads data from the source (file path or dict). Initializes DemonstrationCallback if 'reference_video' is present. :param source: Configuration source. :type source: any :returns: DemonstrationCallback instance or None. :rtype: Optional[DemonstrationCallback] """ data = MinecraftCallback.load_data_from_conf(source) if 'reference_video' in data: return DemonstrationCallback(data['reference_video']) else: return None
def __init__(self, task): """Initializes DemonstrationCallback. Sets up by: identifying reference video directory, prompting for download if videos are missing, and selecting a random reference video for the task. :param task: The task name. :type task: str :raises AssertionError: If the task's reference video directory doesn't exist. """ root_dir = get_mine_studio_dir() reference_videos_dir = os.path.join(root_dir, "reference_videos") if not os.path.exists(reference_videos_dir): response = input("Detecting missing reference videos, do you want to download them from huggingface (Y/N)?\n") while True: if response == 'Y' or response == 'y': download_reference_videos() break elif response == 'N' or response == 'n': break else: response = input("Please input Y or N:\n") self.task = task # load the reference video ref_video_name = task assert os.path.exists(os.path.join(reference_videos_dir, ref_video_name)), f"Reference video {ref_video_name} does not exist." ref_video_path = os.path.join(reference_videos_dir, ref_video_name, "human") # randomly select a video end with .mp4 ref_video_list = [f for f in os.listdir(ref_video_path) if f.endswith('.mp4')] ref_video_path = os.path.join(ref_video_path, random.choice(ref_video_list)) self.ref_video_path = ref_video_path
[docs] def after_reset(self, sim, obs, info): """Adds the reference video path to the observation dictionary after a reset. This method ensures `obs['ref_video_path']` is set with the path to the selected demonstration video. :param sim: The simulation environment. :param obs: The observation dictionary. :param info: Additional information dictionary. :returns: The modified `obs` and `info`. :rtype: tuple[dict, dict] """ obs['ref_video_path'] = self.ref_video_path return obs, info
[docs] def after_step(self, sim, obs, reward, terminated, truncated, info): """Adds the reference video path to the observation dictionary after each step. This method ensures `obs['ref_video_path']` is set with the path to the selected demonstration video. :param sim: The simulation environment. :param obs: The observation dictionary. :param reward: The reward from the current step. :param terminated: Whether the episode has terminated. :param truncated: Whether the episode has been truncated. :param info: Additional information dictionary. :returns: The modified `obs`, `reward`, `terminated`, `truncated`, and `info`. :rtype: tuple[dict, float, bool, bool, dict] """ obs['ref_video_path'] = self.ref_video_path return obs, reward, terminated, truncated, info
def __repr__(self): """Returns a string representation of DemonstrationCallback. :returns: String representation. :rtype: str """ return f"DemonstrationCallback(task={self.task})"