Source code for minestudio.offline.mine_callbacks.behavior_clone

'''
Date: 2024-11-12 13:59:08
LastEditors: caishaofei caishaofei@stu.pku.edu.cn
LastEditTime: 2024-12-09 15:51:34
FilePath: /MineStudio/minestudio/train/mine_callbacks/behavior_clone.py
'''

import torch
from typing import Dict, Any
from minestudio.models import MinePolicy
from minestudio.offline.mine_callbacks.callback import ObjectiveCallback

[docs] class BehaviorCloneCallback(ObjectiveCallback): """ A callback for behavior cloning. This callback calculates the behavior cloning loss, which is the negative log-likelihood of the agent's actions under the policy's action distribution. It also calculates the entropy of the policy's action distribution. """ def __init__(self, weight: float=1.0): """ Initializes the BehaviorCloneCallback. :param weight: The weight to apply to the behavior cloning loss. Defaults to 1.0. :type weight: float """ super().__init__() self.weight = weight def __call__( self, batch: Dict[str, Any], batch_idx: int, step_name: str, latents: Dict[str, torch.Tensor], mine_policy: MinePolicy, ) -> Dict[str, torch.Tensor]: """ Calculates the behavior cloning loss and entropy. The loss is computed as the negative log-probability of the agent's actions (both camera and buttons) under the current policy. A mask is applied to ignore padding in camera actions. The entropy of the policy's action distribution is also computed. :param batch: A dictionary containing the batch data. Must include 'agent_action'. :type batch: Dict[str, Any] :param batch_idx: The index of the current batch. :type batch_idx: int :param step_name: The name of the current step (e.g., 'train', 'val'). :type step_name: str :param latents: A dictionary containing the policy's latent outputs, including 'pi_logits'. :type latents: Dict[str, torch.Tensor] :param mine_policy: The MinePolicy model. :type mine_policy: MinePolicy :returns: A dictionary containing the calculated losses and metrics: 'loss': The weighted behavior cloning loss. 'camera_loss': The camera action loss. 'button_loss': The button action loss. 'entropy': The entropy of the action distribution. 'bc_weight': The weight used for the behavior cloning loss. :rtype: Dict[str, torch.Tensor] :raises AssertionError: If 'agent_action' is not in the batch. """ assert 'agent_action' in batch, "key `agent_action` is required for behavior cloning." agent_action = batch['agent_action'] pi_logits = latents['pi_logits'] log_prob = mine_policy.pi_head.logprob(agent_action, pi_logits, return_dict=True) entropy = mine_policy.pi_head.entropy(pi_logits, return_dict=True) camera_mask = (agent_action['camera'] != 60).float().squeeze(-1) global_mask = batch.get('mask', torch.ones_like(camera_mask)) logp_camera = (log_prob['camera'] * global_mask * camera_mask).sum(-1) logp_buttons = (log_prob['buttons'] * global_mask).sum(-1) entropy_camera = (entropy['camera'] * global_mask * camera_mask).sum(-1) entropy_buttons = (entropy['buttons'] * global_mask).sum(-1) camera_loss, button_loss = -logp_camera, -logp_buttons bc_loss = camera_loss + button_loss entropy = entropy_camera + entropy_buttons result = { 'loss': bc_loss.mean() * self.weight, 'camera_loss': camera_loss.mean(), 'button_loss': button_loss.mean(), 'entropy': entropy.mean(), 'bc_weight': self.weight, } return result