Train OCR model to recognize text in image
trains a new OCR model by fine-tuning a pretrained base model using the hyperparameters
modelFileName = trainOCR(
returns a structure that contains information on training progress, such as the training
root mean squared error (RMSE) and learning rate for each iteration, using the input
arguments from the previous syntax. For a list of the returned error rates, see the modelFileName output
info] = trainOCR(___)
resumes training from an OCR training checkpoint. Use this syntax to improve the accuracy
of your OCR model by using additional training data or to perform more training
info] = trainOCR(
Train OCR Model
This example shows how to train an OCR model that can recognize fourteen-segment characters. The training data contains word samples of fourteen-segment characters.
Unzip and extract training images.
datasetZip = 'dseg14.zip'; evalfiles = unzip(datasetZip);
The training images was annotated with bounding boxes containing words and text labels were added to these bounding boxes as an attribute using the Image Labeler. The labels were exported from the app as
groundTruth object and saved in dseg14Gtruth.mat file.
ld = load("dseg14Gtruth.mat"); gTruth = ld.gTruth;
Create datastores that contain images, bounding boxes and text labels from the
groundTruth object using the
ocrTrainingData function with the label and attribute names used during labeling.
labelName = "Text"; attributeName = "Word"; [imds,boxds,txtds] = ocrTrainingData(gTruth,labelName,attributeName);
Combine the datastores.
cds = combine(imds,boxds,txtds);
Split the data for training and validation with a training-to-validation ratio of 0.9
% Set the random number seed for reproducibility. rng(0); % Compute number of training and validation samples. trainingToValidationRatio = 0.9; numSamples = height(ld.gTruth.LabelData); numTrainSamples = ceil(trainingToValidationRatio*numSamples); % Divide the dataset into training and validation. indices = randperm(numSamples); trainIndices = indices(1:numTrainSamples); validationIndices = indices(numTrainSamples+1:end); cdsTrain = subset(cds, trainIndices); cdsValidation = subset(cds, validationIndices);
Specify training options. Set the gradient decay factor for ADAM optimization to 0.9, and use an initial learning rate of 40e-4. Set the verbose frequency to 160 iterations and the maximum number of epochs for training to 5. Specify the checkpoint path to enable saving checkpoints and specify the validation data to enable validation.
outputDir = "OCRModel"; if ~exist(outputDir, "dir") mkdir(outputDir); end checkpointsDir = "Checkpoints"; if ~exist(checkpointsDir, "dir") mkdir(checkpointsDir); end ocrOptions = ocrTrainingOptions(GradientDecayFactor=0.9, ... InitialLearnRate=40e-4, MaxEpochs=5, VerboseFrequency=160, ... CheckpointPath=checkpointsDir, ValidationData=cdsValidation, ... OutputLocation=outputDir);
Train a new OCR model by fine-tuning the pretrained "english" model. The training will take about 3-4 minutes.
outputModelName = "fourteenSegment"; baseModel = "english"; outputModel = trainOCR(cdsTrain, outputModelName, baseModel, ocrOptions);
************************************************************************* Starting OCR training Model Name: fourteenSegment Base Model: english Preparing training data... 100.00 % completed. Preparing validation data... 100.00 % completed. Character Set: +,-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ |======================================================================================================================================| | Epoch | Iteration | Time Elapsed | Training Statistics | Validation Statistics | Base Learning | | | | (hh:mm:ss) | RMSE | Character Error | Word Error | RMSE | Character Error | Word Error | Rate | |======================================================================================================================================| | 1 | 1 | 00:02:10 | 9.51 | 100.00 | 100.00 | 0.00 | 0.00 | 0.00 | 0.0040 | | 1 | 160 | 00:02:56 | 2.43 | 18.46 | 38.12 | 1.69 | 14.19 | 27.78 | 0.0040 | | 2 | 320 | 00:03:42 | 1.47 | 10.22 | 21.88 | 0.75 | 6.35 | 11.11 | 0.0040 | | 3 | 480 | 00:04:27 | 1.06 | 6.94 | 15.00 | 0.49 | 5.56 | 5.56 | 0.0040 | | 4 | 640 | 00:05:12 | 0.86 | 5.30 | 11.72 | 0.68 | 5.56 | 5.56 | 0.0040 | | 5 | 800 | 00:05:57 | 0.73 | 4.25 | 9.50 | 0.49 | 6.35 | 11.11 | 0.0040 | | 5 | 845 | 00:06:08 | 0.70 | 4.06 | 9.11 | 0.49 | 6.35 | 11.11 | 0.0040 | |======================================================================================================================================| OCR training complete. Exit condition: Reached maximum epochs. Model file name: OCRModel/fourteenSegment.traineddata *************************************************************************
Use the trained model to perform OCR on a test image and visualize the results.
I = imread("DSEG14.png"); ocrResults = ocr(I,Language=outputModel); Iocr = insertObjectAnnotation(I,"rectangle",... ocrResults.WordBoundingBoxes,ocrResults.Words,... LineWidth=2,FontSize=17); imshow(Iocr)
trainingData — Ground truth data
Ground truth data, specified as a
datastore that returns a cell array or a table when input to the
read function. The table must contain at least these three columns:
1st column — A cell vector of logical, grayscale, or RGB images.
2nd column — A cell vector in which each cell corresponds to an image and contains an M-by-4 matrix. M is the number of bounding boxes in the image, and each row of the matrix specifies a bounding box in the form [x,y,width,height]. cell vector that contains M-by-4 matrices with M bounding boxes of the form [x,y,width,height].
3rd column — A cell vector in which each cell corresponds to an image and contains N strings. N is the number of lines of text in the image, and each line must contain only text without newline characters.
modelName — New model name
string scalar | character vector
New model name, specified as a string scalar or character vector. If the folder
already contains a file with the name specified by the
argument, then the
function overwrites it during training.
baseModel — Pretrained base model
string scalar | character vector
Pretrained base model, specified as a string scalar or character vector. You can specify any of these options:
Language models shipped in the Computer Vision Toolbox™, such as
One of the supported languages described in the
Modelargument of the
ocrfunction. You cannot use quantized models, such as
"japanese-fast", as base models.
Full path to a custom trained model with a
ocrOptions — Hyperparameters for training
Hyperparameters for training, specified as an
checkpoint — OCR training checkpoint
string scalar | character vector |
OCR training checkpoint, specified as a string scalar or character vector. You must
specify a path to a file with a
such as the path specified by the
CheckpointPath property of an
ocrTrainingOptions object. When you specify a value for the
CheckpointPath argument of an
trainOCR function saves checkpoints at regular intervals
during training. You can resume training from any one of these saved checkpoints.
modelFileName — Model filename
Model filename, returned as a string scalar.
info — Information on training progress
Information on training progress, returned as a structure containing these fields:
BaseLearnRate— Learning rate at each iteration.
TrainingRMSE— Training RMSE at each iteration.
TrainingCharError— Training character error rate at each iteration.
TrainingWordError— Training word error rate at each iteration.
ValidationRMSE— Validation RMSE at each iteration.
ValidationCharError— Validation character error rate at each iteration.
ValidationWordError— Validation word error rate at each iteration.
FinalValidationRMSE— Final validation RMSE at the end of the training.
OutputModelIteration— Iteration number of the returned model.
If you do not specify validation data , the structure contains empty
Training OCR models with right-to-left scripts such as Arabic and Hebrew are not supported.
trainOCRfunction creates a temporary folder, "<
Training/, where <modelName> is the value of the
modelNameargument, in the location specified by the
OutputLocationproperty of the
ocrTrainingOptionsobject. The folder contains training artifacts. If the folder does not already exist before you run the
trainOCRfunction, the function deletes it at the end of training. If the folder already exists prior to training, the function does not delete the folder.
Images read from
trainingDatamust contain text of at least one-word length and up to a maximum of one-line length. The
trainOCRfunction does not support images that contain multiple lines of text.
trainOCRfunction does not support on-the-fly data augmentation using a datastore transform. All the image data is read once from the training datastores at the start of training.
Introduced in R2023a