Source code for energym.wrappers.downsample_outputs

import numpy as np
from energym.envs.env import StepWrapper


[docs]class DownsampleOutputs(StepWrapper): r"""Transform the outputs via an arbitrary function. Example:: >>> TBD Args: env (Env): environment steps: number of downsampling steps downsampling_dic ({keys:callable}): keys and callable functions on them """ def __init__(self, env, steps: int, downsampling_dic: dict): super(DownsampleOutputs, self).__init__(env) for key in downsampling_dic: assert callable(downsampling_dic[key]) self.downsampling_dic = downsampling_dic self.steps = steps
[docs] def step(self, inputs): output_dic = dict.fromkeys(self.downsampling_dic.keys()) for key in output_dic: output_dic[key] = [] for _ in range(self.steps): outputs = self.env.step(inputs) for key in output_dic: output_dic[key] += [outputs[key]] for key in self.downsampling_dic: output_dic[key] = self.downsampling_dic[key](output_dic[key]) return output_dic
[docs] def get_forecast(self, forecast_length=24): forecast = self.env.get_forecast(forecast_length=forecast_length*self.steps) for key in forecast: forecast[key] = self.downsampling_dic[key]( np.array(forecast[key])[ : (len(forecast[key]) // self.steps) * self.steps ].reshape(-1, self.steps), axis=1, ) forecast[key] = list(forecast[key]) return forecast