Source code for minestudio.simulator.callbacks.speed_test
'''
Date: 2024-11-11 15:59:38
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
LastEditTime: 2024-11-17 21:43:39
FilePath: /Minestudio/minestudio/simulator/callbacks/speed_test.py
'''
import time
from minestudio.simulator.callbacks.callback import MinecraftCallback
[docs]
class SpeedTestCallback(MinecraftCallback):
"""
A callback for testing the speed of the simulator.
This callback measures the average time per step and the average FPS
over a specified interval.
"""
def __init__(self, interval: int = 100):
"""
Initializes the SpeedTestCallback.
:param interval: The interval (in steps) at which to print speed test status.
"""
super().__init__()
self.interval = interval
self.num_steps = 0
self.total_times = 0
[docs]
def before_step(self, sim, action):
"""
Records the start time before executing a step.
:param sim: The Minecraft simulator.
:param action: The action to be executed.
:return: The action.
"""
self.start_time = time.time()
return action
[docs]
def after_step(self, sim, obs, reward, terminated, truncated, info):
"""
Calculates and prints the speed test status if the interval is reached.
:param sim: The Minecraft simulator.
:param obs: The observation from the simulator.
:param reward: The reward from the simulator.
:param terminated: Whether the episode has terminated.
:param truncated: Whether the episode has been truncated.
:param info: Additional information from the simulator.
:return: The observation, reward, terminated, truncated, and info.
"""
end_time = time.time()
self.num_steps += 1
self.total_times += end_time - self.start_time
if self.num_steps % self.interval == 0:
print(
f'Speed Test Status: \n'
f'Average Time: {self.total_times / self.num_steps :.2f} \n'
f'Average FPS: {self.num_steps / self.total_times :.2f} \n'
f'Total Steps: {self.num_steps} \n'
)
return obs, reward, terminated, truncated, info