Source code for minestudio.simulator.utils.gui

'''
Date: 2024-11-15 15:15:22
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
LastEditTime: 2024-11-20 01:02:18
FilePath: /MineStudio/minestudio/simulator/utils/gui.py
'''
from minestudio.simulator.utils.constants import GUIConstants   

from collections import defaultdict
from typing import List, Any, Optional, Callable
import importlib
import cv2
import time
from rich import print
import numpy as np

[docs] def RecordDrawCall(info, **kwargs): """ Draws a recording indicator on the POV display. A red or green circle and "Rec" text are drawn on the top-left corner of the POV image if recording is active. The color of the circle alternates based on the current time. :param info: A dictionary containing the 'pov' image and 'R' (recording status) flag. :param kwargs: Additional keyword arguments (not used). :return: The modified info dictionary with the recording indicator drawn on the 'pov' image. """ if 'R' not in info.keys() or info.get('ESCAPE', False): return info recording = info['R'] if not recording: return info arr = info['pov'] if int(time.time()) % 2 == 0: cv2.circle(arr, (20, 20), 10, (255, 0, 0), -1) cv2.putText(arr, 'Rec', (40, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2) else: cv2.circle(arr, (20, 20), 10, (0, 255, 0), -1) cv2.putText(arr, 'Rec', (40, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) info['pov'] = arr return info
[docs] def CommandModeDrawCall(info, **kwargs): """ Draws a command mode indicator on the POV display. If command mode is active (indicated by the 'ESCAPE' flag in info), the POV image is converted to grayscale, and "Command Mode" text is drawn on the top-left corner. :param info: A dictionary containing the 'pov' image and 'ESCAPE' (command mode) flag. :param kwargs: Additional keyword arguments (not used). :return: The modified info dictionary with the command mode indicator. """ if 'ESCAPE' not in info.keys(): return info mode = info['ESCAPE'] if not mode: return info # Draw a grey overlay on the screen arr = info['pov'] arr = cv2.cvtColor(arr, cv2.COLOR_BGR2GRAY) arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR) cv2.putText(arr, 'Command Mode', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) info['pov'] = arr return info
[docs] def PointDrawCall(info, **kwargs): """ Draws a point indicator on the POV display. If a 'point' is present in the info dictionary, a red circle is drawn at the specified coordinates on the POV image. Text indicating the point's coordinates is also displayed. :param info: A dictionary containing the 'pov' image and 'point' coordinates. :param kwargs: Additional keyword arguments (not used). :return: The modified info dictionary with the point drawn on the 'pov' image. """ if 'point' not in info.keys(): return info point = info['point'] arr = info['pov'] # draw a red circle at the point, the position is relative to the bottom-left corner of arr cv2.circle(arr, (point[0], arr.shape[0] - point[1]), 10, (0, 0, 255), -1) cv2.putText(arr, f'Pointing at {point}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) info['pov'] = arr return info
[docs] def MultiPointDrawCall(info, **kwargs): """ Draws multiple point indicators (positive and negative) on the POV display. Positive points are drawn as green circles, and negative points are drawn as red circles. Point coordinates can be remapped using 'remap_points' in kwargs. :param info: A dictionary containing 'positive_points' and 'negative_points' lists, and the 'pov' image. :param kwargs: Additional keyword arguments, including 'remap_points'. :return: The modified info dictionary with points drawn on the 'pov' image. """ if 'positive_points' not in info.keys() or 'negative_points' not in info.keys(): return info positive_points = info['positive_points'] negative_points = info['negative_points'] if len(positive_points) == 0: return info arr = info['pov'] remap_points = kwargs.get('remap_points', (1, 1, 1, 1)) for point in positive_points: point = (int(point[0] * remap_points[0] / remap_points[1]), int(point[1] * remap_points[2] / remap_points[3])) cv2.circle(arr, (point[0], point[1]), 10, (0, 255, 0), -1) for point in negative_points: point = (int(point[0] * remap_points[0] / remap_points[1]), int(point[1] * remap_points[2] / remap_points[3])) cv2.circle(arr, (point[0], point[1]), 10, (255, 0, 0), -1) info['pov'] = arr return info
[docs] def SegmentDrawCall(info, **kwargs): """ Draws a segmentation mask overlay on the POV display. If a 'segment' mask is present in the info dictionary, it's resized to the POV image dimensions and overlaid with a green color. :param info: A dictionary containing the 'segment' mask and 'pov' image. :param kwargs: Additional keyword arguments (not used). :return: The modified info dictionary with the segmentation mask overlay. """ if 'segment' not in info.keys(): return info mask = info['segment'] if mask is None: return info arr = info['pov'] color = (0, 255, 0) color = np.array(color).reshape(1, 1, 3)[:, :, ::-1] mask = (mask[..., None] * color).astype(np.uint8) # resize the mask to the size of the obs mask = cv2.resize(mask, dsize=(arr.shape[1], arr.shape[0]), interpolation=cv2.INTER_CUBIC) arr = cv2.addWeighted(arr, 1.0, mask, 0.5, 0.0) info['pov'] = arr return info
[docs] class MinecraftGUI: """ Manages the Pyglet-based GUI for the Minecraft simulator. Handles window creation, event processing (keyboard, mouse), rendering of the POV display, and displaying informational messages. It also supports custom draw calls for additional visual elements. """ def __init__(self, extra_draw_call: List[Callable] = None, show_info = True, **kwargs): """ Initializes the MinecraftGUI. :param extra_draw_call: A list of callable functions for custom drawing on the POV. :param show_info: Boolean indicating whether to display the information panel. :param kwargs: Additional keyword arguments passed to the superclass. """ super().__init__(**kwargs) self.constants = GUIConstants() self.pyglet = importlib.import_module('pyglet') self.imgui = importlib.import_module('imgui') self.key = importlib.import_module('pyglet.window.key') self.mouse = importlib.import_module('pyglet.window.mouse') self.PygletRenderer = importlib.import_module('imgui.integrations.pyglet').PygletRenderer self.extra_draw_call = extra_draw_call self.show_info = show_info self.mode = 'normal' self.create_window()
[docs] def create_window(self): """ Creates the Pyglet window and sets up event handlers and ImGui integration. """ if self.show_info: self.window = self.pyglet.window.Window( width = self.constants.WINDOW_WIDTH, height = self.constants.INFO_HEIGHT + self.constants.FRAME_HEIGHT, vsync=False, resizable=False ) else: self.window = self.pyglet.window.Window( width = self.constants.WINDOW_WIDTH, height = self.constants.FRAME_HEIGHT, vsync=False, resizable=False ) self.imgui.create_context() self.imgui.get_io().display_size = self.constants.WINDOW_WIDTH, self.constants.WINDOW_HEIGHT self.renderer = self.PygletRenderer(self.window) self.pressed_keys = defaultdict(lambda: False) self.released_keys = defaultdict(lambda: False) self.modifiers = None self.window.on_mouse_motion = self._on_mouse_motion self.window.on_mouse_drag = self._on_mouse_drag self.window.on_key_press = self._on_key_press self.window.on_key_release = self._on_key_release self.window.on_mouse_press = self._on_mouse_press self.window.on_mouse_release = self._on_mouse_release self.window.on_activate = self._on_window_activate self.window.on_deactivate = self._on_window_deactivate self.last_pov = None self.last_mouse_delta = [0, 0] self.capture_mouse = True self.mouse_position = None self.mouse_pressed = None self.chat_message = None self.command = None self.window.dispatch_events() self.window.switch_to() self.window.flip() self.window.clear() self._show_message("Waiting for start.")
def _on_key_press(self, symbol, modifiers): """ Handles key press events. :param symbol: The Pyglet key symbol. :param modifiers: Key modifiers (e.g., Shift, Ctrl). """ self.pressed_keys[symbol] = True self.modifiers = modifiers def _on_key_release(self, symbol, modifiers): """ Handles key release events. :param symbol: The Pyglet key symbol. :param modifiers: Key modifiers. """ self.pressed_keys[symbol] = False self.released_keys[symbol] = True self.modifiers = modifiers def _on_mouse_press(self, x, y, button, modifiers): """ Handles mouse button press events. :param x: The x-coordinate of the mouse press. :param y: The y-coordinate of the mouse press. :param button: The mouse button pressed. :param modifiers: Key modifiers. """ self.pressed_keys[button] = True self.mouse_pressed = button self.mouse_position = (x, y) def _on_mouse_release(self, x, y, button, modifiers): """ Handles mouse button release events. :param x: The x-coordinate of the mouse release. :param y: The y-coordinate of the mouse release. :param button: The mouse button released. :param modifiers: Key modifiers. """ self.pressed_keys[button] = False def _on_window_activate(self): """ Handles window activation events (e.g., window gains focus). Sets mouse visibility and exclusivity for gameplay. """ self.window.set_mouse_visible(False) self.window.set_exclusive_mouse(True) def _on_window_deactivate(self): """ Handles window deactivation events (e.g., window loses focus). Restores mouse visibility and exclusivity. """ self.window.set_mouse_visible(True) self.window.set_exclusive_mouse(False) def _on_mouse_motion(self, x, y, dx, dy): """ Handles mouse motion events. Updates the `last_mouse_delta` for camera control. Note that vertical mouse movement (dy) is inverted. :param x: The current x-coordinate of the mouse. :param y: The current y-coordinate of the mouse. :param dx: The change in x-coordinate since the last event. :param dy: The change in y-coordinate since the last event. """ # Inverted self.last_mouse_delta[0] -= dy * self.constants.MOUSE_MULTIPLIER self.last_mouse_delta[1] += dx * self.constants.MOUSE_MULTIPLIER def _on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): """ Handles mouse drag events (mouse motion while a button is pressed). Updates the `last_mouse_delta` for camera control. Note that vertical mouse movement (dy) is inverted. :param x: The current x-coordinate of the mouse. :param y: The current y-coordinate of the mouse. :param dx: The change in x-coordinate since the last event. :param dy: The change in y-coordinate since the last event. :param buttons: The mouse buttons currently pressed. :param modifiers: Key modifiers. """ # Inverted self.last_mouse_delta[0] -= dy * self.constants.MOUSE_MULTIPLIER self.last_mouse_delta[1] += dx * self.constants.MOUSE_MULTIPLIER def _show_message(self, text): """ Displays a centered message on the screen. Used for messages like "Waiting for start." or "Resetting environment...". :param text: The text to display. """ document = self.pyglet.text.document.FormattedDocument(text) document.set_style(0, len(document.text), dict(font_name='Arial', font_size=32, color=(255, 255, 255, 255))) document.set_paragraph_style(0,100,dict(align = 'center')) layout = self.pyglet.text.layout.TextLayout( document, width=self.window.width//2, height=self.window.height//2, multiline=True, wrap_lines=True, ) layout.update(x=self.window.width//2, y=self.window.height//2) layout.anchor_x = 'center' layout.anchor_y = 'center' layout.content_valign = 'center' layout.draw() self.window.flip() def _show_additional_message(self, message: List): """ Displays additional messages in the info panel. Each item in the `message` list is displayed as a row of text. :param message: A list of lists/tuples, where each inner list/tuple represents a row of text elements to be joined by " | ". """ if len(message) == 0: return line_height = self.constants.INFO_HEIGHT // len(message) y = line_height // 2 for i, row in enumerate(message): line = ' | '.join(row) self.pyglet.text.Label( line, font_size = 7 * self.constants.SCALE, x = self.window.width // 2, y = y, anchor_x = 'center', anchor_y = 'center', ).draw() y += line_height def _update_image(self, info, message: List = [], **kwargs): """ Updates and renders the main game view (POV) and info panel. Resizes the POV, applies extra draw calls, displays additional messages, and handles ImGui rendering for the chat interface. :param info: A dictionary containing the 'pov' image. :param message: A list of messages for the info panel. :param kwargs: Additional keyword arguments for extra draw calls. """ self.window.switch_to() self.window.clear() # Based on scaled_image_display.py info = info.copy() arr = info['pov'] arr = cv2.resize(arr, dsize=(self.constants.WINDOW_WIDTH, self.constants.FRAME_HEIGHT), interpolation=cv2.INTER_CUBIC) # type: ignore info['pov'] = arr if self.extra_draw_call is not None: for draw_call in self.extra_draw_call: info = draw_call(info, **kwargs) arr = info['pov'] image = self.pyglet.image.ImageData(arr.shape[1], arr.shape[0], 'RGB', arr.tobytes(), pitch=arr.shape[1] * -3) texture = image.get_texture() texture.blit(0, self.constants.INFO_HEIGHT) if self.show_info: self._show_additional_message(message) self.imgui.new_frame() self.imgui.begin("Chat", False, self.imgui.WINDOW_ALWAYS_AUTO_RESIZE) changed, command = self.imgui.input_text("Message", "") self.command = command if self.imgui.button("Send"): self.chat_message = command self.command = None self.imgui.end() self.imgui.render() self.renderer.render(self.imgui.get_draw_data()) self.window.flip() def _show_image(self, info, **kwargs): """ Displays the POV image without the info panel or ImGui chat. Used when `show_info` is False. :param info: A dictionary containing the 'pov' image. :param kwargs: Additional keyword arguments for extra draw calls. """ self.window.switch_to() self.window.clear() info = info.copy() arr = info['pov'] arr = cv2.resize(arr, dsize=(self.constants.WINDOW_WIDTH, self.constants.FRAME_HEIGHT), interpolation=cv2.INTER_CUBIC) info['pov'] = arr if self.extra_draw_call is not None: for draw_call in self.extra_draw_call: info = draw_call(info, **kwargs) arr = info['pov'] image = self.pyglet.image.ImageData(arr.shape[1], arr.shape[0], 'RGB', arr.tobytes(), pitch=arr.shape[1] * -3) texture = image.get_texture() texture.blit(0, 0) self.window.flip() def _get_human_action(self): """ Reads keyboard and mouse state to form a human action dictionary. :return: A dictionary representing the current human action. """ # Keyboard actions action: dict[str, Any] = { name: int(self.pressed_keys[key]) for name, key in self.constants.MINERL_ACTION_TO_KEYBOARD.items() } if not self.capture_mouse: self.last_mouse_delta = [0, 0] action["camera"] = self.last_mouse_delta self.last_mouse_delta = [0, 0] return action
[docs] def reset_gui(self): """ Resets the GUI state, clears the window, and shows a "Resetting" message. """ self.window.clear() self.pressed_keys = defaultdict(lambda: False) self._show_message("Resetting environment...")
def _capture_all_keys(self): """ Captures all keys that were released since the last call. :return: A set of string representations of the released keys. """ released_keys = set() for key in self.released_keys.keys(): if self.released_keys[key]: self.released_keys[key] = False released_keys.add(self.key.symbol_string(key)) return released_keys
[docs] def close_gui(self): """ Closes the Pyglet window and exits the Pyglet application. """ #! WARNING: This should be checked self.window.close() self.pyglet.app.exit()
if __name__ == "__main__": gui = MinecraftGUI()