Source code for minestudio.online.rollout.replay_buffer.fragment_store

'''
Date: 2025-05-20 12:09:48
LastEditors: caishaofei-mus1 1744260356@qq.com
LastEditTime: 2025-05-23 11:39:33
FilePath: /MineStudio/var/minestudio/online/rollout/replay_buffer/fragment_store.py
'''
import ray
import logging
from diskcache import FanoutCache
from uuid import uuid4
from minestudio.online.utils.rollout.datatypes import SampleFragment

[docs] class LocalFragmentStoreImpl: """ A local implementation of a fragment store using diskcache.FanoutCache. This class provides methods to add, get, delete, and clear fragments stored on the local disk. :param path: The directory path where the cache will be stored. :param num_shards: The number of shards to use for the FanoutCache. """ def __init__(self, path: str, num_shards: int, ): self.cache = FanoutCache(path, shards=num_shards, eviction_policy="none")
[docs] def add_fragment(self, fragment: SampleFragment): """ Adds a fragment to the store and returns a unique ID for it. :param fragment: The SampleFragment object to store. :returns: A unique string ID for the stored fragment. """ fragment_uuid = str(uuid4()) self.cache[fragment_uuid] = fragment return fragment_uuid
[docs] def get_fragment(self, fragment_uuid: str): """ Retrieves a fragment from the store by its unique ID. :param fragment_uuid: The unique ID of the fragment to retrieve. :returns: The retrieved SampleFragment object. """ return self.cache[fragment_uuid]
[docs] def delete_fragment(self, fragment_uuid: str): """ Deletes a fragment from the store by its unique ID. :param fragment_uuid: The unique ID of the fragment to delete. """ del self.cache[fragment_uuid]
[docs] def clear(self): """ Removes all fragments from the store. """ self.cache.clear()
[docs] def get_disk_space(self): """ Gets the total disk space used by the cache in bytes. :returns: The disk space used by the cache. """ return self.cache.volume()
@ray.remote(resources={"database": 0.0001}) class RemoteFragmentStoreImpl: """ A Ray actor that wraps LocalFragmentStoreImpl to provide a remote fragment store. This allows the fragment store to be accessed from different nodes in a Ray cluster. It delegates all its methods to an instance of LocalFragmentStoreImpl. :param kwargs: Keyword arguments to be passed to the LocalFragmentStoreImpl constructor. """ def __init__(self, **kwargs): self.local_impl = LocalFragmentStoreImpl(**kwargs) def add_fragment(self, fragment: SampleFragment): """ Adds a fragment to the remote store. :param fragment: The SampleFragment object to store. :returns: A unique string ID for the stored fragment. """ return self.local_impl.add_fragment(fragment) def get_fragment(self, fragment_uuid: str): """ Retrieves a fragment from the remote store by its unique ID. :param fragment_uuid: The unique ID of the fragment to retrieve. :returns: The retrieved SampleFragment object. """ return self.local_impl.get_fragment(fragment_uuid) def delete_fragment(self, fragment_uuid: str): """ Deletes a fragment from the remote store by its unique ID. :param fragment_uuid: The unique ID of the fragment to delete. """ return self.local_impl.delete_fragment(fragment_uuid) def clear(self): """ Removes all fragments from the remote store. """ return self.local_impl.clear() def get_disk_space(self): """ Gets the total disk space used by the remote cache in bytes. :returns: The disk space used by the cache. """ return self.local_impl.get_disk_space()
[docs] class FragmentStore: """ A class that provides an interface to either a local or a remote fragment store. It checks if the current Ray node has a "database" resource. If so, it uses a LocalFragmentStoreImpl. Otherwise, it uses a RemoteFragmentStoreImpl actor. :param kwargs: Keyword arguments to be passed to the underlying store implementation (LocalFragmentStoreImpl or RemoteFragmentStoreImpl). :raises AssertionError: if the local status cannot be determined. """ def __init__(self, **kwargs): self.node_id = ray.get_runtime_context().get_node_id() self.local = None for node in ray.nodes(): if node["NodeID"] == self.node_id: resources = node["Resources"] if resources.get("database", 0) > 0: self.local = True else: logging.warn("Remote fragment store has not been tested yet") self.local = False break assert self.local is not None if not self.local: self.remote_impl = RemoteFragmentStoreImpl.options( placement_group=None, resources={"database": 0.0001} ).remote(**kwargs) # type: ignore else: self.local_impl = LocalFragmentStoreImpl(**kwargs)
[docs] def add_fragment(self, fragment: SampleFragment): """ Adds a fragment to the store (either local or remote). :param fragment: The SampleFragment object to store. :returns: A unique string ID for the stored fragment. """ if self.local: return self.local_impl.add_fragment(fragment) else: return ray.get(self.remote_impl.add_fragment.remote(fragment)) # type: ignore
[docs] def get_fragment(self, fragment_uuid: str) -> SampleFragment: """ Retrieves a fragment from the store (either local or remote) by its unique ID. :param fragment_uuid: The unique ID of the fragment to retrieve. :returns: The retrieved SampleFragment object. """ if self.local: return self.local_impl.get_fragment(fragment_uuid) # type: ignore else: return ray.get(self.remote_impl.get_fragment.remote(fragment_uuid)) # type: ignore
[docs] def delete_fragment(self, fragment_uuid: str): """ Deletes a fragment from the store (either local or remote) by its unique ID. :param fragment_uuid: The unique ID of the fragment to delete. """ if self.local: return self.local_impl.delete_fragment(fragment_uuid) else: return ray.get(self.remote_impl.delete_fragment.remote(fragment_uuid)) # type: ignore
[docs] def clear(self): """ Removes all fragments from the store (either local or remote). """ if self.local: return self.local_impl.clear() else: return ray.get(self.remote_impl.clear.remote()) # type: ignore
[docs] def get_disk_space(self): """ Gets the total disk space used by the cache (either local or remote) in bytes. :returns: The disk space used by the cache. """ if self.local: return self.local_impl.get_disk_space() else: return ray.get(self.remote_impl.get_disk_space.remote()) # type: ignore