MATLAB Examples

CNNによるノイズ除去(縦線・横線)

Contents

初期化

clear; close all; clc

数字画像(5000枚)の読み込み

X = digitTrainCellArrayData;
% 64枚を表示
montage(X(1:64))

学習用・テスト画像生成

Test = X(4501:5000); % テスト画像群(500枚)
X = X(1:4500);       % 学習用画像群(4500枚)

% trainNetworkの仕様に合わせるため学習データは4-Dに変換
numOfSamples = size(X,2); % 画像サンプル数
imageSize = size(X{1});   % 画像サイズ
%XTrain = zeros([imageSize,1,numOfSamples]);
for i = 1:numOfSamples
    % 教師画像の作成
    XResponse(:,:,1,i) = X{i}; %#ok<SAGROW>
    % ノイズ画像の作成
    randrow = randi(imageSize(1));
    randcol = randi(imageSize(2));
    XTrain(:,:,1,i) = X{i}; %#ok<SAGROW>
    XTrain(randrow,:,1,i) = 1; %#ok<SAGROW>
    XTrain(:,randcol,1,i) = 1; %#ok<SAGROW>
end

ノイズの画像と元画像の表示

imshowpair(XTrain(:,:,1,1),XResponse(:,:,1,1),'montage')

学習画像数の確認

size(XTrain)
ans =
          28          28           1        4500

レイヤーの定義

layers = [imageInputLayer([28 28 1], 'Normalization','none','Name','input');          % 入力画像サイズ:28x28x1、入力で明るさの正規化なし  3次元目は1もしくは3
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv1');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu1');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxPooling2dLayer(2,'Stride',2,'Name','mpool1','HasUnpoolingOutputs',true); % max pooling層:2x2の領域内の最大値を出力  (出力:14x14x32) 領域内の平均移動への対応
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv2');                % 3x531のフィルタを32セット(マップ) (出力:14x14x32) パディングあり
    reluLayer('Name','relu2');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxUnpooling2dLayer('Name','upool1');                                       % max 逆pooling層:max pooling層で削除した次元に戻す
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv3');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu3');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    convolution2dLayer(3,1,'Padding',[1 1 1 1],'Name','conv4');                 % 3x3x1のフィルタを1セット(マップ) (出力:28x28x1) パディングあり
    regressionLayer('Name','routput');                                          % 回帰層:RMSE
    ];

lgraph = layerGraph(layers);
% 最大逆プーリング層の追加のため、最大プーリング層で削ったインデックスとサイズの情報を接続
lgraph = connectLayers(lgraph,'mpool1/indices','upool1/indices');
lgraph = connectLayers(lgraph,'mpool1/size','upool1/size');
figure, plot(lgraph);

学習用オプションを、trainingOptions関数を用い設定

options = trainingOptions('adam', 'MaxEpochs',15,...
    'Plots','training-progress'); % 最大15世代まで学習

[学習] ネットワークを教師付き学習(SeriesNetwork クラスのオブジェクトが学習後に生成される)

% 学習過程の画像も可視化するためにtrainNetworkをカスタマイズしたものを使用
if exist('trainNetwork_TMP.m','file')
    net = trainNetwork_TMP(XTrain, XResponse, lgraph, options);
else
    net = trainNetwork(XTrain, XResponse, lgraph, options);
end
% 学習済みのネットワークを読み込み
%load('trainednet_rowcol.mat');
単一の CPU で学習中。
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     RMSE     |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:01 |         6.76 |         22.8 |          0.0010 |
|       2 |          50 |       00:00:30 |         3.88 |          7.5 |          0.0010 |
|       3 |         100 |       00:00:58 |         3.39 |          5.8 |          0.0010 |
|       5 |         150 |       00:01:26 |         2.95 |          4.3 |          0.0010 |
|       6 |         200 |       00:01:57 |         2.70 |          3.6 |          0.0010 |
|       8 |         250 |       00:02:28 |         2.55 |          3.3 |          0.0010 |
|       9 |         300 |       00:02:58 |         2.44 |          3.0 |          0.0010 |
|      10 |         350 |       00:03:26 |         2.26 |          2.6 |          0.0010 |
|      12 |         400 |       00:03:53 |         2.03 |          2.1 |          0.0010 |
|      13 |         450 |       00:04:21 |         1.97 |          1.9 |          0.0010 |
|      15 |         500 |       00:04:58 |         1.70 |          1.5 |          0.0010 |
|      15 |         525 |       00:05:17 |         1.68 |          1.4 |          0.0010 |
|========================================================================================|

テスト画像の生成と評価

% ノイズ画像の作成
randrow = randi(imageSize(1));
randcol = randi(imageSize(2));
testI = Test{randi(500)}; % テスト画像群から1枚を選択
testI(randrow,:) = 1; % 横線の挿入
testI(:,randcol) = 1; % 縦線の挿入

% 学習したネットワークでノイズ除去
predI = predict(net,testI);
imshow([testI,predI(:,:,1,1)])

Copyright 2018 The MathWorks, Inc.