Source code for energym.wrappers.rescale_inputs
import numpy as np
import energym
from energym.spaces.box import Box
from energym.envs.env import InputsWrapper
from copy import deepcopy
[docs]class RescaleInputs(InputsWrapper):
r"""Rescales the continuous inputs space of the environment to a given range. By default,
it is rescaled between 0 and 1."""
def __init__(self, env: energym.envs.env, lower_bound=None, upper_bound=None):
super(RescaleInputs, self).__init__(env)
rel_keys = [
p
for p in list(self.input_space.spaces.keys())
if isinstance(self.input_space[p], Box)
]
if upper_bound is None:
upper_bound = {}
if lower_bound is None:
lower_bound = {}
default_upper = {key: self.input_space[key].high[0] for key in rel_keys if
(key not in list(upper_bound.keys()))}
default_lower = {key: self.input_space[key].low[0] for key in rel_keys if
(key not in list(lower_bound.keys()))}
self.lower_bound = {**lower_bound, **default_lower}
self.upper_bound = {**upper_bound, **default_upper}
def inputs(self, inputs: dict) -> dict:
inputs_cop = deepcopy(inputs)
shared_keys = [
p
for p in inputs_cop
if p in list(self.lower_bound.keys()) and isinstance(self.input_space[p], Box)
]
for key in shared_keys:
inputs_cop[key] = list(
self.lower_bound[key]
+ (self.upper_bound[key] - self.lower_bound[key])
* (
(np.array(inputs_cop[key]))
)
)
return inputs_cop
def revert_inputs(self, inputs):
inputs_cop = deepcopy(inputs)
shared_keys = [
p
for p in inputs_cop
if p in list(self.lower_bound.keys()) and isinstance(self.input_space[p], Box)
]
for key in shared_keys:
inputs_cop[key] = list(
(np.array(inputs_cop[key]) - self.lower_bound[key])
/ (self.upper_bound[key] - self.lower_bound[key])
)
return inputs_cop