MATLAB Examples

CNNによるノイズ除去(縦線・横線)

カスタムミニバッチデータストアを使ってノイズ画像生成

Contents

初期化

clear; close all force; clc; rng('default');

フォルダから数字画像の読み込み

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');
digitData = imageDatastore(digitDatasetPath,...
    'IncludeSubfolders',true,'LabelSource','foldernames');

% シャッフル
digitData = shuffle(digitData);

% 20枚を表示
figure;
for k = 1:20
    subplot(4,5,k);
    imshow(digitData.Files{k});
end

学習データとテストデータの分離

trainNumFiles = 450;
testNumFiles = 50;
[trainDigitData,testDigitData] = splitEachLabel(digitData,trainNumFiles,...
    testNumFiles,'randomize');

劣化データ生成のためのカスタムミニバッチストア

miniBatchSize = 64;

% カスタムミニバッチストアを呼び出し
trainSource = RowColNoiseImageDatastore(trainDigitData,...
    'MiniBatchSize',miniBatchSize,...
    'BatchesPerImage',1);

% 劣化画像と残差を生成
inputBatch = read(trainSource);
summary(inputBatch)

% 可視化
figure, subplot(3,1,1);
imshow(imread(trainDigitData.Files{2}));
subplot(3,1,2);
imshow(inputBatch.noisyPatches{2});
subplot(3,1,3);
imshow(inputBatch.noiseComponents{2});

figure;
for k = 1:5
    subplot(5,3,3*(k-1)+1);
    imshow(trainDigitData.Files{k});
    title('原画像');
    subplot(5,3,3*(k-1)+2);
    imshow(inputBatch.noisyPatches{k});
    title('ノイズ付加画像');
    subplot(5,3,3*(k-1)+3);
    imshow(inputBatch.noiseComponents{k});
    title('差分画像(ノイズ成分)');
end
Variables:
    noisyPatches: 64×1 cell
    noiseComponents: 64×1 cell

テスト用のカスタムミニバッチストアを定義

testSource = RowColNoiseImageDatastore(testDigitData,...
    'MiniBatchSize',miniBatchSize,...
    'BatchesPerImage',1);

レイヤーの定義

layers = [imageInputLayer([28 28 1], 'Normalization','none','Name','input');          % 入力画像サイズ:28x28x1、入力で明るさの正規化なし  3次元目は1もしくは3
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv1');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu1');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxPooling2dLayer(2,'Stride',2,'Name','mpool1','HasUnpoolingOutputs',true); % max pooling層:2x2の領域内の最大値を出力  (出力:14x14x32) 領域内の平均移動への対応
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv2');                % 3x531のフィルタを32セット(マップ) (出力:14x14x32) パディングあり
    reluLayer('Name','relu2');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxUnpooling2dLayer('Name','upool1');                                       % max 逆pooling層:max pooling層で削除した次元に戻す
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv3');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu3');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    convolution2dLayer(3,1,'Padding',[1 1 1 1],'Name','conv4');                 % 3x3x1のフィルタを1セット(マップ) (出力:28x28x1) パディングあり
    regressionLayer('Name','routput');                                          % 回帰層:RMSE
    ];

lgraph = layerGraph(layers);
% 最大逆プーリング層の追加のため、最大プーリング層で削ったインデックスとサイズの情報を接続
lgraph = connectLayers(lgraph,'mpool1/indices','upool1/indices');
lgraph = connectLayers(lgraph,'mpool1/size','upool1/size');
figure, plot(lgraph);

学習用オプションを、trainingOptions関数を用い設定

options = trainingOptions('adam', 'MaxEpochs',15,... % 最大15世代まで学習
    'Plots','training-progress',...
    'ValidationData',testSource);

[学習] ネットワークを教師付き学習

%(SeriesNetwork クラスのオブジェクトが学習後に生成される)

% 学習過程の画像も可視化するためにtrainNetworkをカスタマイズしたものを使用
if exist('trainNetwork_TMP.m','file')
    net = trainNetwork_TMP(trainSource, lgraph, options);
else
    net = trainNetwork(trainSource, lgraph, options);
end
if ~exist('trainednet_rowcol_minibatch.mat','file')
    save('trainednet_rowcol_minibatch','net');
end
% 学習済みのネットワークを読み込み
load('trainednet_rowcol_minibatch.mat');
単一の CPU で学習中。
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     RMSE     |     RMSE     |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |         6.95 |         6.96 |      24.1678 |      24.2410 |          0.0010 |
|       2 |          50 |       00:01:06 |         4.44 |         4.33 |       9.8621 |       9.3752 |          0.0010 |
|       3 |         100 |       00:02:12 |         1.86 |         1.85 |       1.7227 |       1.7118 |          0.0010 |
|       5 |         150 |       00:03:18 |         1.47 |         1.49 |       1.0803 |       1.1140 |          0.0010 |
|       6 |         200 |       00:04:25 |         1.22 |         1.26 |       0.7405 |       0.7954 |          0.0010 |
|       8 |         250 |       00:05:35 |         1.11 |         1.11 |       0.6129 |       0.6122 |          0.0010 |
|       9 |         300 |       00:06:45 |         0.98 |         1.03 |       0.4829 |       0.5256 |          0.0010 |
|      10 |         350 |       00:07:50 |         0.95 |         0.95 |       0.4516 |       0.4505 |          0.0010 |
|      12 |         400 |       00:08:55 |         0.88 |         0.88 |       0.3901 |       0.3885 |          0.0010 |
|      13 |         450 |       00:09:58 |         0.90 |         0.88 |       0.4008 |       0.3911 |          0.0010 |
|      15 |         500 |       00:11:06 |         0.84 |         0.84 |       0.3534 |       0.3535 |          0.0010 |
|      15 |         525 |       00:11:43 |         0.84 |         0.83 |       0.3557 |       0.3460 |          0.0010 |
|======================================================================================================================|

テスト画像を使った評価

% 学習したネットワークでノイズ除去
reset(testSource);
testSource.MiniBatchSize = 10;
noisyImages = read(testSource);
predI = predict(net,noisyImages);

figure;
for k = 1:1
    subplot(1,4,4*(k-1)+1);
    imshow(testDigitData.Files{k});
    title('原画像');
    subplot(1,4,4*(k-1)+2);
    imshow(noisyImages.noisyPatches{k});
    title('ノイズ付加画像');
    subplot(1,4,4*(k-1)+3);
    imshow(predI(:,:,1,k));
    title('ノイズ推定画像');
    subplot(1,4,4*(k-1)+4);
    denoisedImage = noisyImages.noisyPatches{k} - predI(:,:,1,k);
    imshow(denoisedImage);
    title('ノイズ除去画像');
end

Copyright 2018 The MathWorks, Inc.