Inference#
We provide a Ray-based inference framework for MineStudio, to support parallel and distributed inference. The framework is designed to consist of three parts: generator
, filter
and recorder
, forming an inference pipeline for easily evaluating the performance of different agents.
Note
We highly recommend readers to read the ray documentation before using the inference framework.
Quick Start#
Here is a minimal example of how to use the inference framework:
import ray
from minestudio.inference import EpisodePipeline, MineGenerator, InfoBaseFilter
from functools import partial
from minestudio.models import load_vpt_policy
from minestudio.simulator import MinecraftSim
if __name__ == '__main__':
ray.init()
env_generator = partial(
MinecraftSim,
obs_size = (128, 128),
preferred_spawn_biome = "forest",
) # generate the environment
agent_generator = lambda: VPTPolicy.from_pretrained("CraftJarvis/MineStudio_VPT.rl_from_early_game_2x") # generate the agent
worker_kwargs = dict(
env_generator = env_generator,
agent_generator = agent_generator,
num_max_steps = 12000, # provide the maximum number of steps
num_episodes = 2, # provide the number of episodes for each worker
tmpdir = "./output",
image_media = "h264",
) # provide the worker kwargs
pipeline = EpisodePipeline(
episode_generator = MineGenerator(
num_workers = 8, # the number of workers
num_gpus = 0.25, # the number of gpus
max_restarts = 3, # the maximum number of restarts for failed workers
**worker_kwargs,
),
episode_filter = InfoBaseFilter(
key = "mine_block",
val = "diamond_ore",
num = 1,
), # InfoBaseFilter will label episodes mine more than 1 diamond_ore
)
summary = pipeline.run()
print(summary)
Basic Components#
Generator#
The MineGenerator
class is a subclass of EpisodeGenerator
designed to manage distributed workers for generating episodes in a parallelized manner using the ray
library. This class provides a framework for efficiently utilizing multiple workers and GPUs to produce episodes in a scalable fashion.
Argument |
Default |
Description |
---|---|---|
num_workers |
1 |
Number of workers to create for parallel processing. |
num_gpus |
0.5 |
Amount of GPU resources allocated to each worker. A fractional value allows sharing GPUs between workers. |
max_restarts |
3 |
Maximum number of restarts allowed for a worker in case of failure. |
worker_kwargs |
Additional keyword arguments to pass to the worker initialization. |
The generate
method creates a generator to yield episodes produced by the workers. This method manages task assignment and result collection using ray. The workflow is as follows:
Initializes a pool of workers with ray.remote.
Assigns tasks to workers and collects results.
Waits for workers to complete tasks, retrieves completed episodes, and reassigns new tasks.
Continues yielding episodes until all tasks are completed.
Hint
Fault Tolerance: Workers are configured with a
max_restarts
parameter to handle failures gracefully, ensuring system robustness.Scalability: The design supports scalable processing across multiple GPUs, making it suitable for computationally intensive tasks.
The worker initilized by MineGenerator
is an instance of class Worker
.
The Worker class handles the generation of episodes in a simulation environment. It integrates an agent and an environment to collect data, process observations and actions, and store the results in serialized files. It supports saving image-based observations in multiple formats, making it adaptable for different data pipelines.
Here are the arguments for the Worker class:
Argument |
Default |
Description |
---|---|---|
env_generator |
A callable that creates the simulation environment. |
|
agent_generator |
A callable that creates the agent. The agent must support a |
|
num_max_steps |
The maximum number of steps per episode. |
|
num_episodes |
The total number of episodes to generate. |
|
tmpdir |
None |
Directory to store serialized output files. |
image_media |
“h264” |
Specifies the format for saving image observations: “h264” or “jpeg”. |
unused_kwargs |
Additional unused arguments, allowing flexibility in API compatibility. |
We list main methods of the Worker class below:
def append_image_and_info(self, info: Dict, images: List, infos: List)
: Processes the current step’s metadata and observation image, appending them to respective lists.def save_to_file(self, images: List, actions: List, infos: List)
: Serializes and saves the episode’s data (images, actions, and metadata) to disk.def _run(self)
: Core logic for episode generation. Runs the agent in the environment, collects data, and yields serialized results.It resets the environment and initializes storage for actions, images, and metadata. Then data is collected for
num_max_steps
steps per episode. Episode is saved to files and yields the result. The process is repeated fornum_episodes
episodes.
def get_next(self)
: Fetches the next serialized episode from the generator.
Filter#
The EpisodeFilter
and its subclass InfoBaseFilter
provide a framework for filtering episodes generated by a Generator. These classes allow users to apply custom filtering logic to episodes based on their metadata or other criteria.
The InfoBaseFilter
is initialized with key
, val
, num
and label
.
It reads the metadata for each episode, checks the condition, and adds a label to episodes that satisfy the filter:
For each episode, reads its metadata from a file (
episode["info_path"]
).Deserializes the metadata using pickle.
Checks if the value associated with the given
key
andval
meets or exceeds the num threshold.Adds the
label
with the value “yes” to episodes that pass the filter.
Recorder#
The EpisodeRecorder
class is a utility designed to process episodes generated by a Generator, record relevant statistics, and calculate metrics. It analyzes the episodes to count the number of those marked with a specific status and provides an overall summary.
The record
method iterates through episodes in the episode_generator
, counting the total number of episodes and those with "status" == "yes"
. Finally it calculates the percentage of episodes with "status" == "yes"
and prints the results.
Hint
The record
method is a simple example of how to process episodes and calculate metrics. Users can customize the logic to suit their specific requirements.
Pipeline#
The EpisodePipeline
class integrates episode generation, filtering, and recording into a single, streamlined process. It provides a flexible and extensible framework for applying filters and collecting statistics on generated episodes.
It takes a episode_generator
, episode_filter
, and episode_recorder
as arguments. The run
method executes the pipeline, generating episodes, applying filters, and recording statistics.
For more examples, please refer to inference examples with provided agents.