Inference#

We provide a Ray-based inference framework for MineStudio, to support parallel and distributed inference. The framework is designed to consist of three parts: generator, filter and recorder, forming an inference pipeline for easily evaluating the performance of different agents.

Note

We highly recommend readers to read the ray documentation before using the inference framework.

Quick Start#

Here is a minimal example of how to use the inference framework:

import ray
from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter

from functools import partial
from minestudio.models import load_vpt_policy
from minestudio.simulator import MinecraftSim

if __name__ == '__main__':
    ray.init()
    env_generator = partial(
        MinecraftSim, 
        obs_size = (128, 128), 
        preferred_spawn_biome = "forest", 
    ) # generate the environment
    agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x") # generate the agent
    worker_kwargs = dict(
        env_generator = env_generator, 
        agent_generator = agent_generator,
        num_max_steps = 12000, # provide the maximum number of steps
        num_episodes = 2, # provide the number of episodes for each worker
        tmpdir = "./output",
        image_media = "h264",
    ) # provide the worker kwargs
    pipeline = EpisodePipeline(
        episode_generator = MineGenerator(
            num_workers = 8, # the number of workers
            num_gpus = 0.25, # the number of gpus
            max_restarts = 3, # the maximum number of restarts for failed workers
            **worker_kwargs, 
        ),
        episode_filter = InfoBaseFilter(
            key = "mine_block",
            val = "diamond_ore",
            num = 1,
        ), # InfoBaseFilter will label episodes mine more than 1 diamond_ore
    )
    summary = pipeline.run()
    print(summary)

Basic Components#

Generator#

The MineGenerator class is a subclass of EpisodeGenerator designed to manage distributed workers for generating episodes in a parallelized manner using the ray library. This class provides a framework for efficiently utilizing multiple workers and GPUs to produce episodes in a scalable fashion.

Argument

Default

Description

num_workers

1

Number of workers to create for parallel processing.

num_gpus

0.5

Amount of GPU resources allocated to each worker. A fractional value allows sharing GPUs between workers.

max_restarts

3

Maximum number of restarts allowed for a worker in case of failure.

worker_kwargs

Additional keyword arguments to pass to the worker initialization.

The generate method creates a generator to yield episodes produced by the workers. This method manages task assignment and result collection using ray. The workflow is as follows:

  1. Initializes a pool of workers with ray.remote.

  2. Assigns tasks to workers and collects results.

  3. Waits for workers to complete tasks, retrieves completed episodes, and reassigns new tasks.

  4. Continues yielding episodes until all tasks are completed.

Hint

  • Fault Tolerance: Workers are configured with a max_restarts parameter to handle failures gracefully, ensuring system robustness.

  • Scalability: The design supports scalable processing across multiple GPUs, making it suitable for computationally intensive tasks.

The worker initilized by MineGenerator is an instance of class Worker. The Worker class handles the generation of episodes in a simulation environment. It integrates an agent and an environment to collect data, process observations and actions, and store the results in serialized files. It supports saving image-based observations in multiple formats, making it adaptable for different data pipelines.

Here are the arguments for the Worker class:

Argument

Default

Description

env_generator

A callable that creates the simulation environment.

agent_generator

A callable that creates the agent. The agent must support a .to("cuda") method and an .eval() method.

num_max_steps

The maximum number of steps per episode.

num_episodes

The total number of episodes to generate.

tmpdir

None

Directory to store serialized output files.

image_media

“h264”

Specifies the format for saving image observations: “h264” or “jpeg”.

unused_kwargs

Additional unused arguments, allowing flexibility in API compatibility.

We list main methods of the Worker class below:

  • def append_image_and_info(self, info: Dict, images: List, infos: List): Processes the current step’s metadata and observation image, appending them to respective lists.

  • def save_to_file(self, images: List, actions: List, infos: List): Serializes and saves the episode’s data (images, actions, and metadata) to disk.

  • def _run(self): Core logic for episode generation. Runs the agent in the environment, collects data, and yields serialized results.

    • It resets the environment and initializes storage for actions, images, and metadata. Then data is collected for num_max_steps steps per episode. Episode is saved to files and yields the result. The process is repeated for num_episodes episodes.

  • def get_next(self): Fetches the next serialized episode from the generator.

Filter#

The EpisodeFilter and its subclass InfoBaseFilter provide a framework for filtering episodes generated by a Generator. These classes allow users to apply custom filtering logic to episodes based on their metadata or other criteria.

The InfoBaseFilter is initialized with key, val, num and label. It reads the metadata for each episode, checks the condition, and adds a label to episodes that satisfy the filter:

  1. For each episode, reads its metadata from a file (episode["info_path"]).

  2. Deserializes the metadata using pickle.

  3. Checks if the value associated with the given key and val meets or exceeds the num threshold.

  4. Adds the label with the value “yes” to episodes that pass the filter.

Recorder#

The EpisodeRecorder class is a utility designed to process episodes generated by a Generator, record relevant statistics, and calculate metrics. It analyzes the episodes to count the number of those marked with a specific status and provides an overall summary.

The record method iterates through episodes in the episode_generator, counting the total number of episodes and those with "status" == "yes". Finally it calculates the percentage of episodes with "status" == "yes" and prints the results.

Hint

The record method is a simple example of how to process episodes and calculate metrics. Users can customize the logic to suit their specific requirements.

Pipeline#

The EpisodePipeline class integrates episode generation, filtering, and recording into a single, streamlined process. It provides a flexible and extensible framework for applying filters and collecting statistics on generated episodes.

It takes a episode_generator, episode_filter, and episode_recorder as arguments. The run method executes the pipeline, generating episodes, applying filters, and recording statistics.

For more examples, please refer to inference examples with provided agents.