'''
Date: 2024-11-10 12:27:01
LastEditors: caishaofei-mus1 1744260356@qq.com
LastEditTime: 2025-01-15 15:04:05
FilePath: /MineStudio/minestudio/data/minecraft/tools/convertion.py
'''
import ray
import time
import pickle
import ray.experimental.tqdm_ray as tqdm_ray
import lmdb
import numpy as np
import shutil
from tqdm import tqdm
from rich import print
from rich.console import Console
from pathlib import Path
from typing import Optional, Union, Tuple, List, Dict, Any
from collections import OrderedDict
from minestudio.data.minecraft.callbacks import ModalConvertCallback
@ray.remote(num_cpus=1)
class ConvertWorker:
"""
A Ray remote actor for converting and writing data chunks to LMDB.
This worker processes a subset of episodes, converts their data using
the provided kernels, and writes the results to an LMDB database.
It also handles progress reporting via a remote tqdm instance.
"""
def __init__(
self,
write_path: Union[str, Path],
convert_kernel: ModalConvertCallback,
tasks: Dict,
chunk_size: int,
remote_tqdm: Any,
thread_pool: int = 8,
filter_kernel: Optional[ModalConvertCallback]=None,
) -> None:
"""
Initialize the ConvertWorker.
:param write_path: The path to the output LMDB directory.
:type write_path: Union[str, Path]
:param convert_kernel: The kernel for converting modal data.
:type convert_kernel: ModalConvertCallback
:param tasks: A dictionary of tasks (episodes and their parts) to process.
:type tasks: Dict
:param chunk_size: The size of data chunks.
:type chunk_size: int
:param remote_tqdm: A Ray remote tqdm instance for progress tracking.
:type remote_tqdm: Any
:param thread_pool: The number of threads for parallel processing (currently unused).
:type thread_pool: int, optional
:param filter_kernel: An optional kernel for filtering data before conversion.
:type filter_kernel: Optional[ModalConvertCallback], optional
"""
self.tasks = tasks
self.write_path = write_path
self.chunk_size = chunk_size
self.remote_tqdm = remote_tqdm
self.thread_pool = thread_pool
self.convert_kernel = convert_kernel
self.filter_kernel = filter_kernel
if isinstance(write_path, str):
write_path = Path(write_path)
if write_path.is_dir():
print(f"Write path {write_path} exists, delete it. ")
shutil.rmtree(write_path)
write_path.mkdir(parents=True)
self.lmdb_handler = lmdb.open(str(write_path), map_size=1<<40)
def run(self):
"""
Execute the conversion process for the assigned tasks.
Iterates through each episode, converts its data, and writes the
resulting chunks to the LMDB database. Also stores metadata about
the conversion process.
:returns: A dictionary containing metadata about the conversion,
including chunk information, number of episodes, and total frames.
:rtype: Dict
"""
chunk_infos = []
num_total_frames = 0
eps_idx = 0
for eps, parts in self.tasks.items():
# if eps_idx > 3: #! debug !!!
# break
eps_keys, eps_vals = [], []
eps_keys, eps_vals, cost = self.convert(eps=eps, parts=parts)
num_eps_frames = len(eps_keys) * self.chunk_size
if num_eps_frames == 0:
# empty video, skip it
continue
for key, val in zip(eps_keys, eps_vals):
with self.lmdb_handler.begin(write=True) as txn:
lmdb_key = str((eps_idx, key))
txn.put(str(lmdb_key).encode(), val)
chunk_info = {
"episode": eps,
"episode_idx": eps_idx,
"num_frames": num_eps_frames,
}
chunk_infos.append(chunk_info)
num_total_frames += num_eps_frames
eps_idx += 1
self.remote_tqdm.update.remote(1)
meta_info = {
"__chunk_size__": self.chunk_size,
"__chunk_infos__": chunk_infos,
"__num_episodes__": eps_idx,
"__num_total_frames__": num_total_frames,
}
with self.lmdb_handler.begin(write=True) as txn:
for key, val in meta_info.items():
txn.put(key.encode(), pickle.dumps(val))
print(f"Worker finish: {self.write_path}. ")
return meta_info
def convert(self, eps: str, parts: List[Tuple[int, Path, Path]]) -> Tuple[List, List, float]:
"""
Convert data for a single episode.
Uses the convert_kernel to process the episode parts. If a filter_kernel
is provided, it generates skip flags for frames before conversion.
Measures and prints the time taken for conversion and the size of the output.
:param eps: The ID of the episode to convert.
:type eps: str
:param parts: A list of tuples, where each tuple contains information
about a part of the episode (e.g., part ID, file paths).
:type parts: List[Tuple[int, Path, Path]]
:returns: A tuple containing:
- keys: A list of keys for the converted data chunks.
- vals: A list of the converted data chunks (pickled).
- cost: The time taken for conversion in seconds.
:rtype: Tuple[List, List, float]
"""
time_start = time.time()
skip_frames = []
modal_file_path = []
for i in range(len(parts)):
modal_file_path.append(parts[i][1])
if self.filter_kernel is not None:
file_name = parts[i][1].stem
skip_frames.append(self.filter_kernel.gen_frame_skip_flags(file_name))
else:
skip_frames.append( None )
keys, vals = self.convert_kernel.do_convert(eps, skip_frames, modal_file_path)
cost = time.time() - time_start
print(f"episode: {eps}, chunks: {len(keys)}, frames: {len(keys) * self.chunk_size}, "
f"size: {sum(len(x) for x in vals) / (1024*1024):.2f} MB, cost: {cost:.2f} sec")
return keys, vals, cost
[docs]
class ConvertManager:
"""
Manages the overall data conversion process using multiple ConvertWorker actors.
This class is responsible for preparing tasks (episodes and their parts),
dispatching these tasks to ConvertWorker instances, and collecting the results.
It supports filtering of episodes and parts based on a filter_kernel.
"""
def __init__(
self,
output_dir: str,
convert_kernel: ModalConvertCallback,
filter_kernel: Optional[ModalConvertCallback]=None,
chunk_size: int=32,
num_workers: int=16,
) -> None:
"""
Initialize the ConvertManager.
:param output_dir: The root directory for storing the output LMDB files.
:type output_dir: str
:param convert_kernel: The kernel used for converting modal data.
:type convert_kernel: ModalConvertCallback
:param filter_kernel: An optional kernel for filtering data before conversion.
:type filter_kernel: Optional[ModalConvertCallback], optional
:param chunk_size: The size of data chunks.
:type chunk_size: int, optional
:param num_workers: The number of ConvertWorker actors to use for parallel processing.
:type num_workers: int, optional
"""
self.output_dir = output_dir
self.convert_kernel = convert_kernel
self.filter_kernel = filter_kernel
self.chunk_size = chunk_size
self.num_workers = num_workers
[docs]
def prepare_tasks(self):
"""
Prepare the tasks (episodes and their parts) for conversion.
Loads episodes using the convert_kernel and, if provided, the filter_kernel.
Filters out episodes or parts of episodes that do not meet the criteria
defined by the filter_kernel.
The prepared tasks are stored in `self.loaded_episodes`.
"""
source_episodes = self.convert_kernel.load_episodes()
if self.filter_kernel is not None:
filter_episodes = self.filter_kernel.load_episodes()
loaded_episodes = OrderedDict()
num_removed_parts = 0
for eps, source_parts in source_episodes.items():
# 1. check if the episode is in the filter list
if self.filter_kernel is not None and eps not in filter_episodes:
num_removed_parts += len(source_parts)
continue
for ord, source_path in source_parts:
# 2. check if the part is in the filter list
if self.filter_kernel is not None:
intersection = [part for part in filter_episodes[eps] if part[0] == ord]
if len(intersection) == 0:
num_removed_parts += 1
continue
if eps not in loaded_episodes:
loaded_episodes[eps] = []
loaded_episodes[eps].append( (ord, source_path) )
self.loaded_episodes = loaded_episodes
print(f"[ConvertManager] num of removed episode parts: {num_removed_parts}")
[docs]
def dispatch(self):
"""
Dispatch the prepared tasks to ConvertWorker actors for processing.
Divides the loaded episodes among the specified number of workers.
Each worker processes its assigned episodes and writes the output to
a separate LMDB file. Collects and prints summary statistics after
all workers have completed.
"""
sub_tasks = OrderedDict()
workers = []
remote_tqdm = ray.remote(tqdm_ray.tqdm).remote(total=len(self.loaded_episodes))
num_episodes_per_file = (len(self.loaded_episodes) + self.num_workers - 1) // self.num_workers
for idx, (eps, parts) in enumerate(self.loaded_episodes.items()):
sub_tasks[eps] = parts
if (idx + 1) % num_episodes_per_file == 0 or (idx + 1) == len(self.loaded_episodes):
write_path = Path(self.output_dir) / f"part-{idx+1}"
worker = ConvertWorker.remote(
write_path=write_path,
convert_kernel=self.convert_kernel,
tasks=sub_tasks,
chunk_size=self.chunk_size,
remote_tqdm=remote_tqdm,
filter_kernel=self.filter_kernel,
)
workers.append(worker)
sub_tasks = OrderedDict()
results = ray.get([worker.run.remote() for worker in workers])
num_frames = sum([result['__num_total_frames__'] for result in results])
num_episodes = sum([result['__num_episodes__'] for result in results])
ray.kill(remote_tqdm)
print(f"Total frames: {num_frames}, Total episodes: {num_episodes}")