Source code for minestudio.offline.mine_callbacks.kl_divergence

'''
Date: 2024-12-12 13:10:58
LastEditors: muzhancun muzhancun@stu.pku.edu.cn
LastEditTime: 2025-05-27 14:14:37
FilePath: /MineStudio/minestudio/offline/mine_callbacks/kl_divergence.py
'''

import torch
from typing import Dict, Any
from minestudio.models import MinePolicy
from minestudio.offline.mine_callbacks.callback import ObjectiveCallback

[docs] class KLDivergenceCallback(ObjectiveCallback): """ A callback to compute the KL divergence between two Gaussian distributions. This callback is typically used in Variational Autoencoders (VAEs) or similar models where a prior distribution is regularized towards a posterior distribution. The KL divergence is calculated between a posterior (q) and a prior (p) distribution, both assumed to be Gaussian and defined by their means (mu) and log variances (log_var). """ def __init__(self, weight: float=1.0): """ Initializes the KLDivergenceCallback. :param weight: The weight to apply to the KL divergence loss. Defaults to 1.0. :type weight: float """ super().__init__() self.weight = weight def __call__( self, batch: Dict[str, Any], batch_idx: int, step_name: str, latents: Dict[str, torch.Tensor], mine_policy: MinePolicy, ) -> Dict[str, torch.Tensor]: """ Calculates the KL divergence loss. It retrieves the parameters (mu and log_var) of the posterior and prior distributions from the `latents` dictionary. Then, it computes the KL divergence using the `kl_divergence` method and returns it as part of the loss dictionary. :param batch: A dictionary containing the batch data. :type batch: Dict[str, Any] :param batch_idx: The index of the current batch. :type batch_idx: int :param step_name: The name of the current step (e.g., 'train', 'val'). :type step_name: str :param latents: A dictionary containing the policy's latent outputs. Must include 'posterior_dist' and 'prior_dist', each with 'mu' and 'log_var' keys. :type latents: Dict[str, torch.Tensor] :param mine_policy: The MinePolicy model. :type mine_policy: MinePolicy :returns: A dictionary containing the calculated losses and metrics: 'loss': The weighted KL divergence loss. 'kl_div': The mean KL divergence. 'kl_weight': The weight used for the KL divergence loss. :rtype: Dict[str, torch.Tensor] """ posterior_dist = latents['posterior_dist'] prior_dist = latents['prior_dist'] q_mu, q_log_var = posterior_dist['mu'], posterior_dist['log_var'] p_mu, p_log_var = prior_dist['mu'], prior_dist['log_var'] kl_div = self.kl_divergence(q_mu, q_log_var, p_mu, p_log_var) result = { 'loss': kl_div.mean() * self.weight, 'kl_div': kl_div.mean(), 'kl_weight': self.weight, } return result
[docs] def kl_divergence(self, q_mu, q_log_var, p_mu, p_log_var): """ Computes the KL divergence between two Gaussian distributions q and p. KL(q || p) = -0.5 * sum(1 + log(sigma_q^2 / sigma_p^2) - (sigma_q^2 / sigma_p^2) - ((mu_q - mu_p)^2 / sigma_p^2)) where sigma^2 = exp(log_var). :param q_mu: Mean of the posterior distribution q. Shape: (B, D) :type q_mu: torch.Tensor :param q_log_var: Log variance of the posterior distribution q. Shape: (B, D) :type q_log_var: torch.Tensor :param p_mu: Mean of the prior distribution p. Shape: (B, D) :type p_mu: torch.Tensor :param p_log_var: Log variance of the prior distribution p. Shape: (B, D) :type p_log_var: torch.Tensor :returns: The KL divergence for each element in the batch. Shape: (B) :rtype: torch.Tensor """ # shape: (B, D) KL = -0.5 * torch.sum( 1 + (q_log_var - p_log_var) - (q_log_var - p_log_var).exp() - (q_mu - p_mu).pow(2) / p_log_var.exp(), dim=-1 ) # shape: (B) return KL