Main Content

Log Training Data to Disk

This example shows how to log custom data to disk when training agents using the Reinforcement Learning Toolbox™ train function.

Overview

The general steps for data logging are:

  1. Create a data logger object using the rlDataLogger function.

  2. Configure the data logger object with callback functions to specify the data to log at different stages of the training process.

  3. Specify the logger object as a name-value input argument in the train function.

Create Data Logger

Create a file data logger object using the rlDataLogger function.

fileLogger = rlDataLogger()
fileLogger = 
  FileLogger with properties:

           LoggingOptions: [1x1 rl.logging.option.MATFileLoggingOptions]
       EpisodeFinishedFcn: []
     AgentStepFinishedFcn: []
    AgentLearnFinishedFcn: []

Optionally specify the data logging directory and a naming rule for the files to be saved.

% Specify a logging directory. You must have write 
% access for this directory.
logDir = fullfile(pwd,"myDataLog");
fileLogger.LoggingOptions.LoggingDirectory = logDir;

% Specify a naming rule for files. The naming rule episode<id>
% saves files as episode001.mat, episode002.mat and so on.
fileLogger.LoggingOptions.FileNameRule = "episode<id>";

Configure Data Logging

Training data of interest is generated at different stages of the training loop; for example, experience data is available after the completion of an episode. You can configure the logger object with callback functions to log data at these stages. The functions must return either a structure containing the data to log, or an empty array if no data needs to be logged at that stage.

The callback functions are:

  • EpisodeFinishedFcn - Callback function to log data such as experiences, logged Simulink® signals, or initial observations. The training loop executes this function after the completion of a training episode. The following is an example of the function.

function dataToLog = episodeLogFcn(data)
% episodeLogFcn logs data after every episode.
%
% data is a structure that contains the following fields:
%
% EpisodeCount: The current episode number.
% Environment: Environment object.
% Agent: Agent object.
% Experience: A structure containing the experiences 
%             from the current episode.
% EpisodeInfo: A structure containing the fields 
%              CumulativeReward, StepsTaken, and 
%              InitialObservation.
% SimulationInfo: Contains simulation information for the 
%                 current episode.
%                 For MATLAB environments this is a structure 
%                 with the field "SimulationError".
%                 For Simulink environments this is a 
%                 Simulink.SimulationOutput object.
%
% dataToLog is a structure containing the data to be logged 
% to disk.

% Write your code to log data to disk. For example, 
% dataToLog.Experience = data.Experience;

dataToLog.Experience = data.Experience;
dataToLog.EpisodeReward = data.EpisodeInfo.CumulativeReward;
if data.EpisodeInfo.StepsTaken > 0
    dataToLog.EpisodeQ0 = evaluateQ0(data.Agent, ...
        data.EpisodeInfo.InitialObservation);
else
    dataToLog.EpisodeQ0 = 0;
end
  • AgentStepFinishedFcn - Callback function to log data such as the state of exploration. The training loop executes this function after the completion of an agent step within an episode. The following is an example of the function.

function dataToLog = agentStepLogFcn(data)
% agentStepLogFcn logs data after every agent step.
%
% data is a structure that contains the following fields:
%
% EpisodeCount:   The current episode number.
% AgentStepCount: The cumulative number of steps taken by 
%                 the agent.
% SimulationTime: The current simulation time in the 
%                 environment.
% Agent:          Agent object.
%
% dataToLog is a structure containing the data to be logged 
% to disk.

% Write your code to log data to disk. For example, 
% noiseState = getState(getExplorationPolicy(data.Agent));
% dataToLog.noiseState = noiseState;

policy = getExplorationPolicy(data.Agent);
if hasprop(policy,"NoiseType")
    state = getState(policy);
    switch policy.NoiseType
        case "ou"
            dataToLog.OUNoise = state.Noise{1};
            dataToLog.StandardDeviation = state.StandardDeviation{1};
        case "gaussian"
            dataToLog.StandardDeviation = state.StandardDeviation{1};
    end
else
    dataToLog = [];
end
  • AgentLearnFinishedFcn - Callback function to log data such as the actor and critic training losses. The training loop executes this function after the updating the actor or critic networks. The following is an example of the function.

function dataToLog = agentLearnLogFcn(data)
% agentLearnLogFcn logs data after updating the agent's 
% representations (actor or critic).
%
% data is a structure that contains the following fields:
%
% Agent           : Agent object.
% EpisodeCount    : The current episode number.
% AgentStepCount  : The cumulative number of steps taken by 
%                   the agent.
%
% For agents with an actor:
% ActorLoss              : Training loss of the actor.
% ActorGradientStepCount : Cumulative number of actor 
%                          gradient computation steps.
%
% For agents with a critic:
% CriticLoss   : Training loss of the critic.
% CriticGradientStepCount : Cumulative number of critic 
%                gradient computation steps.
% TDTarget     : Future value of rewards as computed by
%                the target critic network.
% TDError      : Error between the critic and target 
%                critic estimates of the reward.
%
% For PPO/TRPO agents:
% Advantage    : Advantage values.
%
% For PPO agents:
% AdvantageLoss: Advantage loss value.
% EntropyLoss  : Entropy loss value.
% PolicyRatio  : Ratio between current and old policies.
%
% For model-based agents:
% EnvModelTrainingInfo: A structure containing the fields: 
%                       a. TransitionFcnLoss
%                       b. RewardFcnLoss
%                       c. IsDoneFcnLoss. 
%
% For off-policy agents with replay memory:
% SampleIndex  : Indices of experiences sampled from the 
%                replay memory for training.
%
% dataToLog is a structure containing the data to be logged 
% to disk.

% Write your code to log data to disk. For example, 
% dataToLog.ActorLoss = data.ActorLoss;

dataToLog.ActorLoss  = data.ActorLoss;
dataToLog.CriticLoss = data.CriticLoss;

For this example, configure only the AgentLearnFinishedFcn callback. The function logTrainingLoss logs the actor and critic training losses and is provided at the end of this example.

fileLogger.AgentLearnFinishedFcn = @logTrainingLoss;

Run Training

Create a predefined CartPole-continuous environment and a deep deterministic policy gradient (DDPG) agent for training.

% Set the random seed to facilitate reproducibility
rng(0);

% Create a CartPole-continuous environment
env = rlPredefinedEnv("CartPole-continuous");

% Create a DDPG agent
agent = rlDDPGAgent(getObservationInfo(env), getActionInfo(env));
agent.AgentOptions.NoiseOptions.StandardDeviationDecayRate = 0.001;

Specify training options to train the agent for 100 episodes without visualization in the Reinforcement Learning Training Monitor.

Note that you can still use the SaveAgentCriteria, SaveAgentValue and SaveAgentDirectory options of the rlTrainingOptions object to save the agent during training. Such options do not affect (and are not affected by) any usage of FileLogger or MonitorLogger objects.

trainOpts = rlTrainingOptions( ...
    MaxEpisodes=100, ...
    Plots="training-progress");

Train the agent using the train function. Specify the file logger object in the Logger name-value option.

results = train(agent, env, trainOpts, Logger=fileLogger);

The logged data is saved within the directory specified by logDir.

Visualize Logged Data

You can visualize data logged to disk using the interactive Reinforcement Learning Data Viewer graphical user interface. To open the visualization, click View Logged Data in the Reinforcement Learning Training Monitor window.

To create plots in the Reinforcement Learning Data Viewer, select a data from the Data panel and a choice of plot from the toolstrip. The following image shows a plot of the ActorLoss data generated using the Trend plot type. The plot shows logged data points and a moving average line.

On the toolstrip, navigate to the Trend tab to configure plot options. Set the window length for averaging data to 50. The plot updates with the new configuration.

Local Functions

function dataToLog = logTrainingLoss(data)

% Function to log the actor and critic training losses
dataToLog.ActorLoss = data.ActorLoss;
dataToLog.CriticLoss = data.CriticLoss;
end

See Also

Functions

Objects

Related Examples

More About