Tutorial: Inference with GROOT#
To inferece with GROOT, you first need to download reference videos and pretrained checkpoints.
The example code is provided in minestudio/tutorials/inference/evaluate_groot/main.py
.
Evaluating GROOT
from minestudio.simulator import MinecraftSim
from minestudio.simulator.callbacks import SpeedTestCallback, load_callbacks_from_config
from minestudio.models import GrootPolicy, load_groot_policy
from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter
from minestudio.benchmark import prepare_task_configs
import ray
import numpy as np
import av
import os
from functools import partial
from rich import print
if __name__ == '__main__':
ray.init()
task_configs = prepare_task_configs("simple")
config_file = task_configs["collect_wood"]
# you can try: survive_plant, collect_wood, build_pillar, ... ; make sure the config file contains `reference_video` field
print(config_file)
env_generator = partial(
MinecraftSim,
obs_size = (224, 224),
preferred_spawn_biome = "plains",
callbacks = [
SpeedTestCallback(50),
] + load_callbacks_from_config(config_file)
)
agent_generator = lambda: GrootPolicy.from_pretrained("CraftJarvis/MineStudio_GROOT.18w_EMA")
worker_kwargs = dict(
env_generator=env_generator,
agent_generator=agent_generator,
num_max_steps=600,
num_episodes=2,
tmpdir="./output",
image_media="h264",
)
pipeline = EpisodePipeline(
episode_generator=MineGenerator(
num_workers=4,
num_gpus=0.25,
max_restarts=3,
**worker_kwargs,
),
episode_filter=InfoBaseFilter(
key="mine_block",
regex=".*log.*",
num=1,
),
)
summary = pipeline.run()
print(summary)
Since GROOT is an instruction following policy, we need to specify the task, corresponding config and the demonstration video.
Supported tasks and configs can be found in minestudio/benchmark/task_configs
and a detailed explanation can be found in the benchmarking tutorial.
To pass demonstration video to GROOT, we implement a DemonstrationCallback
for the environment.
The DemonstrationCallback
will first try to download the demonstration videos from the hugingface dataset if the local path reference_videos
does not exist.
Then given a task name, a video path will be selected from the downloaded videos.
After the environment is initialized, the demonstration video path will be passed to the 'ref_video_path'
field of the observation and then be used to initialize the instruction for the agent.
A line like the following will be printed to the console, indicating the reference video and calculated latent properties.
=======================================================
"Ref video is from: ./reference_videos/collect_wood/human/0.mp4."
"Num frames: 1400"
=======================================================
[📚] latent shape: torch.Size([1, 1, 1024]) | mean: -0.028 | std: 1.308
- For the inferece pipeline parameters, we need to specify:
task, configs and demonstration video for the
env_generator
.pretrained checkpoint for the
agent_generator
.rollout steps, number of episodes, output path for
worker_kwargs
.number of gpus and workers for
MineGenerator
.An
episode_filter
to filter the episode based on the key and value of the observation.
In the above example, we test the GROOT model on the task of collecting wood with 8 episodes and 1200 steps for each episode.
4 workers are used with 0.25 GPU per worker.
The episode will be filtered based on the key mine_block
and value oak_log
.
The summary of the pipeline will be printed to the console, showing the success rate and the number of episode. After the pipeline is finished, the console will print the summary of the pipeline like the following:
...
(Worker pid=922019) Episode 2 saved at output/episode_2.mp4
(Worker pid=922013) Speed Test Status: [repeated 2x across cluster]
(Worker pid=922013) Average Time: 0.04 [repeated 2x across cluster]
(Worker pid=922013) Average FPS: 24.28 [repeated 2x across cluster]
(Worker pid=922013) Total Steps: 2400 [repeated 2x across cluster]
(Worker pid=922020) Episode 2 saved at output/episode_2.mp4
(Worker pid=922013) Episode 2 saved at output/episode_2.mp4
{'num_yes': 6, 'num_episodes': 8, 'yes_rate': '75.00%'}