Base agents
BaseAgent
BaseAgent (environment_info:ddopai.utils.MDPInfo, obsprocessors:list[object]|None=None, agent_name:str|None=None, receive_batch_dim:bool=False)
Base class for all agents to enforce a common interface. See below for more detailed description of the requriements.
Type | Default | Details | |
---|---|---|---|
environment_info | MDPInfo | ||
obsprocessors | list[object] | None | None | default is empty list |
agent_name | str | None | None | |
receive_batch_dim | bool | False |
Important notes:
Agents are, next to the environments, the core element of this library. The agents are the algorithms that take actions in the environment. They can be any type of algorithms ranging from optimization, supervised learning to reinforcement learning and any combination. Key for all the different agents to work is a common interface that allows them to interact with the environment.
Draw action:
The
draw_action
function is the main interface with the environment. It receives an observation as Numpy array and returns an action as Numpy array. The functiondraw_action
is defined in the[`BaseAgent`](https://opimwue.github.io/ddopai/30_agents/40_base_agents/base_agents.html#baseagent)
and should not be overwritten as it properly applies pre- and post-processing (see below).Agents always expect the observation to be of shape (batch_size, observation_dim) or (batch_size, time_dim, observation_dim) to allow batch-processing during training. Most environment do not have a batch dimension as they apply the step function to a single observation. Hence, the agent will by default add an extra dimension to the observation. If this is not desired, the agent has an attribute “receive_batch_dim” that can be set to True to tell the agent that the observation already has a batch dimension.
To create an agent, the function
draw_action_
(note the underscore!) needs to be defined that gets the pre-processed observation and returns the action for post-processing. This function should be overwritten in the derived class.
observation pre-processors and action post-processors:
Sometimes, it is necessary to process the observartion before giving it to the agent (e.g., changing shape) or to process the action before giving it to the environment (e.g., rounding). To ensure compatibility with mushroom_rl, the pre-processors sit with the agent (they must be added to the agent and are applied in the agent’s
draw_action()
method). The post-processors sit with the environment and are applied in the environment’sstep()
method.To differenciate the pre-processors here from the pre-processors used directly inside mushroom_rl, we call them obsprocessors, short for observation pre-processors.
During definition, one can already add the obsprocessors as lists (to the argument
obsprocessors
). After instantiation, processors are to be added using theadd_obsprocessor
method.Note that processors are applied in the order they are added.
Training:
- The
[`run_experiment`](https://opimwue.github.io/ddopai/40_experiments/experiment_functions.html#run_experiment)
function in this library currently supports three types of training processes:train_directly
: The agent is trained by calling agent.fit(X, Y) directly. In this case, the agent must have a fit function that takes the input and target data.train_epochs
: The agent is iteratively trained on the training data (e.g., via SGD). In this case, the functionfit_epoch
must be implemented.fit_epoch
does not get any argument, rather the dataloader from the environment needs to be given to the agent during initialization. The agent will then call the dataloader interatively to get the training data.env_interaction
: The agent is trained by interacting with the environment (e.g., like all reinforcement learning agents). This case build on theCore
class from MushroomRL.
Loading and saving:
- All agents must implement a save and load function that allows to save and load the agent’s parameters. See the Newsvendor ERM and (w)SAA agents for examples of different ways to save and load agents.
Dymamic class loading:
- This package allows to load agents dynamically with the
[`select_agent`](https://opimwue.github.io/ddopai/40_experiments/meta_experiment_functions.html#select_agent)
function that takes a string as input and returns the corresponding agent class. When creating new agents, make sure to add them to 10_AGENT_CLASSES.ipynb under the base agents folder with an appropriate name.
BaseAgent.draw_action
BaseAgent.draw_action (observation:numpy.ndarray)
Main interfrace to the environemnt. Applies preprocessors to the observation. Internal logic of the agent to be implemented in draw_action_ method.
Type | Details | |
---|---|---|
observation | ndarray | |
Returns | ndarray |
BaseAgent.draw_action_
BaseAgent.draw_action_ (observation:numpy.ndarray)
Generate an action based on the observation - this is the core method that needs to be implemented by all agents.
Type | Details | |
---|---|---|
observation | ndarray | |
Returns | ndarray |
BaseAgent.add_obsprocessor
BaseAgent.add_obsprocessor (obsprocessor:object)
Add a preprocessor to the agent
Type | Details | |
---|---|---|
obsprocessor | object | pre-processor object that can be called via the “call” method |
BaseAgent.train
BaseAgent.train ()
Set the internal state of the agent to train
BaseAgent.eval
BaseAgent.eval ()
Set the internal state of the agent to eval. Note that for agents we do not differentiate between val and test modes.
BaseAgent.add_batch_dim
BaseAgent.add_batch_dim (input:numpy.ndarray|dict[str,numpy.ndarray])
Add a batch dimension to the input array if it doesn’t already have one. This is necessary because most environments do not have a batch dimension, but agents typically expect one. If the environment does have a batch dimension, the agent can set the receive_batch_dim attribute to True to skip this step.
Type | Details | |
---|---|---|
input | numpy.ndarray | dict[str, numpy.ndarray] | |
Returns | numpy.ndarray | dict[str, numpy.ndarray] |
BaseAgent.save
BaseAgent.save ()
Save the agent’s parameters to a file.
BaseAgent.load
BaseAgent.load ()
Load the agent’s parameters from a file.
BaseAgent.update_model_params
BaseAgent.update_model_params (default_params:dict, custom_params:dict)
override default parameters with custom parameters in a dictionary
Type | Details | |
---|---|---|
default_params | dict | |
custom_params | dict | |
Returns | dict |
BaseAgent.convert_to_numpy_array
BaseAgent.convert_to_numpy_array (input:Union[numpy.ndarray,List,float,in t,ddopai.utils.Parameter,NoneType])
convert input to numpy array or keep as Parameter
Type | Details | |
---|---|---|
input | Union |