How does RL algorithm work with RNNs?

3 vues (au cours des 30 derniers jours)
Tech Logg Ding
Tech Logg Ding le 10 Fév 2021
Hi,
I noticed that Matlab 2021a allows users to use RL algorithms, such as DDPG, with RNN in the deep neural network structure. This is great as it could benefit continuous control problems with time delay and time-dependent parameters.
However, I am wondering about the algorithm used by Matlab for the RL RNN learning process. RNN learn through backpropagation through time (BPTT), therefore, the sampled states for BPTT must be in series. On the other hand, RL algorithms (such as DDPG) learn by sampling random samples from the experience buffer; therefore, the algorithms does not integrate naturally compared to the conventional MLPNN structure. How does Matlab work with this? Is there any paper that I can refrence?
Next, I am also curious about the RNN BPTT execution in MATLAB. In RL, an episode could have hundred to thousands of time steps and RNN is usually expected to keep a memory of the states in each time step (referring to the unrolled structure) in order to learn the weights and bias for its' internal state. Does the series terminate at the end of every episode to update the RNN? Will this consume significantly more memory?
Thank you very much.
  1 commentaire
Tech Logg Ding
Tech Logg Ding le 23 Fév 2021
Bumping this question. After looking into the documentation, I've not found any information on how updates with RNN in DNN works. This paper (https://academic.oup.com/jigpal/article/18/5/620/751594?login=true) also describes that random episodes should be sampled with a short series for to train its' lstm network to work effectively. Does the RL toolbox include this?

Connectez-vous pour commenter.

Réponse acceptée

Takeshi Takahashi
Takeshi Takahashi le 24 Fév 2021
Hi,
rlDDPGAgent with RNN first randomly samples B sequences (trajectories) from the experience buffer, where B is MiniBatchSize. Then, it randomly selects the starting point of each sampled sequence if the sequence is longer than L, where L is SequenceLength you specified. The end point of the sequence will be determined by the starting point and L so that the length becomes L.
Suppose some sampled sequences from the experience buffer are shorter than L. In that case, the sequences are padded with fake samples so that all short sequences in a batch have the same length (L). We apply masking to those padded samples, and the padded samples don't affect the BPTT.
We use these short sequences as a batch for BPTT. MiniBatchSize and SequenceLength control the size of the batch. Bigger MiniBatchSize and SequenceLength require more memory space during BPTT.
I hope this clarifies your question.
Thank you.
  2 commentaires
Tech Logg Ding
Tech Logg Ding le 28 Fév 2021
Got it! Thank you very much! Does the other algorithms such as TD3 and SAC use the same sampling method?
Takeshi Takahashi
Takeshi Takahashi le 2 Mar 2021
Yes. They use the same method.

Connectez-vous pour commenter.

Plus de réponses (0)

Catégories

En savoir plus sur Sequence and Numeric Feature Data Workflows dans Help Center et File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by