Base Environment

Base environment class based on Gymnasium

source

BaseEnvironment

 BaseEnvironment (mdp_info:ddopai.utils.MDPInfo,
                  postprocessors:list[object]|None=None, mode:str='train',
                  return_truncation:str=True,
                  horizon_train:int|str='use_all_data')

Base class for environments enforcing a common interface that works with MushroomRL, as well as other RL libraries.

Type Default Details
mdp_info MDPInfo MDPInfo object to ensure compatibility with the agents
postprocessors list[object] | None None default is empty list
mode str train Initial mode (train, val, test) of the environment
return_truncation str True whether to return a truncated condition in step function
horizon_train int | str use_all_data horizon of the training data
Returns None

Important notes:

init method:

  • When adding parameters to the environment, make sure to always add them via set_param(...). This ensures all parameters are of the correct types and shapes.

  • During the init method, any Gymnasium environment expects the action and observation space to be defined. For clarity, avoid doing it directly in the init, but rather use the functions set_action_space() and set_observation_space() and call them in the ___init___ method.

train, val, test, and horizon (episode length):

  • When the __init__ method is called, the environment executes the train(), val() or test() methods. Therefore, they must be implemented in a way that they work right during set-up.

  • train(), val() and test() methods are provided in the base class, but can also be overwritten if necessary. In any case, they must set the dataloader to the correct dataset to ensure no data leakage. They also need to update mdp_info to update the horizon (episode length) of the environment

  • The horizon for validation and testing will be equal to the length of those datasets. For training, there is a parameter horizon_train that either contains a string “use_all_data” or an integer. If it is the former, the horizon will be the length of the training dataset. If it is the latter, the environment will play an episode of length horizon_train starting at a random point of the training dataset.

step method:

  • The step method is the core of the environment, calculating the next state (observation) and reward given an action. Since some frameworks expect a truncation condition (standard implementation in Gymnasium now) while others (e.g., mushroom_rl), do not, the step function is implemented in the base class and handles this (via a flag in in the environment called return_truncation). DO NOT OVERWRITE the step function, but rather implement the step_(self, action) (underscore) method in the specific environment. This function shall always return a tuple of the form (observation, reward, terminated, truncated, info).

  • For clarity, the construction of the next state (we call it more general observation to include POMDPs) is done in a separate method called get_observation() that must be called inside the step function. See documentation below and the Newsvendor environment envs.inventory.NewsvendorEnv for an example.

  • The dataloader will typically return an X,Y pair (where X are some features and Y typically is demand) The X is necessary at the end of the step to construct the next observation to be returned to the agent. The Y is only relevant one step later to calculate the reward. Hence, Y is typically transferred to the next step method via an object variable like self.demand (see envs.inventory.NewsvendorEnv as an example).

observation pre-processors and action post-processors:

  • Sometimes, it is necessary to process the observartion before giving it to the agent (e.g., changing shape) or to process the action before giving it to the environment (e.g., rounding). To ensure compatibility with mushroom_rl, the pre-processors (also called observationprocessors) sit with the agent (they must be added to the agent and are applied in the agent’s draw_action() method). The post-processors (also called actionprocessors) sit with the environment and are applied in the environment’s step() method.

reset method:

  • The reset method may depend strongly on the environment dynamics, so it must be implemented for the specific environment. It needs to fulfill two requirements: 1) it needs to differenticate between train, val, and test mode and 2) when setting the training mode, it needs to take the horizon_train parameter into account.

  • At the end of the function, first the reset_index() method should be called (either with a specific index as integer or the flag "random"as input) and then the get_observation() method to construct the first observation.


source

BaseEnvironment.set_param

 BaseEnvironment.set_param (name:str, input:Union[ddopai.utils.Parameter,i
                            nt,float,numpy.ndarray,List,Dict,NoneType],
                            shape:tuple=(1,), new:bool=False)

Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, the function will create a new parameter or update an existing one otherwise. If new is set to False, the function will raise an error if the parameter does not exist.

Type Default Details
name str name of the parameter (will become the attribute name)
input Union input value of the parameter
shape tuple (1,) shape of the parameter
new bool False whether to create a new parameter or update an existing one
Returns None

source

BaseEnvironment.return_truncation_handler

 BaseEnvironment.return_truncation_handler (observation, reward,
                                            terminated, truncated, info)

Handle the return_truncation attribute of the environment. This function is called by the step function


source

BaseEnvironment.step

 BaseEnvironment.step (action)

Step function of the environment. Do not overwrite this function. Instead, write the step_ function. Note that the postprocessor is applied here.


source

BaseEnvironment.add_postprocessor

 BaseEnvironment.add_postprocessor (postprocessor:object)

Add a postprocessor (also called actionprocessor) to the agent

Type Details
postprocessor object post-processor object that can be called via the “call” method

source

BaseEnvironment.step_

 BaseEnvironment.step_ (action)

Step function of the environment. This function contains the logic of the environment and must be provided. It will be called by the step function that applies the actionprocessor and handles the return_truncation attribute.


source

BaseEnvironment.mdp_info

 BaseEnvironment.mdp_info ()

Returns: The MDPInfo object of the environment.


source

BaseEnvironment.info

 BaseEnvironment.info ()

Returns: Alternative call to the method for mushroom_rl.


source

BaseEnvironment.mode

 BaseEnvironment.mode ()

Returns: A string with the current mode (train, test val) of the environment.


source

BaseEnvironment.set_action_space

 BaseEnvironment.set_action_space ()

Set the action space of the environment.


source

BaseEnvironment.set_observation_space

 BaseEnvironment.set_observation_space ()

Set the observation space of the environment. In general, this can be also a dict space, but the agent must have the appropriate pre-processor.


source

BaseEnvironment.get_observation

 BaseEnvironment.get_observation ()

Return the current observation. Typically constructed from the output of the dataloader and internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.


source

BaseEnvironment.reset

 BaseEnvironment.reset ()

Reset the environment. This function must be provided, using the function self.reset_index() to handle indexing. It needs to account for the current training mode train, val, or test and handle the horizon_train param. See the reset function for the NewsvendorEnv for an example.


source

BaseEnvironment.set_index

 BaseEnvironment.set_index (index=None)

Handle the index of the environment.


source

BaseEnvironment.get_start_index

 BaseEnvironment.get_start_index (start_index:int|str=None)

Determine if the start index is random or 0, depending on the state of the environment and training process (over entire train set or in shorter episodes)

Type Default Details
start_index int | str None index to start from
Returns int

source

BaseEnvironment.reset_index

 BaseEnvironment.reset_index (start_index:Union[int,str])

Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is “random”, the index is set to a random integer between 0 and the length of the training data.


source

BaseEnvironment.update_mdp_info

 BaseEnvironment.update_mdp_info (gamma=None, horizon=None)

Update the MDP info of the environment.


source

BaseEnvironment.train

 BaseEnvironment.train (update_mdp_info=True)

Set the environment in training mode by both setting the internal state self._train and the dataloader. If the horizon is set to “use_all_data”, the horizon is set to the length of the training data, otherwise it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info and resets with the new state.


source

BaseEnvironment.val

 BaseEnvironment.val (update_mdp_info=True)

Set the environment in validation mode by both setting the internal state self._val and the dataloader. The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info and resets with the new state.


source

BaseEnvironment.test

 BaseEnvironment.test (update_mdp_info=True)

Set the environment in testing mode by both setting the internal state self._test and the dataloader. The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info and resets with the new state.


source

BaseEnvironment.set_return_truncation

 BaseEnvironment.set_return_truncation (return_truncation:bool)

Set the return_truncation attribute of the environment.

Type Details
return_truncation bool whether or not to return the truncated condition in the step function

source

BaseEnvironment.stop

 BaseEnvironment.stop ()

Stop the environment. This function is used to ensure compatibility with the Core of mushroom_rl.