Source code for minestudio.simulator.callbacks.rewards
'''
Date: 2024-11-11 17:44:15
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
LastEditTime: 2024-11-14 20:09:56
FilePath: /Minestudio/minestudio/simulator/callbacks/rewards.py
'''
import numpy as np
from minestudio.simulator.callbacks.callback import MinecraftCallback
[docs]
class RewardsCallback(MinecraftCallback):
"""
A callback for calculating rewards based on in-game events.
This callback allows defining custom rewards for various events, such as
killing entities.
"""
def __init__(self, reward_cfg):
"""
Initializes the RewardsCallback.
:param reward_cfg: A list of reward configurations.
Each configuration is a dictionary with keys:
"event": The type of event (e.g., "kill_entity").
"identity": A unique identifier for the reward.
"objects": A list of objects related to the event.
"reward": The reward value.
"max_reward_times": The maximum number of times this reward can be given.
"""
super().__init__()
"""
Examples:
reward_cfg = [{
"event": "kill_entity",
"identity": "kill sheep or cow",
"objects": ["sheep", "cow"],
"reward": 1.0,
"max_reward_times": 5,
}]
"""
self.reward_cfg = reward_cfg
self.prev_info = {}
self.reward_memory = {}
self.current_step = 0
[docs]
def after_reset(self, sim, obs, info):
"""
Resets the reward memory and current step count.
:param sim: The Minecraft simulator.
:param obs: The observation from the simulator.
:param info: Additional information from the simulator.
:return: The observation and info.
"""
self.prev_info = info.copy()
self.reward_memory = {}
self.current_step = 0
return obs, info
[docs]
def after_step(self, sim, obs, reward, terminated, truncated, info):
"""
Calculates and overrides the reward based on the reward configuration.
:param sim: The Minecraft simulator.
:param obs: The observation from the simulator.
:param reward: The original reward from the simulator.
:param terminated: Whether the episode has terminated.
:param truncated: Whether the episode has been truncated.
:param info: Additional information from the simulator.
:return: The modified observation, overridden reward, terminated, truncated, and info.
"""
override_reward = 0.
for reward_info in self.reward_cfg:
event_type = reward_info['event']
delta = 0
for obj in reward_info['objects']:
delta += self._get_obj_num(info, event_type, obj) - self._get_obj_num(self.prev_info, event_type, obj)
if delta <= 0:
continue
already_reward_times = self.reward_memory.get(reward_info['identity'], 0)
if already_reward_times < reward_info['max_reward_times']:
override_reward += reward_info['reward']
self.reward_memory[reward_info['identity']] = already_reward_times + 1
break
self.prev_info = info.copy()
self.current_step += 1
return obs, override_reward, terminated, truncated, info
def _get_obj_num(self, info, event_type, obj):
"""
Gets the number of objects of a specific type for a given event.
:param info: The info dictionary.
:param event_type: The type of event.
:param obj: The object to count.
:return: The number of objects.
"""
if event_type not in info:
return 0.
if obj not in info[event_type]:
return 0.
res = info[event_type][obj]
return res.item() if isinstance(res, np.ndarray) else res