Main Content

train

Train reinforcement learning agents within a specified environment

Description

trainStats = train(env,agents) trains one or more reinforcement learning agents within the environment env, using default training options, and returns training results in trainStats. Although agents is an input argument, after each training episode, train updates the parameters of each agent specified in agents to maximize their expected long-term reward from the environment. This is possible because each agent is an handle object. When training terminates, agents reflects the state of each agent at the end of the final training episode.

Note

To train an off-policy agent offline using existing data, use trainFromData.

trainStats = train(agents,env) performs the same training as the previous syntax.

example

trainStats = train(___,trainOpts) trains agents within env, using the training options object trainOpts. Use training options to specify training parameters such as the criteria for terminating training, when to save agents, the maximum number of episodes to train, and the maximum number of steps per episode.

example

trainStats = train(___,prevTrainStats) resumes training from the last values of the agent parameters and training results contained in prevTrainStats, which is returned by the previous function call to train.

example

trainStats = train(___,Name=Value) train agents with additional name-value arguments. Use this syntax to specify a logger or evaluator object to be used in training. Logger and evaluator objects allow you to periodically log results to disk and to evaluate agents, respectively.

Examples

collapse all

Train the agent configured in the Train PG Agent to Balance Cart-Pole System example, within the corresponding environment. The observation from the environment is a vector containing the position and velocity of a cart, as well as the angular position and velocity of the pole. The action is a scalar with two possible elements (a force of either -10 or 10 Newtons applied to a cart).

Load the file containing the environment and a PG agent already configured for it.

load RLTrainExample.mat

Specify some training parameters using rlTrainingOptions. These parameters include the maximum number of episodes to train, the maximum steps per episode, and the conditions for terminating training. For this example, use a maximum of 2000 episodes and 500 steps per episode. Instruct the training to stop when the average reward over the previous twenty episodes reaches 495. Create a default options set and use dot notation to change some of the parameter values.

trainOpts = rlTrainingOptions;

trainOpts.MaxEpisodes = 2000;
trainOpts.MaxStepsPerEpisode = 500;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 495;
trainOpts.ScoreAveragingWindowLength = 20;

During training, the train command can save candidate agents that give good results. Further configure the training options to save an agent when the episode reward exceeds 499. Save the agent to a folder called savedAgents.

trainOpts.SaveAgentCriteria = "EpisodeReward";
trainOpts.SaveAgentValue = 499;
trainOpts.SaveAgentDirectory = "savedAgents";

Turn off the command-line display. Turn on the Reinforcement Learning Training Monitor so you can observe the training progress visually.

trainOpts.Verbose = false;
trainOpts.Plots = "training-progress";

You are now ready to train the PG agent. For the predefined cart-pole environment used in this example, you can use plot to generate a visualization of the cart-pole system.

plot(env)

When you run this example, both this visualization and the Reinforcement Learning Training Monitor update with each training episode. Place them side by side on your screen to observe the progress, and train the agent. (This computation can take 20 minutes or more.)

trainingInfo = train(agent,env,trainOpts);

The Reinforcement Learning Training Monitor shows that the training successfully reaches the termination condition of a reward of 500 averaged over the previous five episodes. At each training episode, train updates agent with the parameters learned in the previous episode. When training terminates, you can simulate the environment with the trained agent to evaluate its performance. The environment plot updates during simulation as it did during training.

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

During training, train saves to disk any agents that meet the condition specified with trainOps.SaveAgentCritera and trainOpts.SaveAgentValue. To test the performance of any of those agents, you can load the data from the data files in the folder you specified using trainOpts.SaveAgentDirectory, and simulate the environment with that agent.

This example shows how to periodically evaluate an agent during training using an rlEvaluator object.

Load the predefined environment object representing a cart-pole system with a discrete action space. For more information on this environment, see Load Predefined Control System Environments.

env = rlPredefinedEnv("CartPole-Discrete");

The agent networks are initialized randomly. Ensure reproducibility by fixing the seed of the random generator.

rng(0)

Create a DQN agent with default networks.

agent = rlDQNAgent(getObservationInfo(env),getActionInfo(env));

Use the standard algorithm instead of the double DQN and configure agent options for training.

agent.AgentOptions.UseDoubleDQN                               = false;
agent.AgentOptions.CriticOptimizerOptions.LearnRate           = 1e-3;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold   = 5;
agent.AgentOptions.MiniBatchSize                              = 128;
agent.AgentOptions.DiscountFactor                             = 0.99;
agent.AgentOptions.TargetSmoothFactor                         = 5e-3;
agent.AgentOptions.ExperienceBufferLength                     = 1e5;
agent.AgentOptions.SampleTime                                 = env.Ts;

To specify training options, create an rlTrainingOptions object. Configure training to stop after when the average reward reaches 480.

tngOpts = rlTrainingOptions(...
    MaxEpisodes=5000, ...
    StopTrainingCriteria="AverageReward",...
    StopTrainingValue=480);

To evaluate the agent during training, create an rlEvaluator object. Configure the evaluator to run 5 consecutive evaluation episodes every 50 training episodes.

evl = rlEvaluator( ...
    NumEpisodes=5, ...
    EvaluationFrequency=50)
evl = 
  rlEvaluator with properties:

    EvaluationStatisticType: "MeanEpisodeReward"
                NumEpisodes: 5
         MaxStepsPerEpisode: []
       UseExplorationPolicy: 0
                RandomSeeds: 1
        EvaluationFrequency: 50

To train the agent using these evaluation options, pass evl to train. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    results = train(agent, env, tngOpts, Evaluator=evl);
else
    load("DQNAgent.mat","agent","results");
end

The red stars on the plot indicate the statistic (for this example the average episode reward) collected for the evaluation episodes.

Display the reward accumulated during the last episode.

results.EpisodeReward(end)
ans = 500

This value means that the agent is able to balance the cart-pole system for the whole episode.

Display the size of the evaluation statistic vector returned for each episode.

size(results.EvaluationStatistic)
ans = 1×2

   395     1

Display only the finite values, corresponding to the training episodes at the end of which the 5 evaluation episodes are run.

results.EvaluationStatistic(isfinite(results.EvaluationStatistic))
ans = 7×1

    4.4000
  150.0000
  117.2000
  101.8000
  248.2000
  326.2000
  195.6000

Train the agents configured in the Train Multiple Agents to Perform Collaborative Task example, within the corresponding environment.

Run the script that loads the environment parameters.

rlCollaborativeTaskParams

Load the file containing the agents. For this example, load the agents that have been already trained using decentralized learning.

load decentralizedAgents.mat

Create an environment object that uses the Simulink® model rlCollaborativeTask. Since the agent objects referred by the agent blocks are already available in the MATLAB® workspace at the time of environment creation, the observation and action specification arrays are not needed. For more information, see rlSimulinkEnv.

env = rlSimulinkEnv("rlCollaborativeTask", ...
    ["rlCollaborativeTask/Agent A", "rlCollaborativeTask/Agent B"])
env = 
SimulinkEnvWithAgent with properties:

           Model : rlCollaborativeTask
      AgentBlock : [
                     rlCollaborativeTask/Agent A
                     rlCollaborativeTask/Agent B
                   ]
        ResetFcn : []
  UseFastRestart : on

It is good practice to specify a reset function for the environment such that agents start from random initial positions at the beginning of each episode. For an example, see the resetRobots function defined in Train Multiple Agents to Perform Collaborative Task. For this example, however, do not define a reset function.

For this example, configure the training to be centralized.

  • Allocate both agents (with indices 1 and 2) in a single group. Do this by specifying the agent indices in the "AgentGroups" option.

  • Specify the "centralized" learning strategy.

  • For this example, run the training for 5 episodes, with each episode lasting at most 600 time steps.

  • Do not visualize training progress.

trainOpts = rlMultiAgentTrainingOptions(...
    AgentGroups={[1,2]},...
    LearningStrategy="centralized",...
    MaxEpisodes=5,...
    MaxStepsPerEpisode=600,...
    StopTrainingCriteria="none",...
    Plots="none");

Train the agents using the train function.

results = train([dcAgentA,dcAgentB],env,trainOpts);

Replaying the animation plot shows you how the agent behaves in the training.

This example shows how to resume training using existing training data for training Q-learning. For more information on these agents, see Q-Learning Agents and SARSA Agents.

Create Grid World Environment

For this example, create the basic grid world environment.

env = rlPredefinedEnv("BasicGridWorld");

To specify that the initial state of the agent is always [2,1], create a reset function that returns the state number for the initial agent state.

x0 = [1:12 15:17 19:22 24];
env.ResetFcn = @() x0(randi(numel(x0)));

Fix the random generator seed for reproducibility.

rng(1)

Create Q-Learning Agent

To create a Q-learning agent, first create a Q table using the observation and action specifications from the grid world environment. Set the learning rate of the representation to 1.

qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qVf = rlQValueFunction(qTable,getObservationInfo(env),getActionInfo(env));

Next, create a Q-learning agent using this table representation and configure the epsilon-greedy exploration. For more information on creating Q-learning agents, see rlQAgent and rlQAgentOptions. Keep the default value of the discount factor to 0.99.

agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.2;
agentOpts.CriticOptimizerOptions.LearnRate = 0.2;
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-3;
agentOpts.EpsilonGreedyExploration.EpsilonMin = 1e-3;
agentOpts.DiscountFactor = 1;
qAgent = rlQAgent(qVf,agentOpts);

Train Q-Learning Agent for 100 Episodes

To train the agent, first specify the training options. For more information, see rlTrainingOptions.

trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 200;
trainOpts.MaxEpisodes = 1e6;
trainOpts.Plots = "none";
trainOpts.Verbose = false;

trainOpts.StopTrainingCriteria = "EpisodeCount";
trainOpts.StopTrainingValue = 100;
trainOpts.ScoreAveragingWindowLength = 30;

Train the Q-learning agent using the train function. Training can take several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

trainingStats = train(qAgent,env,trainOpts);

Display index of last episode.

trainingStats.EpisodeIndex(end)
ans = 100

Train Q-Learning Agent for 200 More Episodes

Set the training to stop after episode 300.

trainingStats.TrainingOptions.StopTrainingValue = 300;

Resume the training using the training data that exists in trainingStats.

trainingStats = train(qAgent,env,trainingStats);

Display index of last episode.

trainingStats.EpisodeIndex(end)
ans = 300

Plot episode reward.

figure()
plot(trainingStats.EpisodeIndex,trainingStats.EpisodeReward)
title("Episode Reward")
xlabel("EpisodeIndex")
ylabel("EpisodeReward")

Display the final Q-value table.

qAgentFinalQ = getLearnableParameters(getCritic(qAgent));
qAgentFinalQ{1}
ans = 25x4 single matrix

   -5.9934    5.4707   10.0000    1.6349
    8.9968   -4.5969   -4.7967   -8.0369
   -3.9844    8.0000   -4.3924   -6.3623
   -4.4457   -3.4794    9.0000   -4.1959
   -4.4743   -2.3964    7.0000    1.7904
   -4.5117   -3.7606   11.0000   -1.3847
   -3.5016    6.8809   12.0000    4.0197
   11.0000   -3.8480    0.6307   -3.0320
   10.0000    7.0000   -1.5601   -3.4550
    3.0709    4.2059    8.0000    4.8305
      ⋮

Validate Q-Learning Results

To validate the training results, simulate the agent in the training environment.

Before running the simulation, visualize the environment and configure the visualization to maintain a trace of the agent states.

plot(env)
env.ResetFcn = @() 2;
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

Simulate the agent in the environment using the sim function.

sim(qAgent,env)

Input Arguments

collapse all

Agents to train, specified as a reinforcement learning agent object, such as rlACAgent or rlDDPGAgent, or as an array of such objects.

If env is a multi-agent environment, specify agents as an array. The order of the agents in the array must match the agent order used to create env.

Note

train updates the agents as training progresses. This is possible because each agent is an handle object. To preserve the original agent parameters for later use, save the agent to a MAT-file (if you copy the agent into a new variable, the new variable will also always point to the most recent agent version with updated parameters). For more information about handle objects, see Handle Object Behavior.

Note

When training terminates, agents reflects the state of each agent at the end of the final training episode. The rewards obtained by the final agents are not necessarily the highest achieved during the training process, due to continuous exploration. To save agents during training, create an rlTrainingOptions object specifying the SaveAgentCriteria and SaveAgentValue properties and pass it to train as a trainOpts argument.

For more information about how to create and configure agents for reinforcement learning, see Reinforcement Learning Agents.

Environment in which the agents act, specified as one of the following kinds of reinforcement learning environment object:

  • A predefined MATLAB® or Simulink® environment created using rlPredefinedEnv.

  • A custom MATLAB environment you create with functions such as rlFunctionEnv or rlCreateEnvTemplate. This kind of environment does not support training multiple agents at the same time.

  • A Simulink environment you create using createIntegratedEnv. This kind of environment does not support training multiple agents at the same time.

  • A custom Simulink environment you create using rlSimulinkEnv. This kind of environment supports training multiple agents at the same time, and allows you to use multi-rate execution, so that each agent has its own execution rate.

  • A custom MATLAB environment you create using rlMultiAgentFunctionEnv or rlTurnBasedFunctionEnv. This kind of environment supports training multiple agents at the same time. In an rlMultiAgentFunctionEnv environment all agents execute in the same step, while in an rlTurnBasedFunctionEnv environment agents execute in turns.

For more information about creating and configuring environments, see:

When env is a Simulink environment, the environment object acts an interface so that train calls the (compiled) Simulink model to generate experiences for the agents.

Training parameters and options, specified as either an rlTrainingOptions or an rlMultiAgentTrainingOptions object. Use this argument to specify parameters and options such as:

  • Criteria for ending training

  • Criteria for saving candidate agents

  • How to display training progress

  • Options for parallel computing

For details, see rlTrainingOptions and rlMultiAgentTrainingOptions.

Training episode data, specified as an:

  • rlTrainingResult object, when training a single agent.

  • Array of rlTrainingResult objects when training multiple agents.

Use this argument to resume training from the exact point at which it stopped. This starts the training from the last values of the agent parameters and training results object obtained after the previous train function call. prevTrainStats contains, as one of its properties, the rlTrainingOptions object or the rlMultiAgentTrainingOptions object specifying the training option set. Therefore, to restart the training with updated training options, first change the training options in trainResults using dot notation. If the maximum number of episodes was already reached in the previous training session, you must increase the maximum number of episodes.

For details about the rlTrainingResult object properties, see the trainStats output argument.

Name-Value Arguments

Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Example: train(agent,env,Evaluator=myEval)

Logger object, specified either as a FileLogger or as a MonitorLogger object. Use a logger object to periodically save data during training. For more information on reinforcement logger objects, see rlDataLogger.

Evaluator object, specified either as a rlEvaluator or as a rlCustomEvaluator object. Use an evaluator object to periodically save data during training. For more information on reinforcement learning evaluator objects, see rlEvaluator and rlCustomEvaluator.

Output Arguments

collapse all

Training episode data, returned as an:

  • rlTrainingResult object, when training a single agent.

  • rlMultiAgentTrainingResult object when training multiple agents.

The following properties pertain to both the rlTrainingResult and rlMultiAgentTrainingResult objects:

Episode numbers, returned as the column vector [1;2;…;N], where N is the number of episodes in the training run. This vector is useful if you want to plot the evolution of other quantities from episode to episode.

Reward for each episode, returned in a column vector of length N. Each entry contains the reward for the corresponding episode. For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Number of steps in each episode, returned in a column vector of length N. Each entry contains the number of steps in the corresponding episode.

For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Average reward over the averaging window specified in trainOpts, returned as a column vector of length N. Each entry contains the average award computed at the end of the corresponding episode.

For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Total number of agent steps in training, returned as a column vector of length N. Each entry contains the cumulative sum of the entries in EpisodeSteps up to that point.

For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Critic estimate of expected discounted cumulative long-term reward using the current agent and the environment initial conditions, returned as a column vector of length N. Each entry is the critic estimate (Q0) for the agent of the beginning of corresponding episode. This field is present only for agents that have critics, such as rlDDPGAgent and rlDQNAgent.

For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Environment simulation information, returned as:

  • An SimulationStorage object, if SimulationStorageType is set to "memory" or "file".

  • An empty array, if SimulationStorageType is set to "none".

A SimulationStorage object contains environment information collected during simulation, which you can access by indexing into the object using the episode number.

For example, if res is an rlTrainingResult object returned by train, or an experience structure returned by sim, you can access the environment simulation information related to the second episode as:

mySimInfo2 = res.SimulationInfo(2);
  • For MATLAB environments, mySimInfo2 is a structure containing the field SimulationError. This structure contains any errors that occurred during simulation for the second episode.

  • For Simulink environments, mySimInfo2 is a Simulink.SimulationOutput object containing logged data from the Simulink model. Properties of this object include any signals and states that the model is configured to log, simulation metadata, and any errors that occurred during the second episode.

A SimulationStorage object also has the following read-only properties:

Total number of episodes ran in the entire training or simulation, returned as a positive integer.

Example: 2670

Type of storage for the environment data, returned as either "memory" (indicating that data is stored in memory) or "file" (indicating that data is stored on disk). For more information, see the SimulationStorageType property of rlEvolutionStrategyTrainingOptions and Address Memory Issues During Training.

Example: "file"

Evaluation statistic for each episode, returned as a column vector with as many elements as the number of episodes. When a (training) episode is followed by a number of consecutive evaluation episodes, the corresponding EvaluationStatistic element is a statistic (for example, mean, maximum, minimum, median) calculated from these evaluation episodes. Otherwise, when the episode is followed by another training episode, the EvaluationStatistic element corresponding to the episode is NaN. If no evaluator object is passed to train, each element of this vector is NaN. For more information, see rlEvaluator and rlCustomEvaluator.

For multiagent environments, this property contains a matrix in which each column corresponds to an agent.

Training options set, returned as:

  • For a single agent environment, an rlTrainingOptions object. For more information, see rlTrainingOptions.

  • For a multiagent environment, an rlMultiAgentTrainingOptions object. For more information, see rlMultiAgentTrainingOptions.

Tips

  • train updates the agents as training progresses. To preserve the original agent parameters for later use, save the agents to a MAT-file.

  • By default, calling train opens the Reinforcement Learning Training Monitor, which lets you visualize the progress of the training. The Reinforcement Learning Training Monitor plot shows the reward for each episode, a running average reward value, and the critic estimate Q0 (for agents that have critics). The Reinforcement Learning Training Monitor also displays various episode and training statistics. To turn off the Reinforcement Learning Training Monitor, set the Plots option of trainOpts to "none".

  • If you use a predefined environment for which there is a visualization, you can use plot(env) to visualize the environment. If you call plot(env) before training, then the visualization updates during training to allow you to visualize the progress of each episode. (For custom environments, you must implement your own plot method.)

  • Training terminates when the conditions specified in trainOpts are satisfied. To terminate training in progress, in the Reinforcement Learning Training Monitor, click Stop Training. Because train updates the agent at each episode, you can resume training by calling train(agent,env,trainOpts) again, without losing the trained parameters learned during the first call to train.

  • During training, you can save candidate agents that meet conditions you specify with trainOpts. For instance, you can save any agent whose episode reward exceeds a certain value, even if the overall condition for terminating training is not yet satisfied. train stores saved agents in a MAT-file in the folder you specify with trainOpts. Saved agents can be useful, for instance, to allow you to test candidate agents generated during a long-running training process. For details about saving criteria and saving location, see rlTrainingOptions.

Algorithms

In general, train performs the following iterative steps:

  1. Initialize agent.

  2. For each episode:

    1. Reset the environment.

    2. Get the initial observation s0 from the environment.

    3. Compute the initial action a0 = μ(s0).

    4. Set the current action to the initial action (aa0) and set the current observation to the initial observation (ss0).

    5. While the episode is not finished or terminated:

      1. Step the environment with action a to obtain the next observation s' and the reward r.

      2. Learn from the experience set (s,a,r,s').

      3. Compute the next action a' = μ(s').

      4. Update the current action with the next action (aa') and update the current observation with the next observation (ss').

      5. Break if the episode termination conditions defined in the environment are met.

  3. If the training termination condition defined by trainOpts is met, terminate training. Otherwise, begin the next episode.

The specifics of how train performs these computations depends on your configuration of the agent and environment. For instance, resetting the environment at the start of each episode can include randomizing initial state values, if you configure your environment to do so.

Extended Capabilities

Version History

Introduced in R2019a

expand all