from ddopai.envs.inventory.single_period import NewsvendorEnv
from ddopai.dataloaders.tabular import XYDataLoader
from ddopai.experiments.experiment_functions import run_experiment, test_agent
SAA based agents
BaseSAAagent
BaseSAAagent (environment_info:ddopai.utils.MDPInfo, obsprocessors:Optional[List[object]]=None, agent_name:str|None=None)
Base class for Sample Average Approximation Agents, implementing the main method to find the quntile of some (weighted) empirical distribution.
BaseSAAagent._validate_X_predict
BaseSAAagent._validate_X_predict (X)
Validate X data before prediction
BaseSAAagent.find_weighted_quantiles
BaseSAAagent.find_weighted_quantiles (weights, weightPosIndices, sl, y)
Find the weighted quantile of a range of data y. It assumes that all arrays are of shape (n_samples, n_outputs). Note that it has not been tested for n_outputs > 1.
NewsvendorSAAagent
NewsvendorSAAagent (environment_info:ddopai.utils.MDPInfo, cu:float|numpy.ndarray, co:float|numpy.ndarray, obsprocessors:list[object]|None=None, agent_name:str='SAA')
Newsvendor agent that uses Sample Average Approximation to find the quantile of the empirical distribution
Type | Default | Details | |
---|---|---|---|
environment_info | MDPInfo | ||
cu | float | numpy.ndarray | underage cost | |
co | float | numpy.ndarray | overage cost | |
obsprocessors | list[object] | None | None | |
agent_name | str | SAA |
Further information:
References:
.. [1] Levi, Retsef, Georgia Perakis, and Joline Uichanco. "The data-driven newsvendor problem: new bounds and insights."
Operations Research 63.6 (2015): 1294-1306.
NewsvendorSAAagent.fit
NewsvendorSAAagent.fit (X:numpy.ndarray, Y:numpy.ndarray)
Fit the agent to the data. The agent will find the quantile of the empirical distribution of the data.
Type | Details | |
---|---|---|
X | ndarray | features will be ignored |
Y | ndarray | |
Returns | None |
NewsvendorSAAagent.draw_action_
NewsvendorSAAagent.draw_action_ (observation:numpy.ndarray)
Draw an action from the quantile of the empirical distribution.
Type | Details | |
---|---|---|
observation | ndarray | |
Returns | ndarray |
NewsvendorSAAagent.save
NewsvendorSAAagent.save (path:str, overwrite:bool=True)
Save the quantiles to a file in the specified directory.
Type | Default | Details | |
---|---|---|---|
path | str | The directory where the file will be saved. | |
overwrite | bool | True | Allow overwriting; if False, a FileExistsError will be raised if the file exists. |
NewsvendorSAAagent.load
NewsvendorSAAagent.load (path:str)
Load the quantiles from a file.
Type | Details | |
---|---|---|
path | str | Only the path to the folder is needed, not the file itself |
BasewSAAagent
BasewSAAagent (environment_info:ddopai.utils.MDPInfo, cu:float|numpy.ndarray, co:float|numpy.ndarray, obsprocessors:list[object]|None=None, agent_name:str='wSAA')
Base class for weighted Sample Average Approximation (wSAA) Agents
BasewSAAagent.fit
BasewSAAagent.fit (X:numpy.ndarray, Y:numpy.ndarray)
*Fit the agent to the data. The function will call _get_fitted_model which will train a machine learning model to determine the sample weightes (e.g., kNN, DT, RF).*
Type | Details | |
---|---|---|
X | ndarray | |
Y | ndarray |
BaseAgent.draw_action
BaseAgent.draw_action (observation:numpy.ndarray)
Main interfrace to the environemnt. Applies preprocessors to the observation. Internal logic of the agent to be implemented in draw_action_ method.
Type | Details | |
---|---|---|
observation | ndarray | |
Returns | ndarray |
BasewSAAagent._get_fitted_model
BasewSAAagent._get_fitted_model (X, y)
Initialise the underlying model - depending on the underlying machine learning model
BasewSAAagent._calc_weights
BasewSAAagent._calc_weights (sample)
Calculate the sample weights - depending on the underlying machine learning model
BasewSAAagent.predict
BasewSAAagent.predict (X:numpy.ndarray)
Predict value for X by finding the quantiles of the empirical distribution based on the sample weights predicted by the underlying machine learning model.
Type | Details | |
---|---|---|
X | ndarray | |
Returns | ndarray |
BasewSAAagent.save
BasewSAAagent.save (path:str, overwrite:bool=True)
Save the scikit-learn model to a file in the specified directory.
Type | Default | Details | |
---|---|---|---|
path | str | The directory where the file will be saved. | |
overwrite | bool | True | Allow overwriting; if False, a FileExistsError will be raised if the file exists. |
BasewSAAagent.load
BasewSAAagent.load (path:str)
Load the scikit-learn model from a file.
Type | Details | |
---|---|---|
path | str | Only the path to the folder is needed, not the file itself |
NewsvendorRFwSAAagent
NewsvendorRFwSAAagent (environment_info:ddopai.utils.MDPInfo, cu:float|numpy.ndarray, co:float|numpy.ndarray, obsprocessors:list[object]|None=None, n_estimators:int=100, criterion:str='squared_error', max_depth:int|None=None, min_samples_split:int=2, min_samples_leaf:int=1, min_weight_fraction_leaf:float=0.0, max_features:int|float|str|None=1.0, max_leaf_nodes:int|None=None, min_impurity_decrease:float=0.0, bootstrap:bool=True, oob_score:bool=False, n_jobs:int|None=None, random_state:int|numpy.rando m.mtrand.RandomState|None=None, verbose:int=0, warm_start:bool=False, ccp_alpha:float=0.0, max_samples:int|float|None=None, monotonic_cst:numpy.ndarray|None=None, agent_name:str='wSAA')
Newsvendor agent that uses weighted Sample Average Approximation based on Random Forest
Type | Default | Details | |
---|---|---|---|
environment_info | MDPInfo | ||
cu | float | numpy.ndarray | underage cost | |
co | float | numpy.ndarray | overage cost | |
obsprocessors | list[object] | None | None | List of obsprocessors to apply to the observation |
n_estimators | int | 100 | The number of trees in the forest. |
criterion | str | squared_error | Function to measure the quality of a split. |
max_depth | int | None | None | Maximum depth of the tree; None means unlimited. |
min_samples_split | int | 2 | Minimum samples required to split a node. |
min_samples_leaf | int | 1 | Minimum samples required to be at a leaf node. |
min_weight_fraction_leaf | float | 0.0 | Minimum weighted fraction of the total weights at a leaf node. |
max_features | int | float | str | None | 1.0 | Number of features to consider when looking for the best split. |
max_leaf_nodes | int | None | None | Maximum number of leaf nodes; None means unlimited. |
min_impurity_decrease | float | 0.0 | Minimum impurity decrease required to split a node. |
bootstrap | bool | True | Whether to use bootstrap samples when building trees. |
oob_score | bool | False | Whether to use out-of-bag samples to estimate R^2 on unseen data. |
n_jobs | int | None | None | Number of jobs to run in parallel; None means 1. |
random_state | int | numpy.random.mtrand.RandomState | None | None | Controls randomness for bootstrapping and feature sampling. |
verbose | int | 0 | Controls the verbosity when fitting and predicting. |
warm_start | bool | False | If True, reuse solution from previous fit and add more estimators. |
ccp_alpha | float | 0.0 | Complexity parameter for Minimal Cost-Complexity Pruning. |
max_samples | int | float | None | None | Number of samples to draw when bootstrap is True. |
monotonic_cst | numpy.ndarray | None | None | Monotonic constraints for features. |
agent_name | str | wSAA | Default wSAA, change if it is needed to differentiate among different ML models |
Further information:
Notes —–
The default values for the parameters controlling the size of the trees (e.g. max_depth
, min_samples_leaf
, etc.) lead to fully grown and unpruned trees which can potentially be very large on some data sets. To reduce memory consumption, the complexity and size of the trees should be controlled by setting those parameter values. The features are always randomly permuted at each split. Therefore, the best found split may vary, even with the same training data, max_features=n_features
and bootstrap=False
, if the improvement of the criterion is identical for several splits enumerated during the search of the best split. To obtain a deterministic behaviour during fitting, random_state
has to be fixed.
References ———-
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
.. [2] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized
trees", Machine Learning, 63(1), 3-42, 2006.
.. [3] Bertsimas, Dimitris, and Nathan Kallus, "From predictive to prescriptive analytics."
arXiv preprint arXiv:1402.5481 (2014).
.. [4] scikit-learn, RandomForestRegressor,
<https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/_forest.py>
.. [5] Scornet, Erwan. "Random forests and kernel methods."
IEEE Transactions on Information Theory 62.3 (2016): 1485-1500.
NewsvendorRFwSAAagent._get_fitted_model
NewsvendorRFwSAAagent._get_fitted_model (X:numpy.ndarray, Y:numpy.ndarray)
Fit the underlying machine learning model using all X and Y data in the train set.
Type | Details | |
---|---|---|
X | ndarray | |
Y | ndarray |
NewsvendorRFwSAAagent._calc_weights
NewsvendorRFwSAAagent._calc_weights (sample:numpy.ndarray)
Calculate the sample weights based on the Random Forest model.
Type | Details | |
---|---|---|
sample | ndarray | |
Returns | tuple |
Example usage:
= 800 #90_000
val_index_start = 900 #100_000
test_index_start
= np.random.rand(1000, 2)
X = np.random.rand(1000, 1)
Y
= XYDataLoader(X, Y, val_index_start, test_index_start)
dataloader
= NewsvendorEnv(
environment = dataloader,
dataloader = 0.42857,
underage_cost = 1.0,
overage_cost = 0.999,
gamma = 365,
horizon_train
)
= NewsvendorSAAagent(environment.mdp_info, cu=0.42857, co=1.0)
agent = NewsvendorRFwSAAagent(environment.mdp_info, cu=0.42857, co=1.0)
agent
environment.test()eval()
agent.
= test_agent(agent, environment)
R, J
print(R, J)
100, run_id = "test", save_best=True) # fit agent via run_experiment function
run_experiment(agent, environment,
environment.test()eval()
agent.
= test_agent(agent, environment)
R, J
print(R, J)
-18.01888542213257 -17.142493964355882
WARNING:root:Overwriting file results/test/saved_models/best/model.joblib
results
-15.763567080255545 -15.022369246527656 -15.763567080255545 -15.022369246527656
-17.334785352427232 -16.554914069406784