Source code for minestudio.inference.pipeline
'''
Date: 2024-11-25 07:29:21
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-11-25 12:40:22
FilePath: /MineStudio/minestudio/inference/pipeline.py
'''
import ray
from typing import Union, List, Optional
from minestudio.inference.generator.base_generator import EpisodeGenerator
from minestudio.inference.filter.base_filter import EpisodeFilter
from minestudio.inference.recorder.base_recorder import EpisodeRecorder
[docs]
class EpisodePipeline:
"""
A pipeline for generating, filtering, and recording episodes.
:param episode_generator: An instance of EpisodeGenerator.
:type episode_generator: EpisodeGenerator
:param episode_filter: An instance of EpisodeFilter or a list of EpisodeFilter instances. Defaults to None.
:type episode_filter: Optional[Union[EpisodeFilter, List[EpisodeFilter]]]
:param episode_recorder: An instance of EpisodeRecorder. Defaults to None.
:type episode_recorder: Optional[EpisodeRecorder]
"""
def __init__(
self,
episode_generator: EpisodeGenerator,
episode_filter: Optional[Union[EpisodeFilter, List[EpisodeFilter]]] = None,
episode_recorder: Optional[EpisodeRecorder] = None,
):
"""
Initializes the EpisodePipeline.
:param episode_generator: An instance of EpisodeGenerator.
:type episode_generator: EpisodeGenerator
:param episode_filter: An instance of EpisodeFilter or a list of EpisodeFilter instances. Defaults to None.
:type episode_filter: Optional[Union[EpisodeFilter, List[EpisodeFilter]]]
:param episode_recorder: An instance of EpisodeRecorder. Defaults to None.
:type episode_recorder: Optional[EpisodeRecorder]
"""
if episode_filter is None:
episode_filter = EpisodeFilter()
if episode_recorder is None:
episode_recorder = EpisodeRecorder()
if not isinstance(episode_filter, List):
episode_filter = [episode_filter]
self.episode_filter = episode_filter
self.episode_generator = episode_generator
self.episode_recorder = episode_recorder
[docs]
def run(self):
"""
Runs the episode pipeline.
The pipeline generates an episode, filters it, and then records it.
:returns: A summary of the recorded episode.
:rtype: Any
"""
_generator = self.episode_generator.generate()
for episode_filter in self.episode_filter:
_generator = episode_filter.filter(_generator)
summary = self.episode_recorder.record(_generator)
return summary