Encountering 'dlnetwork/predict' error: Undefined function 'recordNary' for 'deep.internal.AcceleratedOp' in my deep RL code
7 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
Hello MATLAB community,
I'm working on a custom deep learning agent for a maintenance scheduling problem. I've been facing an issue where, starting from the first learning step, I consistently obtain zero gradients. This obviously prevents any meaningful updates to my network and hinders learning.
To further investigate this, I introduced a numerical gradient check in my agent's update method. However, I encountered an error when this check was run.
What I've tried:
- Checked initializations to ensure that they're not leading to saturated activations.
- Commented out the numerical gradient check, which made the error go away.
- Checked perturbations with epso and tried different values like 1e-4 and 1e-6.
Questions:
- What are potential reasons for consistently obtaining zero gradients in custom deep learning agents?
- How can I properly implement and use a numerical gradient check in MATLAB to debug this issue?
- Are there best practices or other methods to investigate and resolve the vanishing gradient problem in custom agents?
Thank you for any insights or suggestions!
Code:
classdef DQNAgent < handle
properties
numWorkstations
bufferSize
batchSize
gamma
...
tau = 0.01;
...
end
methods
function obj = DQNAgent(assemblyLine, numWorkstations, taskTimes, bufferSize, batchSize, gamma, alpha, epsilon, targetUpdateFrequency, agentRole) % Add agentRole as a parameter
obj.numWorkstations = numWorkstations;
obj.maxTasks = max(cellfun(@numel, taskTimes));
obj.assemblyLine = assemblyLine;
obj.bufferSize = bufferSize;
obj.batchSize = batchSize;
obj.gamma = gamma;
obj.alpha = alpha;
obj.epsilon = epsilon;
obj.targetUpdateFrequency = targetUpdateFrequency;
obj.taskTimes = taskTimes;
obj.numModels=numel(taskTimes);
obj.sumTaskTimes = sum(cellfun(@numel, obj.taskTimes));
obj.lossValues = [];
totalTasks = sum(cellfun(@length, taskTimes));
obj.stateSize = numWorkstations * totalTasks + numWorkstations * 2;
obj.maintenanceActionSize = numWorkstations;
obj.reassignmentActionSize = obj.maxTasks * obj.numModels * obj.numWorkstations^2;
obj.agentRole = agentRole; % Set the agentRole property
if strcmp(agentRole, 'maintenance')
obj.actionSize = obj.maintenanceActionSize;
lgraph = obj.createNetwork(obj.actionSize, [32, 64], 'leakyrelu'); % Example usage
obj.network = dlnetwork(lgraph);
obj.maintenanceTargetNetwork = dlnetwork(lgraph); % Initialize the maintenance target network
obj.updateTargetNetwork();
elseif strcmp(agentRole, 'taskAssignment')
obj.actionSize = obj.reassignmentActionSize;
lgraph = obj.createNetwork(obj.actionSize, [32, 64], 'leakyrelu'); % Example usage
obj.network = dlnetwork(lgraph);
obj.taskAssignmentTargetNetwork = dlnetwork(lgraph); % Initialize the task assignment target network
obj.updateTargetNetwork();
end
obj.replayBuffer = ReplayBuffer(bufferSize);
obj.learnStep = 0;
if strcmp(obj.agentRole, 'maintenance')
disp(obj.maintenanceTargetNetwork.Layers);
elseif strcmp(obj.agentRole, 'taskAssignment')
disp(obj.taskAssignmentTargetNetwork.Layers);
end
end
function normalizedState = normalizeState(obj, state)
% Preallocate matrices for normalized values
normalizedHealthStatusMatrix = zeros(size(state, 1), 1);
normalizedMaintenanceCountsMatrix = zeros(size(state, 1), 1);
% Loop through each row (workstation) and normalize values
for i = 1:size(state, 1)
% Normalize health status for current workstation
healthStatus = state(i, end-1);
normalizedHealthStatus = (healthStatus - obj.healthStatusMin) / (obj.healthStatusMax - obj.healthStatusMin);
normalizedHealthStatusMatrix(i) = normalizedHealthStatus;
% Normalize maintenance counts for current workstation
maintenanceCounts = state(i, end);
normalizedMaintenanceCounts = (maintenanceCounts - obj.maintenanceCountMin) / (obj.maintenanceCountMax - obj.maintenanceCountMin);
normalizedMaintenanceCountsMatrix(i) = normalizedMaintenanceCounts;
end
% Concatenate the normalized values to the state matrix, excluding the original health status and maintenance counts
normalizedState = [state(:, 1:end-2), normalizedHealthStatusMatrix, normalizedMaintenanceCountsMatrix];
end
function action = selectAction(obj, state)
state = obj.normalizeState(state);
flattenedState = state(:);
dlState = dlarray(single(flattenedState), 'SSCB');
qValues = predict(obj.network, dlState);
if strcmp(obj.agentRole, 'maintenance')
if rand() < obj.epsilon
action = randi(obj.maintenanceActionSize);
else
[~, action] = max(qValues(1:obj.maintenanceActionSize));
action = extractdata(action);
end
elseif strcmp(obj.agentRole, 'taskAssignment')
triedActions = []; % Initialize the list of tried actions
useRandomAction = []; % Initialize the flag
while true
if isempty(useRandomAction)
useRandomAction = rand() < obj.epsilon;
end
if useRandomAction
% Select a random task assignment action
sourceWorkstation = randi(obj.numWorkstations);
targetWorkstation = randi(obj.numWorkstations);
taskModel = randi(numel(obj.taskTimes));
taskIndex = randi(numel(obj.taskTimes{taskModel}));
else
% Select the task assignment action with the highest Q-value
qValuesTemp = qValues;
qValuesTemp(triedActions) = -Inf;
[~, actionIndex] = max(qValuesTemp);
actionIndex = extractdata(actionIndex);
triedActions = [triedActions, actionIndex]; % Update the list of tried actions
taskIndex = mod(actionIndex, obj.maxTasks);
if taskIndex == 0
taskIndex = obj.maxTasks;
actionIndex = actionIndex - obj.maxTasks;
end
tempValue = floor(actionIndex / obj.maxTasks);
taskModel = mod(tempValue, obj.numModels) + 1;
tempValue = floor(tempValue / obj.numModels);
sourceWorkstation = mod(tempValue, obj.numWorkstations) + 1;
targetWorkstation = floor(tempValue / obj.numWorkstations) + 1;
end
% Check if the action is valid
if sourceWorkstation <= obj.numWorkstations && targetWorkstation <= obj.numWorkstations && taskModel <= numel(obj.taskTimes) && taskIndex <= numel(obj.taskTimes{taskModel})
break;
end
% After you have determined the action components, encode them into a single number:
end
actionIndex = taskIndex ...
+ obj.maxTasks * (taskModel - 1) ...
+ obj.maxTasks * obj.numModels * (sourceWorkstation - 1) ...
+ obj.maxTasks * obj.numModels * obj.numWorkstations * (targetWorkstation - 1);
% Return the task assignment action as a cell array
action = actionIndex;
end
end
function storeExperience(obj, state, action, reward, nextState, done)
state = obj.normalizeState(state);
nextState = obj.normalizeState(nextState);
obj.replayBuffer.add(state, action, reward, nextState, done);
end
function update(obj)
if obj.replayBuffer.size() >= obj.batchSize
[states, actions, rewards, nextStates, dones] = obj.replayBuffer.sample(obj.batchSize);
actualBatchSize = size(states, 1);
% Transpose the states and nextStates arrays
states = states';
nextStates = nextStates';
states = single(reshape(states, [1, obj.stateSize, 1, actualBatchSize]));
nextStates = single(reshape(nextStates, [1, obj.stateSize, 1, actualBatchSize]));
states = dlarray(states, 'SSCB');
nextStates = dlarray(nextStates, 'SSCB');
obj.learnStep = obj.learnStep + 1;
% Calculate loss and gradients
[loss, gradients] = obj.qLoss(states, actions, rewards, nextStates, dones);
% Debugging: Compute the numerical gradient and compare
numericalGrads = obj.computeNumericalGradient(@obj.qLoss, states, actions, rewards, nextStates, dones);
disp('Numerical Gradients:');
disp(numericalGrads);
disp('dlgradient Gradients:');
disp(gradients);
% Simple gradient descent update
for i = 1:height(gradients)
layerName = gradients{i, 1};
paramName = gradients{i, 2};
gradValue = gradients{i, 3}{1};
%fprintf('Layer: %s, Parameter: %s, Max Gradient: %f, Min Gradient: %f\n', layerName, paramName, max(gradValue, [], 'all'), min(gradValue, [], 'all'));
%fprintf('Layer: %s, Max Gradient: %f, Min Gradient: %f\n', layerName, max(gradValue, [], 'all'), min(gradValue, [], 'all'));
idx = strcmp(obj.network.Learnables.Layer, layerName) & strcmp(obj.network.Learnables.Parameter, paramName);
currentValue = obj.network.Learnables.Value{idx};
updatedValue = currentValue - obj.alpha * gradValue;
obj.network.Learnables.Value{idx} = dlarray(updatedValue);
end
% Log the loss
obj.lossValues(end+1) = double(gather(loss));
% Update the target network at regular intervals
obj.updateTargetNetwork();
end
end
function [loss, gradients] = qLoss(obj, states, actions, rewards, nextStates, dones)
% Define a nested function for the loss computation
function [loss, grads] = computeLoss(w)
obj.network.Learnables = w;
% Predict Q-values for the next states using the target network
if strcmp(obj.agentRole, 'maintenance')
qNext = predict(obj.maintenanceTargetNetwork, nextStates);
elseif strcmp(obj.agentRole, 'taskAssignment')
qNext = predict(obj.taskAssignmentTargetNetwork, nextStates);
end
% Compute the target Q-values
qNextMax = max(qNext, [], 1);
targets = rewards' + obj.gamma * (1 - dones') .* qNextMax;
% Predict Q-values for the current states using the main network
q = predict(obj.network, states);
qAction = zeros(1, numel(actions));
for i = 1:numel(actions)
qAction(i) = q(actions(i), i);
end
qAction = dlarray(qAction, 'SSCB');
qAction = squeeze(qAction);
% Directly compute the loss
loss = mean((qAction - targets).^2, 'all');
% Compute the gradient of the loss with respect to the network's learnables
grads = dlgradient(loss, obj.network.Learnables);
end
% Call the nested function using dlfeval to compute the loss and gradients
[loss, gradients] = dlfeval(@computeLoss, obj.network.Learnables);
end
function updateTargetNetwork(obj)
numLearnables = height(obj.network.Learnables); % Use height instead of length
if strcmp(obj.agentRole, 'maintenance')
for i = 1:numLearnables
currentValue = obj.network.Learnables{i, 'Value'}{1}; % Extract the contents of the cell
targetValue = obj.maintenanceTargetNetwork.Learnables{i, 'Value'}{1}; % Extract the contents of the cell
obj.maintenanceTargetNetwork.Learnables{i, 'Value'}{1} = ...
obj.tau * currentValue + ...
(1 - obj.tau) * targetValue;
end
elseif strcmp(obj.agentRole, 'taskAssignment')
for i = 1:numLearnables
currentValue = obj.network.Learnables{i, 'Value'}{1}; % Extract the contents of the cell
targetValue = obj.taskAssignmentTargetNetwork.Learnables{i, 'Value'}{1}; % Extract the contents of the cell
obj.taskAssignmentTargetNetwork.Learnables{i, 'Value'}{1} = ...
obj.tau * currentValue + ...
(1 - obj.tau) * targetValue;
end
end
end
function [weights, bias] = customInitialization(obj, inputSize, outputSize)
weights = sqrt(2/inputSize) * randn(outputSize, inputSize);
bias = zeros(outputSize, 1);
end
function lgraph = createNetwork(obj, actionSize, layerDepths, activationFunction)
layers = [
imageInputLayer([1 obj.stateSize 1], 'Normalization', 'none', 'Name', 'state')
];
inputSize = obj.stateSize; % Initial input size is the state size
for i = 1:length(layerDepths)
% Create the fully connected layer without weights
fcLayer = fullyConnectedLayer(layerDepths(i), 'Name', ['fc' num2str(i)]);
disp(['inputSize: ', num2str(inputSize), ', layerDepth: ', num2str(layerDepths(i))]);
% Use the custom initialization function
[fcLayer.Weights, fcLayer.Bias] = obj.customInitialization(inputSize, layerDepths(i));
% Append the layer to the layers array
layers = [layers; fcLayer];
switch activationFunction
case 'relu'
layers = [layers; reluLayer('Name', ['relu' num2str(i)])];
case 'leakyrelu'
layers = [layers; leakyReluLayer('Name', ['leakyrelu' num2str(i)])];
case 'tanh'
layers = [layers; tanhLayer('Name', ['tanh' num2str(i)])];
case 'sigmoid'
layers = [layers; sigmoidLayer('Name', ['sigmoid' num2str(i)])];
end
% Update the input size for the next layer
inputSize = layerDepths(i);
end
% For the final output layer
fcLayer = fullyConnectedLayer(actionSize, 'Name', 'output');
[fcLayer.Weights, fcLayer.Bias] = obj.customInitialization(inputSize, actionSize); % <-- Corrected this line
layers = [layers; fcLayer];
lgraph = layerGraph(layers);
end
function numericalGradients = computeNumericalGradient(obj, lossFunc, states, actions, rewards, nextStates, dones)
% Compute the numerical gradient of a loss function with respect to learnables
epso = 1e-5; % Small perturbation value
numericalGradients = []; % Initialize the numerical gradients
% Create a deep copy of the network using save and load
tempFileName = [tempname, '.mat'];
save(tempFileName, 'obj');
loadedData = load(tempFileName);
tempNetwork = loadedData.obj.network;
delete(tempFileName); % Clean up the temporary file
for i = 1:numel(tempNetwork.Learnables)
originalValue = tempNetwork.Learnables.Value{i}; % Store the original value
% Perturb the parameter positively
tempNetwork.Learnables.Value{i} = originalValue + epso;
lossPlus = lossFunc(states, actions, rewards, nextStates, dones);
% Perturb the parameter negatively
tempNetwork.Learnables.Value{i} = originalValue - epso;
lossMinus = lossFunc(states, actions, rewards, nextStates, dones);
% Compute the gradient for this parameter
gradient = (lossPlus - lossMinus) / (2 * epso);
numericalGradients = [numericalGradients; gradient];
end
end
function saveNetworkWeights(obj, filename)
qNetworkWeights = obj.network.Learnables; % Directly access the Learnables property
if strcmp(obj.agentRole, 'maintenance')
targetNetworkWeights = obj.maintenanceTargetNetwork.Learnables; % Directly access the Learnables property
elseif strcmp(obj.agentRole, 'taskAssignment')
targetNetworkWeights = obj.taskAssignmentTargetNetwork.Learnables; % Directly access the Learnables property
end
save(filename, 'qNetworkWeights', 'targetNetworkWeights');
end
function loadNetworkWeights(obj, qNetworkWeights, targetNetworkWeights)
obj.network.Learnables = qNetworkWeights; % Directly set the Learnables property
if strcmp(obj.agentRole, 'maintenance')
obj.maintenanceTargetNetwork.Learnables = targetNetworkWeights; % Directly set the Learnables property
elseif strcmp(obj.agentRole, 'taskAssignment')
obj.taskAssignmentTargetNetwork.Learnables = targetNetworkWeights; % Directly set the Learnables property
end
end
end
end
Error:
typescriptCopy code
Error using dlnetwork/predict Execution failed during layer(s) ''. ...
Caused by: Undefined function 'recordNary' for input arguments of type 'deep.internal.AcceleratedOp'.
What I've tried:
- Checked initializations to ensure that they're not leading to saturated activations.
- Commented out the numerical gradient check, which made the error go away.
- Checked perturbations with epso and tried different values like 1e-4 and 1e-6.
Questions:
- What are potential reasons for consistently obtaining zero gradients in custom deep learning agents?
- How can I properly implement and use a numerical gradient check in MATLAB to debug this issue?
- Are there best practices or other methods to investigate and resolve the vanishing gradient problem in custom agents?
Thank you for any insights or suggestions!
0 commentaires
Réponses (1)
Neha
le 11 Oct 2023
Hi Mohammadreza,
I understand that you are facing issues related to vanishing gradient in your custom deep learning agent. To answer your first question, I see that you have already used ReLu and LeakyReLu activation functions in your code to prevent zero gradient, you can also add batch normalization layer before the fully connected layer to ensure that the inputs to each layer are properly normalized, which can help mitigate the vanishing gradients problem. You can refer to the following documentation link to implement the same:
Techniques like Glorot initialization or He initialization can help ensure that the weights are in a good range to avoid vanishing gradient, since Glorot is the default weight bias for the fully connected layers, you can try He weight initialization. Please refer to the following documentation link for more information:
Regarding the numerical gradient check implemented in your code, the updated weights are not being passed directly to the loss function. This means that it is being evaluated with the original weights, not the perturbed weights.
numericalGrads = obj.computeNumericalGradient(@obj.qLoss, states, actions, rewards, nextStates, dones);
% In the computeNumericalGradient function:
lossPlus = lossFunc(states, actions, rewards, nextStates, dones); % this is being called with the original weights in obj
To compute the numerical gradients correctly, you need to pass the updated weights to the loss function.
Hope this helps!
0 commentaires
Voir également
Catégories
En savoir plus sur Image Data Workflows dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!