Source code for minestudio.simulator.callbacks.judgereset
from minestudio.simulator.callbacks.callback import MinecraftCallback
from minestudio.simulator.utils import MinecraftGUI, GUIConstants
from minestudio.simulator.utils.gui import PointDrawCall
import time
from typing import Dict, Literal, Optional, Callable, Tuple
import cv2
[docs]
class JudgeResetCallback(MinecraftCallback):
"""Resets the environment if a time limit is reached or episode terminates.
This callback monitors the number of steps taken in an episode.
If the episode terminates naturally or if the step count exceeds `time_limit`,
it forces a reset for the next step.
:param time_limit: The maximum number of steps per episode before forcing a reset.
Defaults to 600.
:type time_limit: int, optional
"""
def __init__(self, time_limit: int = 600):
"""Initializes the JudgeResetCallback.
:param time_limit: Maximum steps per episode.
"""
super().__init__()
self.time_limit = time_limit
self.time_step = 0
[docs]
def after_reset(self, sim, obs: Dict, info: Dict) -> Tuple[Dict, Dict]:
"""Resets the internal step counter after an environment reset.
:param sim: The simulator instance.
:param obs: The initial observation after reset.
:param info: The initial info dictionary after reset.
:returns: The passed `obs` and `info`.
:rtype: Tuple[Dict, Dict]
"""
self.time_step = 0
print("Environment reset:", self.time_step)
return obs, info
[docs]
def after_step(self, sim, obs: Dict, reward: float, terminated: bool, truncated: bool, info: Dict) -> Tuple[Dict, float, bool, bool, Dict]:
"""Checks for termination or time limit and flags for reset if needed.
Increments the step counter. If `terminated` is true or `self.time_step`
exceeds `self.time_limit`, it sets `terminated` to True to signal
a reset and resets `self.time_step`.
:param sim: The simulator instance.
:param obs: The observation after the step.
:param reward: The reward received.
:param terminated: Whether the episode has terminated.
:param truncated: Whether the episode has been truncated.
:param info: The info dictionary.
:returns: The (potentially modified) obs, reward, terminated, truncated, and info.
:rtype: Tuple[Dict, float, bool, bool, Dict]
"""
self.time_step += 1
if terminated or self.time_step > self.time_limit-1:
print(f"Time limit reached, resetting the environment:", self.time_step)
self.time_step = 0
terminated = True
return obs, reward, terminated, truncated, info