MATLAB Examples

CNNによるゴマ塩ノイズ除去

Contents

初期化

clear; close all force; clc

学習用画像生成

X = digitTrainCellArrayData;

% trainNetworkの仕様に合わせるため学習データは4-Dに変換
numOfSamples = size(X,2); % 画像サンプル数
imageSize = size(X{1});   % 画像サイズ
%XTrain = zeros([imageSize,1,numOfSamples]);
for i = 1:numOfSamples
    % 教師画像の作成
    XResponse(:,:,1,i) = X{i}; %#ok<SAGROW>
    % ノイズ画像の作成
    XTrain(:,:,1,i) = imnoise(X{i},'salt & pepper',0.1);  %#ok<SAGROW> % 入力は28x28x1x5000
end

% 学習画像数の確認
size(XTrain)
ans =
          28          28           1        5000

レイヤーの定義

layers = [imageInputLayer([28 28 1], 'Normalization','none','Name','input');          % 入力画像サイズ:28x28x1、入力で明るさの正規化なし  3次元目は1もしくは3
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv1');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu1');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxPooling2dLayer(2,'Stride',2,'Name','mpool1','HasUnpoolingOutputs',true); % max pooling層:2x2の領域内の最大値を出力  (出力:14x14x32) 領域内の平均移動への対応
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv2');                % 3x531のフィルタを32セット(マップ) (出力:14x14x32) パディングあり
    reluLayer('Name','relu2');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    maxUnpooling2dLayer('Name','upool1');                                       % max 逆pooling層:max pooling層で削除した次元に戻す
    convolution2dLayer(3,32,'Padding',[1 1 1 1],'Name','conv3');                % 3x3x1のフィルタを32セット(マップ) (出力:28x28x32) パディングあり
    reluLayer('Name','relu3');                                                  % ReLU(Rectified Linear Unit)活性化関数層
    convolution2dLayer(3,1,'Padding',[1 1 1 1],'Name','conv4');                 % 3x3x1のフィルタを1セット(マップ) (出力:28x28x1) パディングあり
    regressionLayer('Name','routput');                                          % 回帰層:RMSE
    ];

lgraph = layerGraph(layers);
% 最大逆プーリング層の追加のため、最大プーリング層で削ったインデックスとサイズの情報を接続
lgraph = connectLayers(lgraph,'mpool1/indices','upool1/indices');
lgraph = connectLayers(lgraph,'mpool1/size','upool1/size');

学習用オプションを、trainingOptions関数を用い設定

   必要であれば、GPUのメモリに応じてMiniBatchSizeを調整
options = trainingOptions('adam', 'MaxEpochs',15,'Plots','training-progress'); % 最大15世代まで学習

[学習] ネットワークを教師付き学習(SeriesNetwork クラスのオブジェクトが学習後に生成される)

if exist('trainNetwork_TMP.m','file')
    net = trainNetwork_TMP(XTrain, XResponse, lgraph, options);
else
    net = trainNetwork(XTrain, XResponse, lgraph, options);
end
%load('trainednet.mat');
単一の CPU で学習中。
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     RMSE     |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:01 |         7.01 |         24.6 |          0.0010 |
|       2 |          50 |       00:00:29 |         3.32 |          5.5 |          0.0010 |
|       3 |         100 |       00:01:00 |         2.81 |          4.0 |          0.0010 |
|       4 |         150 |       00:01:26 |         2.66 |          3.5 |          0.0010 |
|       6 |         200 |       00:01:52 |         2.58 |          3.3 |          0.0010 |
|       7 |         250 |       00:02:18 |         2.43 |          3.0 |          0.0010 |
|       8 |         300 |       00:02:44 |         2.36 |          2.8 |          0.0010 |
|       9 |         350 |       00:03:09 |         2.30 |          2.6 |          0.0010 |
|      11 |         400 |       00:03:35 |         2.29 |          2.6 |          0.0010 |
|      12 |         450 |       00:03:58 |         2.21 |          2.4 |          0.0010 |
|      13 |         500 |       00:04:23 |         2.25 |          2.5 |          0.0010 |
|      15 |         550 |       00:05:01 |         2.23 |          2.5 |          0.0010 |
|      15 |         585 |       00:05:21 |         2.18 |          2.4 |          0.0010 |
|========================================================================================|

テスト画像の生成と評価

testI = imnoise(X{randi(5000)},'salt & pepper',0.1);
predI = predict(net,testI);
imshow([testI,predI(:,:,1,1)])

Copyright 2018 The MathWorks, Inc.