how to create channel attention layer in matlab.
19 vues (au cours des 30 derniers jours)
Afficher commentaires plus anciens
classdef ChannelAttentionLayer < nnet.layer.Layer
properties
% Reduction ratio used in the channel attention mechanism
ReductionRatio
end
properties (Learnable)
% Layer learnable parameters
Weights1
Bias1
Weights2
Bias2
end
methods
function layer = ChannelAttentionLayer(reduction_ratio, input_channels, name)
% Constructor for ChannelAttentionLayer
layer.Name = name;
layer.ReductionRatio = reduction_ratio;
% Calculate reduced channels based on reduction ratio
reduced_channels = max(1, round(input_channels / reduction_ratio));
% Initialize weights and biases
layer.Weights1 = randn([1, 1, input_channels, reduced_channels], 'single');
layer.Bias1 = zeros([1, 1, reduced_channels], 'single');
layer.Weights2 = randn([1, 1, reduced_channels, input_channels], 'single');
layer.Bias2 = zeros([1, 1, input_channels], 'single');
end
function Z = forward(layer, X)
% Forward pass for training mode
% Ensure X is a dlarray
X = dlarray(X);
% Get input size
[H, W, C] = size(X);
% Global Average Pooling (GAP)
avg_pool = mean(X, [1, 2]); % Mean over height and width
avg_pool = reshape(avg_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% Global Max Pooling (GMP)
max_pool = max(X, [], [1, 2]); % Max over height and width
max_pool = reshape(max_pool, [1, 1, C]); % Reshape to [1, 1, Channels]
% First fully connected layer applied to both avg and max pooled outputs
avg_out = fullyconnect(avg_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
max_out = fullyconnect(max_pool, layer.Weights1, layer.Bias1, C, layer.ReductionRatio);
% Apply ReLU
avg_out = relu(avg_out);
max_out = relu(max_out);
% Second fully connected layer
avg_out = fullyconnect(avg_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
max_out = fullyconnect(max_out, layer.Weights2, layer.Bias2, layer.ReductionRatio, C);
% Combine average and max pooled outputs
Z = avg_out + max_out;
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Reshape attention map and multiply with input
Z = reshape(Z, [1, 1, C]);
Z = X .* Z;
% Ensure Z is unformatted
Z = dlarray(Z);
end
function Z = predict(layer, X)
% Predict pass for inference mode
Z = forward(layer, X);
end
end
end
% Fully connected operation for 1x1 conv
function out = fullyconnect(input, weights, bias, input_channels, output_channels)
% Ensure the number of input channels matches the weights' channels
[H, W, C_in] = size(input);
[~, ~, C, ~] = size(weights);
if C_in ~= C
error('Number of channels in input and weights do not match.');
end
% Flatten input dimensions
input_reshaped = reshape(input, [], C_in); % Flatten spatial dimensions
% Perform matrix multiplication and add bias
weights_reshaped = reshape(weights, [C_in, output_channels]);
out = input_reshaped * weights_reshaped + reshape(bias, [1, output_channels]);
% Reshape back to original dimensions
out = reshape(out, [1, 1, output_channels]);
end
Réponses (2)
Wilfrido Gomez-Flores
il y a environ 9 heures
classdef CBAMLayer < nnet.layer.Layer
properties
% Reduction ratio used in the channel attention mechanism
ReductionRatio
InputChannels
ReducedChannels
filterSize
end
properties (Learnable)
% Layer learnable parameters
Weights1
Bias1
Weights2
Bias2
Weights3
Bias3
end
methods
function layer = CBAMLayer(reduction_ratio, filter_size, name)
% Constructor for ChannelAttentionLayer
layer.Name = name;
layer.ReductionRatio = reduction_ratio;
layer.filterSize = filter_size;
end
function layer = initialize(layer,layout)
reduction_ratio = layer.ReductionRatio;
idx = finddim(layout,"C");
input_channels = layout.Size(idx);
% Calculate reduced channels based on reduction ratio
reduced_channels = max(1, round(input_channels / reduction_ratio));
layer.InputChannels = input_channels;
layer.ReducedChannels = reduced_channels;
% Initialize shared MLP weights and biases
if isempty(layer.Weights1)
layer.Weights1 = randn([1, 1, input_channels, reduced_channels], 'single');
end
if isempty(layer.Bias1)
layer.Bias1 = zeros([1, 1, reduced_channels], 'single');
end
if isempty(layer.Weights2)
layer.Weights2 = randn([1, 1, reduced_channels, input_channels], 'single');
end
if isempty(layer.Bias2)
layer.Bias2 = zeros([1, 1, input_channels], 'single');
end
filter_size = layer.filterSize;
% Initialize convolutional weights and biases
if isempty(layer.Weights3)
layer.Weights3 = randn([filter_size, filter_size, 2, 1], 'single');
end
if isempty(layer.Bias3)
layer.Bias3 = zeros([1, 1], 'single');
end
end
function Z = forward(layer, X)
% Forward pass for training mode
Z = ChannelAttentionModule(X,layer);
Z = SpatialAttention(Z,layer);
Z = X + Z;
end
function Z = predict(layer, X)
% Predict pass for inference mode
Z = forward(layer, X);
end
end
end
function Z = ChannelAttentionModule(X,layer)
X = dlarray(X);
% Get minibatch size
B = size(X,4);
% Global Average Pooling (GAP)
avg_pool = mean(X, [1, 2]);
% Global Max Pooling (GMP)
max_pool = max(X, [], [1, 2]);
% First fully connected layer applied to both avg and max pooled outputs
avg_out = fullyconnect(avg_pool, layer.Weights1, layer.Bias1, layer.InputChannels, layer.ReducedChannels, B);
max_out = fullyconnect(max_pool, layer.Weights1, layer.Bias1, layer.InputChannels, layer.ReducedChannels, B);
% Apply ReLU
avg_out = relu(avg_out);
max_out = relu(max_out);
% Second fully connected layer
avg_out = fullyconnect(avg_out, layer.Weights2, layer.Bias2, layer.ReducedChannels, layer.InputChannels, B);
max_out = fullyconnect(max_out, layer.Weights2, layer.Bias2, layer.ReducedChannels, layer.InputChannels, B);
% Combine average and max pooled outputs
Z = avg_out + max_out;
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Multiply with input
Z = X .* Z;
Z = dlarray(Z);
end
function Z = SpatialAttention(X,layer)
% Ensure X is a dlarray
X = dlarray(X);
% Average Pooling (GAP)
avg_pool = mean(X, 3); % Mean over height and width
% Max Pooling (GMP)
max_pool = max(X, [], 3); % Max over height and width
% Concatenate
Y = cat(3,avg_pool,max_pool);
% Convolution
Z = dlconv(Y,layer.Weights3,layer.Bias3,DataFormat="SSCB",Padding="same");
% Apply sigmoid to get attention weights
Z = sigmoid(Z);
% Multiply with input
Z = X .* Z;
Z = dlarray(Z);
end
% Fully connected operation for 1x1 conv
function out = fullyconnect(input, weights, bias, input_channels, output_channels, batch)
if size(input,3) ~= size(weights,3)
error('Number of channels in input and weights do not match.');
end
input_reshaped = reshape(input, input_channels, batch)';
weights_reshaped = reshape(weights, [input_channels, output_channels]);
bias_reshaped = reshape(bias, [1, output_channels]);
out = input_reshaped * weights_reshaped + bias_reshaped;
out = reshape(out, [1, 1, output_channels, batch]);
end
0 commentaires
Voir également
Catégories
En savoir plus sur Signal Reception and Recovery dans Help Center et File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!