SACRNNAgent (environment_info:ddopai.utils.MDPInfo,
hidden_layers_RNN:int=1, num_hidden_units_RNN:int=64,
RNN_cell:str='GRU', hidden_layers_MLP:List=None,
hidden_layers_input_MLP:List=None, activation:str='relu',
learning_rate_actor:float=0.0003,
learning_rate_critic:float|None=None,
initial_replay_size:int=64, max_replay_size:int=50000,
batch_size:int=64, warmup_transitions:int=100,
lr_alpha:float=0.0003, tau:float=0.005,
log_std_min:float=-20.0, log_std_max:float=2.0,
use_log_alpha_loss=False, target_entropy:float|None=None,
drop_prob:float=0.0, batch_norm:bool=False,
init_method:str='xavier_uniform', optimizer:str='Adam',
loss:str='MSE', obsprocessors:list|None=None,
device:str='cpu', agent_name:str|None='SAC',
observation_space_shape=None, action_space_shape=None)
val_index_start = 8000 #90_000
test_index_start = 9000 #100_000
X = np.random.standard_normal((10000 , 2 ))
Y = np.random.standard_normal((10000 , 1 ))
Y += 2 * X[:,0 ].reshape(- 1 , 1 ) + 3 * X[:,1 ].reshape(- 1 , 1 )
Y = X[:,0 ].reshape(- 1 , 1 )
# truncate Y at 0:
Y = np.maximum(Y, 0 )
# normalize Y max to 1
Y = Y/ np.max (Y)
clip_action = ClipAction(0. , 1. )
dataloader = XYDataLoader(X, Y, val_index_start, test_index_start, lag_window_params = {'lag_window' : 5 , 'include_y' : True , 'pre_calc' : True })
environment = NewsvendorEnv(
dataloader = dataloader,
underage_cost = 0.42857 ,
overage_cost = 1.0 ,
gamma = 0.999 ,
horizon_train = 365 ,
q_bound_high = 1.0 ,
q_bound_low = - 0.1 ,
postprocessors = [clip_action],
)
agent = SACRNNAgent(environment.mdp_info,
obsprocessors = None , # default: []
device= "cpu" , # "cuda" or "cpu"
)
environment.test()
agent.eval ()
R, J = test_agent(agent, environment)
print (R, J)
environment.train()
agent.train()
environment.print = False
# run_experiment(agent, environment, n_epochs=50, n_steps=1000, run_id = "test", save_best=True, print_freq=1) # fit agent via run_experiment function
environment.test()
agent.eval ()
R, J = test_agent(agent, environment)
print (R, J)
/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: WARN: Box bound precision lowered by casting to float32
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
INFO:root:Actor network (mu network):
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
RNNActor [1, 1] --
├─RNNMLPHybrid: 1-1 [1, 1] --
│ └─Sequential: 2-1 [1, 6, 64] --
│ │ └─SpecificRNNWrapper: 3-1 [1, 6, 64] 13,248
│ │ └─ReLU: 3-2 [1, 6, 64] --
│ └─Sequential: 2-2 [1, 1] --
│ │ └─Linear: 3-3 [1, 64] 4,160
│ │ └─ReLU: 3-4 [1, 64] --
│ │ └─Dropout: 3-5 [1, 64] --
│ │ └─Linear: 3-6 [1, 64] 4,160
│ │ └─ReLU: 3-7 [1, 64] --
│ │ └─Dropout: 3-8 [1, 64] --
│ │ └─Linear: 3-9 [1, 1] 65
==========================================================================================
Total params: 21,633
Trainable params: 21,633
Non-trainable params: 0
Total mult-adds (M): 0.09
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.09
Estimated Total Size (MB): 0.09
==========================================================================================
INFO:root:################################################################################
INFO:root:Critic network:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
RNNStateAction -- --
├─RNNMLPHybrid: 1-1 [1, 1] --
│ └─Sequential: 2-1 [1, 6, 64] --
│ │ └─SpecificRNNWrapper: 3-1 [1, 6, 64] 13,248
│ │ └─ReLU: 3-2 [1, 6, 64] --
│ └─Sequential: 2-2 [1, 1] --
│ │ └─Linear: 3-3 [1, 64] 4,224
│ │ └─ReLU: 3-4 [1, 64] --
│ │ └─Dropout: 3-5 [1, 64] --
│ │ └─Linear: 3-6 [1, 64] 4,160
│ │ └─ReLU: 3-7 [1, 64] --
│ │ └─Dropout: 3-8 [1, 64] --
│ │ └─Linear: 3-9 [1, 1] 65
==========================================================================================
Total params: 21,697
Trainable params: 21,697
Non-trainable params: 0
Total mult-adds (M): 0.09
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.09
Estimated Total Size (MB): 0.09
==========================================================================================
-383.1306977574299 -243.60956423506602
-383.1306977574299 -243.60956423506602