Source code for minestudio.simulator.callbacks.point

'''
Date: 2024-11-18 20:37:50
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-11-24 08:23:45
FilePath: /MineStudio/minestudio/simulator/callbacks/point.py
'''

from minestudio.simulator.callbacks.callback import MinecraftCallback
from minestudio.simulator.utils import MinecraftGUI, GUIConstants
from minestudio.simulator.utils.gui import PointDrawCall, SegmentDrawCall, MultiPointDrawCall

import time
from typing import Dict, Literal, Optional, Callable, Tuple, List, Any
from rich import print
import numpy as np
import cv2
import os


[docs] class PointCallback(MinecraftCallback): """Allows the player to select a point on the screen using the mouse. When activated (default: by pressing 'P'), this callback opens a GUI window displaying the current game view. The player can click on this window to select a 2D point. The selected point's coordinates are stored in the `info['point']` dictionary. """ def __init__(self): """Initializes the PointCallback.""" super().__init__()
[docs] def after_reset(self, sim, obs: Dict, info: Dict) -> Tuple[Dict, Dict]: """Adds a message to the simulator to inform the user about the activation key. :param sim: The simulator instance. :param obs: The initial observation. :param info: The initial info dictionary. :returns: The passed `obs` and `info`. :rtype: Tuple[Dict, Dict] """ sim.callback_messages.add("Press 'P' to start pointing.") return obs, info
[docs] def after_step(self, sim, obs: Dict, reward: float, terminated: bool, truncated: bool, info: Dict) -> Tuple[Dict, float, bool, bool, Dict]: """Handles the point selection process if activated. If the 'P' key is pressed (i.e., `info.get('P', False)` is True): 1. Opens a new GUI window. 2. Enters a loop to capture mouse clicks for point selection. 3. Updates `info['point']` with the clicked coordinates. 4. Closes the GUI when 'ESCAPE' is pressed. :param sim: The simulator instance. :param obs: The observation after the step. :param reward: The reward received. :param terminated: Whether the episode has terminated. :param truncated: Whether the episode has been truncated. :param info: The info dictionary. :returns: The (potentially modified) obs, reward, terminated, truncated, and info. :rtype: Tuple[Dict, float, bool, bool, Dict] """ if info.get('P', False): print(f'[green]Start pointing[/green]') else: return obs, reward, terminated, truncated, info gui = MinecraftGUI(extra_draw_call=[PointDrawCall], show_info=False) gui.window.activate() while True: gui.window.dispatch_events() gui.window.switch_to() gui.window.set_mouse_visible(True) gui.window.set_exclusive_mouse(False) gui.window.flip() released_keys = gui._capture_all_keys() if 'ESCAPE' in released_keys: break if gui.mouse_position is not None: info['point'] = gui.mouse_position gui._show_image(info) gui.close_gui() if info['point'] is not None: print(f'[red]Stop pointing at {info["point"]}[/red]') info['P'] = False return obs, reward, terminated, truncated, info
[docs] class PlaySegmentCallback(MinecraftCallback): """Integrates Segment Anything Model (SAM) for interactive object segmentation. This callback allows a human player to provide positive and negative point prompts on the game's POV to segment objects using a SAM2 model. It then tracks the segmented object across subsequent frames. **Note:** This callback should typically be placed *before* the `PlayCallback` in the callback list to ensure its GUI interactions are handled correctly. Key Features: - Loads a specified SAM2 model checkpoint. - Provides a GUI for adding positive/negative point prompts. - Generates an initial segmentation based on prompts. - Tracks the segmented object in subsequent frames. - Adds the segmentation mask to `obs['segment']['obj_mask']`. Activation: - Press 'S' to start/stop the segmentation process. GUI Controls (during segmentation): - Left Mouse Click: Add a positive point prompt. - Right Mouse Click: Add a negative point prompt. - 'C': Clear all current points. - 'Enter': Start tracking the current segmentation. - 'ESCAPE': Exit segmentation mode. :param sam_path: Path to the directory containing SAM2 model checkpoints and configs. :type sam_path: str :param sam_choice: Which SAM2 model to load (e.g., 'base', 'large', 'small', 'tiny'). Defaults to 'base'. :type sam_choice: str, optional """ def __init__(self, sam_path: str, sam_choice: str = 'base'): """Initializes the PlaySegmentCallback. :param sam_path: Path to SAM2 model directory. :param sam_choice: SAM2 model variant to load. """ super().__init__() self.sam_path = sam_path self._clear() self.sam_choice = sam_choice # Corrected assignment self._load_sam() # TODO: add different segment types def _load_sam(self): """Loads the specified SAM2 model checkpoint and configuration. Dynamically imports `build_sam2_camera_predictor` from `sam2.build_sam` and initializes the predictor. """ ckpt_mapping = { 'large': [os.path.join(self.sam_path, "sam2_hiera_large.pt"), "sam2_hiera_l.yaml"], 'base': [os.path.join(self.sam_path, "sam2_hiera_base_plus.pt"), "sam2_hiera_b+.yaml"], 'small': [os.path.join(self.sam_path, "sam2_hiera_small.pt"), "sam2_hiera_s.yaml"], 'tiny': [os.path.join(self.sam_path, "sam2_hiera_tiny.pt"), "sam2_hiera_t.yaml"] } sam_ckpt, model_cfg = ckpt_mapping[self.sam_choice] # first realease the old predictor if hasattr(self, "predictor"): del self.predictor from sam2.build_sam import build_sam2_camera_predictor self.predictor = build_sam2_camera_predictor(model_cfg, sam_ckpt) print(f"Successfully loaded SAM2 from {sam_ckpt}") self.able_to_track = False def _get_message(self, info: Dict) -> Dict: """Constructs a message string for GUI display about segmentation status. :param info: The current info dictionary. :type info: Dict :returns: The modified info dictionary with the segmentation message. :rtype: Dict """ message = info.get('message', {}) message['SegmentCallback'] = f'Segment: {"On" if self.tracking else "Off"}, Tracking Time: {self.tracking_time}' return message def _clear(self): """Resets all segmentation-related state variables.""" self.positive_points = [] self.negative_points = [] self.segment = None self.able_to_track = False self.tracking = False self.tracking_time = 0
[docs] def after_reset(self, sim, obs: Dict, info: Dict) -> Tuple[Dict, Dict]: """Clears segmentation state and adds a GUI message after environment reset. :param sim: The simulator instance. :param obs: The initial observation. :param info: The initial info dictionary. :returns: The passed `obs` and `info`. :rtype: Tuple[Dict, Dict] """ self._clear() sim.callback_messages.add("Press 'S' to start/stop segmenting.") info['message'] = self._get_message(info) return obs, info
[docs] def before_step(self, sim, action: Any) -> Any: """Prevents actions if segmentation GUI is active but not yet tracking. If the 'S' key was pressed to start segmentation but tracking hasn't begun (i.e., user is still providing prompts), this returns a no-op action to pause game progression. :param sim: The simulator instance. :param action: The proposed action. :type action: Any :returns: A no-op action if segmenting GUI is active, else the original action. :rtype: Any """ if hasattr(sim, 'info') and sim.info.get('S', False) and not self.tracking: return sim.noop_action() return action
[docs] def after_step(self, sim, obs: Dict, reward: float, terminated: bool, truncated: bool, info: Dict) -> Tuple[Dict, float, bool, bool, Dict]: """Manages the segmentation lifecycle based on user input ('S' key) and GUI interaction. Handles: - Starting segmentation GUI when 'S' is pressed and not already tracking. - Stopping tracking when 'S' is pressed while tracking. - Updating the segmentation mask in `obs` if tracking is active. :param sim: The simulator instance. :param obs: The observation after the step. :param reward: The reward received. :param terminated: Whether the episode has terminated. :param truncated: Whether the episode has been truncated. :param info: The info dictionary. :returns: The (potentially modified) obs, reward, terminated, truncated, and info. :rtype: Tuple[Dict, float, bool, bool, Dict] """ if self.tracking and (not info.get('S', False)): # stop tracking print(f'[red]Stop tracking[/red]') self._clear() info['segment'] = None elif (not self.tracking) and info.get('S', False): # start tracking print(f'[green]Start segmenting[/green]') current_info = info.copy() # Use a copy for the GUI current_info['segment'] = None current_info['positive_points'] = [] current_info['negative_points'] = [] current_info = self._segment_gui(current_info, sim) # Pass sim to _segment_gui # Update original info based on GUI results if necessary info['segment'] = current_info.get('segment') info['positive_points'] = current_info.get('positive_points', []) info['negative_points'] = current_info.get('negative_points', []) if not self.tracking: info['S'] = False elif self.tracking and info.get('S', False): self.tracking_time += 1 info['segment'] = self._segment(info) if info.get('segment', None) is not None and self.tracking: # resize the segment to the size of the obs segment = cv2.resize(info['segment'].astype(np.uint8), dsize=(obs['image'].shape[0], obs['image'].shape[1]), interpolation=cv2.INTER_NEAREST) obs['segment'] = {} obs['segment']['obj_mask'] = segment obs['segment']['obj_id'] = 2 else: obs['segment'] = {} obs['segment']['obj_mask'] = np.zeros((obs['image'].shape[0], obs['image'].shape[1]), dtype=np.uint8) obs['segment']['obj_id'] = -1 info['message'] = self._get_message(info) return obs, reward, terminated, truncated, info
def _segment_gui(self, current_info: Dict, sim) -> Dict: """Manages the GUI for interactive point-based segmentation. This method creates a GUI window where the user can add positive and negative points on the current POV. It updates the segmentation mask in real-time based on these prompts. Controls: - Left Click: Add positive point. - Right Click: Add negative point. - 'C': Clear points. - 'Enter': Finalize points and start tracking. - 'ESCAPE': Cancel and exit segmentation GUI. :param current_info: A copy of the current info dictionary, used for GUI display. :type current_info: Dict :param sim: The simulator instance (passed to access POV for segmentation). :type sim: Any :returns: The `current_info` dictionary, potentially updated with segmentation results. :rtype: Dict """ info = current_info.copy() gui = MinecraftGUI(extra_draw_call=[SegmentDrawCall, MultiPointDrawCall], show_info=True) help_message = [["Press 'C' to clear points."], ["Press mouse left button to add points."], ["Press mouse right button to add negative points."], ["Press 'Enter' to start tracking."], ["Press 'ESC' to exit."]] gui.window.activate() refresh = False last_mouse_position = None while True: gui.window.dispatch_events() gui.window.switch_to() gui.window.set_mouse_visible(True) gui.window.set_exclusive_mouse(False) gui.window.flip() released_keys = gui._capture_all_keys() if 'ESCAPE' in released_keys: self._clear() info['segment'] = None info['positive_points'] = self.positive_points info['negative_points'] = self.negative_points self.tracking = False print('[red]Exit segmenting[/red]') break if 'C' in released_keys: self._clear() info['segment'] = None info['positive_points'] = self.positive_points info['negative_points'] = self.negative_points last_mouse_position = None refresh = True print('[red]Points cleared[/red]') if 'ENTER' in released_keys and self.able_to_track: assert info['segment'] is not None, 'segment is not generated before tracking.' print(f'[green]Start tracking[/green]') self.tracking = True break if gui.mouse_position is not None: if gui.mouse_pressed == 1 or gui.mouse_pressed == 4: if gui.mouse_position != last_mouse_position: last_mouse_position = gui.mouse_position # Adjust for info bar height if GUI shows it y_offset = gui.constants.INFO_HEIGHT if gui.show_info else 0 position = (last_mouse_position[0], gui.constants.FRAME_HEIGHT + y_offset - last_mouse_position[1]) # Ensure pov_shape is derived correctly from sim.obs or info # Assuming info contains the pov under 'pov' key as per _segment method context pov_image = info.get('pov', sim.obs.get('pov')) if pov_image is None: print("[red]POV image not found in info or sim.obs for segmentation GUI.[/red]") # Handle error or return, as POV is crucial gui.close_gui() return info pov_shape = pov_image.shape position = (int(position[0] * pov_shape[1] / gui.constants.WINDOW_WIDTH), int(position[1] * pov_shape[0] / gui.constants.FRAME_HEIGHT)) if gui.mouse_pressed == 1: # left button pressed self.positive_points.append(position) info['positive_points'] = self.positive_points print(f'[green]Positive point added at {position}[/green]') refresh = True elif gui.mouse_pressed == 4: # right button pressed self.negative_points.append(position) info['negative_points'] = self.negative_points print(f'[red]Negative point added at {position}[/red]') refresh = True gui.mouse_pressed = 0 if len(self.positive_points) > 0: self.able_to_track = True if self.able_to_track: self._segment(info, refresh) info['segment'] = self.segment refresh = False gui._update_image(info, message=help_message, remap_points=(gui.constants.WINDOW_WIDTH, pov_shape[1], gui.constants.FRAME_HEIGHT, pov_shape[0])) gui.close_gui() return info def _segment(self, current_info: Dict, refresh: bool = False): """Performs segmentation using the SAM2 predictor. If `self.segment` is None or `refresh` is True, it loads the first frame and adds new prompts. Otherwise, it tracks the existing segment on the new POV. :param current_info: The dictionary containing the POV image ('pov') and points. :type current_info: Dict :param refresh: Whether to re-initialize segmentation with current points. Defaults to False. :type refresh: bool, optional :returns: The generated segmentation mask. :rtype: np.ndarray """ pov_image = current_info.get('pov') if pov_image is None: # Attempt to get POV from sim.obs if not in current_info (e.g. during tracking) # This part depends on how sim object is available or if obs is passed differently # For now, let's assume it must be in current_info for _segment print("[red]POV image not found in current_info for _segment.[/red]") return self.segment # Return existing segment or None if (self.segment is None) or refresh: assert len(self.positive_points) > 0 points = self.positive_points + self.negative_points self.predictor.load_first_frame(pov_image) _, out_obj_ids, out_segment_logits = self.predictor.add_new_prompt( frame_idx=0, obj_id=0, points=points, labels=[1] * len(self.positive_points) + [0] * len(self.negative_points), ) else: out_obj_ids, out_segment_logits = self.predictor.track(pov_image) self.segment = (out_segment_logits[0, 0] > 0.0).cpu().numpy() # 360 * 640 return self.segment