Source code for energym.wrappers.filter_outputs

import copy

from energym import spaces
from energym.envs.env import OutputsWrapper


[docs]class FilterOutputs(OutputsWrapper): """Filter dictionary observations by their keys. Args: env: The environment to wrap. filter_keys: List of keys to be included in the observations. Raises: ValueError: If observation keys in not instance of None or iterable. ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space """ def __init__(self, env, filter_keys=None): super(FilterOutputs, self).__init__(env) wrapped_output_space = env.output_space assert isinstance( wrapped_output_space, spaces.Dict ), "FilterObservationWrapper is only usable with dict observations." output_keys = wrapped_output_space.spaces.keys() if filter_keys is None: filter_keys = tuple(output_keys) missing_keys = set(key for key in filter_keys if key not in output_keys) if missing_keys: raise ValueError( "All the filter_keys must be included in the " "original output space.\n" "Filter keys: {filter_keys}\n" "Observation keys: {output_keys}\n" "Missing keys: {missing_keys}".format( filter_keys=filter_keys, output_keys=output_keys, missing_keys=missing_keys, ) ) self.output_space = type(wrapped_output_space)( [ (name, copy.deepcopy(space)) for name, space in wrapped_output_space.spaces.items() if name in filter_keys ] ) self._env = env self._filter_keys = tuple(filter_keys) def outputs(self, outputs): filter_outputs = self._filter_outputs(outputs) return filter_outputs def _filter_outputs(self, outputs): outputs = type(outputs)( { name: value for name, value in outputs.items() if name in self._filter_keys } ) return outputs