Source code for minestudio.online.utils.train.gae
import ray
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from minestudio.online.utils.rollout.datatypes import FragmentDataDict, FragmentIndex
import minestudio.online.utils.train.wandb_logger as wandb_logger
from collections import defaultdict
import logging
[docs]
def get_last_fragment_indexes(fragment_indexs: List[FragmentIndex]) -> List[FragmentIndex]:
"""
Identifies the last fragment index for each worker from a list of fragment indexes.
A fragment is considered the last if it's the last one from a worker or if the next fragment from the same worker is not contiguous.
:param fragment_indexs: A list of FragmentIndex objects.
:returns: A list of FragmentIndex objects, each being the last fragment for a worker in a sequence.
"""
fragment_indexs = fragment_indexs.copy()
fragment_indexs.sort(key=lambda x: x.fid_in_worker)
last_fragment_indexs = []
_last_idx = {}
for index in reversed(fragment_indexs):
if not (index.worker_uuid in _last_idx) or _last_idx[index.worker_uuid] != index.fid_in_worker + 1:
last_fragment_indexs.append(index)
_last_idx[index.worker_uuid] = index.fid_in_worker
return last_fragment_indexs
@ray.remote
class GAEWorker:
"""
A Ray remote actor for calculating Generalized Advantage Estimation (GAE) and TD-Lambda targets.
This worker processes fragments, calculates advantages and value targets, and stores them.
:param discount: The discount factor (gamma) for future rewards.
:param gae_lambda: The GAE lambda parameter for balancing bias and variance.
"""
def __init__(self,
discount: float,
gae_lambda: float,
):
self._gae_lengths = []
self.discount = discount
self.gae_lambda = gae_lambda
self.reset()
def reset(self):
"""
Resets the internal state of the GAE worker, clearing stored GAE information.
"""
self.gae_infos: Dict[FragmentIndex, Dict[str, Any]] = {}
def update_gae_infos(self, gae_infos: Dict[FragmentIndex, Dict[str, Any]]):
"""
Updates the GAE information with new data.
:param gae_infos: A dictionary mapping FragmentIndex to a dictionary of GAE-related information
(e.g., 'vpred', 'reward', 'next_done', 'next_vpred').
"""
self.gae_infos.update(gae_infos)
def calculate_target(self):
"""
Calculates TD-Lambda targets and GAE advantages for the stored fragments.
It iterates through fragments in reverse chronological order for each worker to compute GAE.
Logs the average GAE length using wandb_logger if available.
"""
fragment_indexs = list(self.gae_infos.keys())
fragment_indexs.sort(key=lambda x: x.fid_in_worker)
self.td_targets, self.advantages = FragmentDataDict(), FragmentDataDict()
_last_idx = {}
_gae_length = {}
fragment_count = defaultdict(int)
last_advantage = defaultdict(float)
last_next_vpred = defaultdict(float)
for index in reversed(fragment_indexs):
if not (index.worker_uuid in _last_idx) or _last_idx[index.worker_uuid] != index.fid_in_worker + 1:
last_advantage[index.worker_uuid] = 0
if 'next_vpred' not in self.gae_infos[index]:
ray.util.pdb.set_trace()
last_next_vpred[index.worker_uuid] = self.gae_infos[index]['next_vpred']
if index.worker_uuid in self._gae_lengths:
self._gae_lengths.append(_gae_length[index.worker_uuid])
_gae_length[index.worker_uuid] = 0
_last_idx[index.worker_uuid] = index.fid_in_worker
vpred: np.ndarray = self.gae_infos[index]['vpred']
assert len(vpred.shape) == 1
reward = self.gae_infos[index]['reward']
next_done = self.gae_infos[index]['next_done']
next_vpred = last_next_vpred[index.worker_uuid]
last_gae_adv = last_advantage[index.worker_uuid]
self.advantages[index] = np.zeros_like(vpred)
self.td_targets[index] = np.zeros_like(vpred)
_gae_length[index.worker_uuid] += len(next_done)
fragment_count[index.worker_uuid] += 1
for t in range(len(next_done) - 1, -1, -1):
next_nonterminal = 1.0 - next_done[t]
delta = reward[t] + self.discount * next_vpred * next_nonterminal - vpred[t]
last_gae_adv = delta + self.discount * self.gae_lambda * next_nonterminal * last_gae_adv
self.advantages[index][t] = last_gae_adv
self.td_targets[index][t] = last_gae_adv + vpred[t]
next_vpred = vpred[t]
if np.isnan(self.td_targets[index][t]) or np.isnan(next_vpred):
ray.util.pdb.set_trace()
last_advantage[index.worker_uuid] = last_gae_adv
last_next_vpred[index.worker_uuid] = next_vpred
self._gae_lengths += list(_gae_length.values())
if len(self._gae_lengths) > 0:
wandb_logger.log({
"GAEWorker/average_gae_length": np.mean(self._gae_lengths),
})
self._gae_lengths = []
# self.print_episodes() # for debug
def get_target(self, indexs: List[FragmentIndex]) -> Tuple[FragmentDataDict, FragmentDataDict]:
"""
Retrieves the calculated TD-Lambda targets and GAE advantages for a given list of fragment indexes.
:param indexs: A list of FragmentIndex objects for which to retrieve the targets and advantages.
:returns: A tuple containing two FragmentDataDicts: one for TD-Lambda targets and one for GAE advantages.
"""
td_targets, advantages = FragmentDataDict(), FragmentDataDict()
for index in indexs:
td_targets[index] = self.td_targets[index]
advantages[index] = self.advantages[index]
return td_targets, advantages