Base Environment
BaseEnvironment
BaseEnvironment (mdp_info:ddopai.utils.MDPInfo, postprocessors:list[object]|None=None, mode:str='train', return_truncation:str=True, horizon_train:int|str='use_all_data')
Base class for environments enforcing a common interface that works with MushroomRL, as well as other RL libraries.
Type | Default | Details | |
---|---|---|---|
mdp_info | MDPInfo | MDPInfo object to ensure compatibility with the agents | |
postprocessors | list[object] | None | None | default is empty list |
mode | str | train | Initial mode (train, val, test) of the environment |
return_truncation | str | True | whether to return a truncated condition in step function |
horizon_train | int | str | use_all_data | horizon of the training data |
Returns | None |
Important notes:
init method:
When adding parameters to the environment, make sure to always add them via
set_param(...)
. This ensures all parameters are of the correct types and shapes.During the init method, any Gymnasium environment expects the action and observation space to be defined. For clarity, avoid doing it directly in the init, but rather use the functions
set_action_space()
andset_observation_space()
and call them in the___init___
method.
train, val, test, and horizon (episode length):
When the
__init__
method is called, the environment executes thetrain()
,val()
ortest()
methods. Therefore, they must be implemented in a way that they work right during set-up.train()
,val()
andtest()
methods are provided in the base class, but can also be overwritten if necessary. In any case, they must set the dataloader to the correct dataset to ensure no data leakage. They also need to update mdp_info to update the horizon (episode length) of the environmentThe horizon for validation and testing will be equal to the length of those datasets. For training, there is a parameter
horizon_train
that either contains a string “use_all_data” or an integer. If it is the former, the horizon will be the length of the training dataset. If it is the latter, the environment will play an episode of lengthhorizon_train
starting at a random point of the training dataset.
step method:
The step method is the core of the environment, calculating the next state (observation) and reward given an action. Since some frameworks expect a truncation condition (standard implementation in Gymnasium now) while others (e.g., mushroom_rl), do not, the step function is implemented in the base class and handles this (via a flag in in the environment called
return_truncation
). DO NOT OVERWRITE the step function, but rather implement thestep_(self, action)
(underscore) method in the specific environment. This function shall always return a tuple of the form (observation, reward, terminated, truncated, info).For clarity, the construction of the next state (we call it more general observation to include POMDPs) is done in a separate method called
get_observation()
that must be called inside the step function. See documentation below and the Newsvendor environmentenvs.inventory.NewsvendorEnv
for an example.The dataloader will typically return an X,Y pair (where X are some features and Y typically is demand) The X is necessary at the end of the step to construct the next observation to be returned to the agent. The Y is only relevant one step later to calculate the reward. Hence, Y is typically transferred to the next step method via an object variable like self.demand (see
envs.inventory.NewsvendorEnv
as an example).
observation pre-processors and action post-processors:
- Sometimes, it is necessary to process the observartion before giving it to the agent (e.g., changing shape) or to process the action before giving it to the environment (e.g., rounding). To ensure compatibility with mushroom_rl, the pre-processors (also called observationprocessors) sit with the agent (they must be added to the agent and are applied in the agent’s
draw_action()
method). The post-processors (also called actionprocessors) sit with the environment and are applied in the environment’sstep()
method.
reset method:
The reset method may depend strongly on the environment dynamics, so it must be implemented for the specific environment. It needs to fulfill two requirements: 1) it needs to differenticate between train, val, and test mode and 2) when setting the training mode, it needs to take the
horizon_train
parameter into account.At the end of the function, first the
reset_index()
method should be called (either with a specific index as integer or the flag"random"
as input) and then theget_observation()
method to construct the first observation.
BaseEnvironment.set_param
BaseEnvironment.set_param (name:str, input:Union[ddopai.utils.Parameter,i nt,float,numpy.ndarray,List,Dict,NoneType], shape:tuple=(1,), new:bool=False)
Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, the function will create a new parameter or update an existing one otherwise. If new is set to False, the function will raise an error if the parameter does not exist.
Type | Default | Details | |
---|---|---|---|
name | str | name of the parameter (will become the attribute name) | |
input | Union | input value of the parameter | |
shape | tuple | (1,) | shape of the parameter |
new | bool | False | whether to create a new parameter or update an existing one |
Returns | None |
BaseEnvironment.return_truncation_handler
BaseEnvironment.return_truncation_handler (observation, reward, terminated, truncated, info)
Handle the return_truncation attribute of the environment. This function is called by the step function
BaseEnvironment.step
BaseEnvironment.step (action)
Step function of the environment. Do not overwrite this function. Instead, write the step_ function. Note that the postprocessor is applied here.
BaseEnvironment.add_postprocessor
BaseEnvironment.add_postprocessor (postprocessor:object)
Add a postprocessor (also called actionprocessor) to the agent
Type | Details | |
---|---|---|
postprocessor | object | post-processor object that can be called via the “call” method |
BaseEnvironment.step_
BaseEnvironment.step_ (action)
Step function of the environment. This function contains the logic of the environment and must be provided. It will be called by the step function that applies the actionprocessor and handles the return_truncation attribute.
BaseEnvironment.mdp_info
BaseEnvironment.mdp_info ()
Returns: The MDPInfo object of the environment.
BaseEnvironment.info
BaseEnvironment.info ()
Returns: Alternative call to the method for mushroom_rl.
BaseEnvironment.mode
BaseEnvironment.mode ()
Returns: A string with the current mode (train, test val) of the environment.
BaseEnvironment.set_action_space
BaseEnvironment.set_action_space ()
Set the action space of the environment.
BaseEnvironment.set_observation_space
BaseEnvironment.set_observation_space ()
Set the observation space of the environment. In general, this can be also a dict space, but the agent must have the appropriate pre-processor.
BaseEnvironment.get_observation
BaseEnvironment.get_observation ()
Return the current observation. Typically constructed from the output of the dataloader and internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.
BaseEnvironment.reset
BaseEnvironment.reset ()
Reset the environment. This function must be provided, using the function self.reset_index() to handle indexing. It needs to account for the current training mode train, val, or test and handle the horizon_train param. See the reset function for the NewsvendorEnv for an example.
BaseEnvironment.set_index
BaseEnvironment.set_index (index=None)
Handle the index of the environment.
BaseEnvironment.get_start_index
BaseEnvironment.get_start_index (start_index:int|str=None)
Determine if the start index is random or 0, depending on the state of the environment and training process (over entire train set or in shorter episodes)
Type | Default | Details | |
---|---|---|---|
start_index | int | str | None | index to start from |
Returns | int |
BaseEnvironment.reset_index
BaseEnvironment.reset_index (start_index:Union[int,str])
Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is “random”, the index is set to a random integer between 0 and the length of the training data.
BaseEnvironment.update_mdp_info
BaseEnvironment.update_mdp_info (gamma=None, horizon=None)
Update the MDP info of the environment.
BaseEnvironment.train
BaseEnvironment.train (update_mdp_info=True)
Set the environment in training mode by both setting the internal state self._train and the dataloader. If the horizon is set to “use_all_data”, the horizon is set to the length of the training data, otherwise it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info and resets with the new state.
BaseEnvironment.val
BaseEnvironment.val (update_mdp_info=True)
Set the environment in validation mode by both setting the internal state self._val and the dataloader. The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info and resets with the new state.
BaseEnvironment.test
BaseEnvironment.test (update_mdp_info=True)
Set the environment in testing mode by both setting the internal state self._test and the dataloader. The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info and resets with the new state.
BaseEnvironment.set_return_truncation
BaseEnvironment.set_return_truncation (return_truncation:bool)
Set the return_truncation attribute of the environment.
Type | Details | |
---|---|---|
return_truncation | bool | whether or not to return the truncated condition in the step function |
BaseEnvironment.stop
BaseEnvironment.stop ()
Stop the environment. This function is used to ensure compatibility with the Core of mushroom_rl.