Can you create a CNN with two images as inputs and a 6x1 array as an output?

11 vues (au cours des 30 derniers jours)
Matthew Blomquist
Matthew Blomquist le 31 Juil 2023
Hello,
I am getting started on using MATLAB's Deep Network Designer, but I have not been able to find any applications that are similar to what I am hoping to do. I have thousands of images from a video containing two 3D objects that are moving relative to one another. My goal for the neural network would be to input two consecutive images and then output the relative changes in motion between the two objects (so an array of translations and rotations). Is this possible to create in MATLAB (i.e., input = two images, output = 6x1 array)?
I have the "ground truth" data for how the two objects move relative to one another, but I'm not sure how to start setting up the CNN, and if it is even possible.
Thank you very much for your help, I really appreciate it!

Réponses (1)

Mrutyunjaya Hiremath
Mrutyunjaya Hiremath le 31 Juil 2023
This is sample code.. Try it.
% Load your dataset and ground truth data here
% Replace 'input_images' and 'ground_truth' with your actual data
% 'input_images' should be a 4D array with size [height x width x channels x numPairs]
% 'ground_truth' should be a 2D array with size [numPairs x 6] representing the 6x1 array of translations and rotations
% Split the data into training, validation, and testing sets
numPairs = size(input_images, 4);
numTraining = round(0.7 * numPairs);
numValidation = round(0.2 * numPairs);
numTesting = numPairs - numTraining - numValidation;
trainingData = input_images(:, :, :, 1:numTraining);
trainingLabels = ground_truth(1:numTraining, :)';
validationData = input_images(:, :, :, numTraining+1:numTraining+numValidation);
validationLabels = ground_truth(numTraining+1:numTraining+numValidation, :)';
testingData = input_images(:, :, :, numTraining+numValidation+1:end);
testingLabels = ground_truth(numTraining+numValidation+1:end, :)';
% Create the CNN architecture
layers = [
imageInputLayer([size(input_images, 1) size(input_images, 2) size(input_images, 3) 2]) % Input layer with two images
convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'conv1') % Convolutional layer with 16 filters of size 3x3
batchNormalizationLayer('Name', 'bn1') % Batch normalization
reluLayer('Name', 'relu1') % ReLU activation
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1') % Max pooling
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv2') % Convolutional layer with 32 filters of size 3x3
batchNormalizationLayer('Name', 'bn2') % Batch normalization
reluLayer('Name', 'relu2') % ReLU activation
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2') % Max pooling
fullyConnectedLayer(128, 'Name', 'fc1') % Fully connected layer with 128 neurons
reluLayer('Name', 'relu3') % ReLU activation
fullyConnectedLayer(6, 'Name', 'fc2') % Output layer with 6 neurons (translations and rotations)
regressionLayer('Name', 'output') % Regression layer for continuous output
];
% Set the training options
options = trainingOptions('adam', ...
'MaxEpochs', 20, ...
'MiniBatchSize', 32, ...
'ValidationData', {validationData, validationLabels}, ...
'Plots', 'training-progress');
% Create and train the CNN
net = trainNetwork(trainingData, trainingLabels, layers, options);
% Test the CNN on the testing dataset
predictedLabels = predict(net, testingData);
% Calculate the Mean Squared Error (MSE) between predicted and ground truth labels
mse = mean((predictedLabels - testingLabels).^2);
disp(['Mean Squared Error: ' num2str(mse)]);
  1 commentaire
Matthew Blomquist
Matthew Blomquist le 1 Août 2023
Thank you for the help, Mrutyunjaya Hiremath!
I am currently getting an error while running the imageInputLayer function:
Error using imageInputLayer
imageInputLayer( [size(input_images, 1) size(input_images, 2) size(input_images, 3) 2] )
Invalid argument at position 1. Expected input image size to be a 2 or 3 element vector.
So it looks like I'm not able to input two separate images that way, correct?

Connectez-vous pour commenter.

Catégories

En savoir plus sur Deep Learning Toolbox dans Help Center et File Exchange

Produits


Version

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by