Tutorial: Inference with VPT#

We can easily perform batch evaluations on a VPT model. A simple example is evaluating the success rate of a policy model fine-tuned by OpenAI using reinforcement learning for the diamond-mining task:

First, import the necessary dependencies. Since our episode generator, MineGenerator, is implemented based on ray, it is essential to initialize ray beforehand.

import ray
from functools import partial

from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter
from minestudio.models import load_vpt_policy
from minestudio.simulator import MinecraftSim

ray.init()

Next, we create the env_generator and agent_generator separately to enable workers to generate resources.

env_generator = partial(
    MinecraftSim,
    obs_size=(128, 128),
    preferred_spawn_biome="forest",
)
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x")
Next, we configure the worker parameters, including:
  • A maximum of 12,000 steps per episode,

  • Each worker generating 2 episodes,

  • The output folder set to ./output,

  • The output video format set to h264.

worker_kwargs = dict(
    env_generator=env_generator,
    agent_generator=agent_generator,
    num_max_steps=12000,
    num_episodes=2,
    tmpdir="./output",
    image_media="h264",
)

Finally, we create an EpisodePipeline object, passing MineGenerator as the episode generator and InfoBaseFilter as the episode filter.

pipeline = EpisodePipeline(
    episode_generator=MineGenerator(
        num_workers=8,
        num_gpus=0.25,
        max_restarts=3,
        **worker_kwargs,
    ),
    episode_filter=InfoBaseFilter(
        key="mine_block",
        val="diamond_ore",
        num=1,
    ),
)
summary = pipeline.run()
print(summary)

Note

We initialized 8 workers, with each worker utilizing 0.25 of a GPU and allowing up to 3 restarts.

We used the built-in InfoBaseFilter to process the generated episodes, including detecting whether a mine_block event occurred with the val set to diamond_ore.

The summary of the pipeline will be printed to the console, showing the success rate and the number of episode. After the pipeline is finished, the console will print the summary of the pipeline like the following:

... ...
{'num_yes': 4, 'num_episodes': 16, 'yes_rate': '25.00%'}
(Worker pid=1011772) Speed Test Status:
(Worker pid=1011772) Average Time: 0.02
(Worker pid=1011772) Average FPS: 56.11