Source code for minestudio.online.rollout.replay_buffer.interface
import ray
from omegaconf import DictConfig
from minestudio.online.rollout.replay_buffer.fragment_store import FragmentStore
from minestudio.online.rollout.replay_buffer.actor import ReplayBufferActor
from minestudio.online.utils.rollout.datatypes import FragmentMetadata
from minestudio.online.utils.rollout.datatypes import FragmentIndex, SampleFragment
from typing import List, Optional, Tuple
[docs]
class ReplayBufferInterface:
"""
Provides an interface to interact with the ReplayBufferActor.
This class handles the creation or connection to a ReplayBufferActor named "replay_buffer".
It also initializes a FragmentStore based on the actor's database configuration.
All methods to interact with the replay buffer (add, load, fetch fragments, update model version)
are routed through the ReplayBufferActor.
:param config: Optional DictConfig. If provided, a new ReplayBufferActor is created with this config.
If None, it attempts to connect to an existing actor named "replay_buffer".
:raises ValueError: If config is provided but an actor already exists, or if config is None and no actor exists.
"""
def __init__(self, config: Optional[DictConfig] = None):
existing_actor = None
try:
existing_actor = ray.get_actor("replay_buffer")
except ValueError:
pass
if config is not None:
if existing_actor is not None:
raise ValueError("Replay buffer already initialized")
self.actor = ReplayBufferActor.options(name="replay_buffer").remote(** config) # type: ignore
else:
if existing_actor is None:
raise ValueError("Replay buffer not initialized")
self.actor = existing_actor
self.database_config = ray.get(self.actor.get_database_config.remote())
self.store = FragmentStore(** self.database_config)
[docs]
def update_training_session(self):
"""
Calls the update_training_session method of the ReplayBufferActor.
:returns: The result of the actor's method call.
"""
return ray.get(self.actor.update_training_session.remote())
[docs]
def add_fragment(self, fragment: SampleFragment, metadata: FragmentMetadata):
"""
Adds a fragment to the FragmentStore and then informs the ReplayBufferActor.
:param fragment: The SampleFragment to add.
:param metadata: The FragmentMetadata associated with the fragment.
"""
fragment_id = self.store.add_fragment(fragment)
ray.get(
self.actor.add_fragment.remote(
fragment_id=fragment_id,
metadata=metadata,
)
)
[docs]
def load_fragment(self, fragment_id: str) -> SampleFragment:
"""
Loads a fragment directly from the FragmentStore.
:param fragment_id: The unique ID of the fragment to load.
:returns: The loaded SampleFragment.
"""
return self.store.get_fragment(fragment_id)
[docs]
def fetch_fragments(self, num_fragments: int) -> List[Tuple[FragmentIndex, str]]:
"""
Fetches a list of fragment IDs and their indices from the ReplayBufferActor.
:param num_fragments: The number of fragments to fetch.
:returns: A list of tuples, each containing a FragmentIndex and the fragment_id.
"""
return ray.get(
self.actor.fetch_fragments.remote(num_fragments=num_fragments)
)
[docs]
def prepared_fragments(self) -> List[Tuple[FragmentIndex, str]]:
"""
Retrieves the fragments that were prepared by the last call to fetch_fragments in the ReplayBufferActor.
:returns: A list of tuples, each containing a FragmentIndex and the fragment_id.
"""
return ray.get(
self.actor.prepared_fragments.remote()
)
[docs]
def update_model_version(self, session_id: str, model_version: int):
"""
Updates the model version in the ReplayBufferActor.
:param session_id: The ID of the current training session.
:param model_version: The new model version.
:returns: The result of the actor's method call.
"""
return ray.get(
self.actor.update_model_version.remote(
session_id=session_id,
model_version=model_version
)
)