A saved GAN trained model for image generation does not generate the same accurate images when GPU is reset
1 vue (au cours des 30 derniers jours)
Afficher commentaires plus anciens
When I train the flower image generation example, everything seems to go well as long as the GPU memory keeps the parameters loaded. I obtain images of easily recognizable flowers, as shown in the example. However, if I save the complete training workspace using the 'save' command (for example, save('GANWorkspacefile.mat')), which also includes netG, then clear the GPU memory (reset), and subsequently load the previous workspace (load('GANWorkspacefile.mat')), the images generated with 'predict' end up blurry—no flowers at all—resembling the ones generated at the beginning of training. The same issue occurs when I transfer the saved workspace and load it on another machine with the same version of MATLAB (R2022b). It seems that something is missing when loading the workspace variables that prevents generating the images in the same way as they are generated just at the end of training. I would appreciate it if someone has any idea of what I'm doing wrong could comment on it.
Thank you.
0 commentaires
Réponse acceptée
Ben
le 8 Avr 2024
Déplacé(e) : Walter Roberson
le 23 Avr 2024
I believe this is due to a bug in the R2022b version of the custom projectAndReshapeLayer attached to the example. In particular in the initialize method the layer.Weights and layer.Bias are replaced with their initial values even if they already have trained values. The initialize method is called when you load the saved generator network.
You can update the initialize method in the custom layer to the following:
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable parameters.
%
% Inputs:
% layer - Layer to initialize
% layout - Data layout, specified as a
% networkDataLayout object
%
% Outputs:
% layer - Initialized layer
% Layer output size.
outputSize = layer.OutputSize;
% Initialize fully connect weights.
if isempty(layer.Weights)
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize using Glorot.
sz = [prod(outputSize) numChannels];
numOut = prod(outputSize);
numIn = numChannels;
layer.Weights = initializeGlorot(sz,numOut,numIn);
end
% Initialize fully connect bias.
if isempty(layer.Bias)
% Initialize with zeros.
layer.Bias = initializeZeros([prod(outputSize) 1]);
end
end
The if isempty(layer.Weights) and if isempty(layer.Bias) checks ensure that the trained projection is not lost on load.
1 commentaire
Plus de réponses (0)
Voir également
Catégories
En savoir plus sur Image Data Workflows dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!