Conditional GAN Training Error for TrainGAN function
Afficher commentaires plus anciens
I try to make the Conditional GAN training working with input as a 2D matrix: 14*8.
I try to mimic the "GenerateSyntheticPumpSignalsUsingCGANExample", by changing the vector input as a 2D matrix input.
The error message pops out as:

It seems that there is a size mismatch in the function modelGradients. But since this is an official example, thus I have no idea how to revise it. Can someone give a hint?
The input data is attached as: test.mat
The training script is attached as: untitled3.m. I have also pasted it below.
clear;
%% Load the data
% LSTM_Reform_Data_SeriesData1_20210315_data001_for_GAN;
% load('LoadedData_20210315_data001_for_GAN.mat')
load('test.mat');
% load('test2.mat');
%% Generator Network
numFilters = 4;
numLatentInputs = 120;
projectionSize = [2 1 63];
numClasses = 2;
embeddingDimension = 120;
layersGenerator = [
imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','Input')
projectAndReshapeLayer(projectionSize,numLatentInputs,'ProjReshape');
concatenationLayer(3,2,'Name','Concate1');
transposedConv2dLayer([3 2],8*numFilters,'Stride',1,'Name','TransConv1') % 4*2*32
batchNormalizationLayer('Name','BN1','Epsilon',1e-5)
reluLayer('Name','Relu1')
transposedConv2dLayer([2 2],4*numFilters,'Stride',2,'Name','TransConv2') % 8*4*16
batchNormalizationLayer('Name','BN2','Epsilon',1e-5)
reluLayer('Name','Relu2')
transposedConv2dLayer([2 2],2*numFilters,'Stride',2,'Cropping',[2 1],'Name','TransConv3') % 12*6*8
batchNormalizationLayer('Name','BN3','Epsilon',1e-5)
reluLayer('Name','Relu3')
transposedConv2dLayer([3 3],2*numFilters,'Stride',1,'Name','TransConv4') % 14*8*1
];
lgraphGenerator = layerGraph(layersGenerator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'EmbedReshape1')];
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'EmbedReshape1','Concate1/in2');
subplot(1,2,1);
plot(lgraphGenerator);
dlnetGenerator = dlnetwork(lgraphGenerator);
%% Discriminator Network
scale = 0.2;
Input_Num_Feature = [14 8 1]; % The input data is [14 8 1]
layersDiscriminator = [
imageInputLayer(Input_Num_Feature,'Normalization','none','Name','Input')
concatenationLayer(3,2,'Name','Concate2')
convolution2dLayer([2 2],4*numFilters,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv1')
leakyReluLayer(scale,'Name','LeakyRelu1')
convolution2dLayer([2 4],2*numFilters,'Stride',2,'DilationFactor',1,'Padding',[2 2],'Name','Conv2')
leakyReluLayer(scale,'Name','LeakyRelu2')
convolution2dLayer([2 2],numFilters,'Stride',2,'DilationFactor',1,'Padding',[0 0],'Name','Conv3')
leakyReluLayer(scale,'Name','LeakyRelu3')
convolution2dLayer([2 1],numFilters/2,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv4')
leakyReluLayer(scale,'Name','LeakyRelu4')
convolution2dLayer([2 2],numFilters/4,'Stride',1,'DilationFactor',1,'Padding',[0 0],'Name','Conv5')
];
lgraphDiscriminator = layerGraph(layersDiscriminator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(Input_Num_Feature,embeddingDimension,numClasses,'EmbedReshape2')];
lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'EmbedReshape2','Concate2/in2');
subplot(1,2,2);
plot(lgraphDiscriminator);
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
%% Train model
params.numLatentInputs = numLatentInputs;
params.numClasses = numClasses;
params.sizeData = [Input_Num_Feature length(Series_Fused_Label)];
params.numEpochs = 50;
params.miniBatchSize = 512;
% Specify the options for Adam optimizer
params.learnRate = 0.0002;
params.gradientDecayFactor = 0.5;
params.squaredGradientDecayFactor = 0.999;
executionEnvironment = "cpu";
params.executionEnvironment = executionEnvironment;
% for test, 14*8*30779
[dlnetGenerator,dlnetDiscriminator] =...
trainGAN(dlnetGenerator,dlnetDiscriminator,Series_Fused_Expand_Norm_Input,Series_Fused_Label,params);
Réponse acceptée
Plus de réponses (0)
Catégories
En savoir plus sur Deep Learning Toolbox dans Centre d'aide et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!