Main Content

Classify Documents Using Document Embeddings

This example shows how to train a document classifier by converting documents to feature vectors using a document embedding.

Most machine learning techniques require feature vectors as input to train a classifier.

A document embedding maps documents to vectors. Given a data set of labeled document vectors, you can then train a machine learning model to classify these documents.

Load Pretrained Document Embedding

Load the pretrained document embedding "all-MiniLM-L6-v2" using the documentEmbedding function. This model requires the Text Analytics Toolbox™ Model for all-MiniLM-L6-v2 Network support package. If this support package is not installed, then the function provides a download link.

emb = documentEmbedding(Model="all-MiniLM-L6-v2");

For reproducibility, use the rng function with the "default" option.

rng("default");

Load Training Data

Next, load the example data. The file factoryReports.csv contains factory reports, including a text description and categorical labels for each event.

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

The goal of this example is to classify events by the label in the Category column. To divide the data into classes, convert these labels to categorical.

str = data.Description;
labels = categorical(data.Category);

Next, split the data into a training partition and a held-out partition for validation and testing. Set the holdout percentage to 30%.

cvp = cvpartition(labels,Holdout=0.3);
idxTrain = training(cvp);
idxTest = test(cvp);

Get the target labels for the training and test partitions.

labelsTrain = labels(idxTrain,:);
labelsTest = labels(idxTest,:);

Convert Documents to Feature Vectors

To convert the factory reports to vectors, use the embed function. You do not need to perform any text preprocessing on the documents.

embeddedDocumentsTrain = embed(emb,str(idxTrain,:));
embeddedDocumentsTest = embed(emb,str(idxTest,:));

View the size of the embedded test data.

size(embeddedDocumentsTest)
ans = 1×2

   144   384

The output for each of the 144 documents is a single 384-element vector that provides a semantic representation of the entire document. View the embedding vector for the first document in the test set.

embeddedDocumentsTest(1,:)
ans = 1×384

   -0.0141   -0.0434    0.0271   -0.0302   -0.1098   -0.0431   -0.0311   -0.0633    0.0388   -0.0577    0.0328   -0.0112   -0.0293   -0.0755   -0.0539    0.0484    0.0798   -0.0112   -0.0152   -0.0711   -0.0854    0.0378    0.0026    0.0957    0.0080    0.0720    0.0196    0.0605    0.0109   -0.0186    0.0441   -0.0159   -0.0111   -0.0404    0.1344   -0.0472   -0.0102    0.0745    0.0056   -0.1010    0.0479   -0.0117    0.0843   -0.0471   -0.0217    0.0362   -0.0030   -0.0579    0.1073   -0.0383

To visualize the embedding vectors, create a t-SNE plot. First embed the vectors in two-dimensional space using tsne. Then use gscatter to create a scatter plot of the test embedding vectors grouped by label.

Y = tsne(embeddedDocumentsTest,Distance="cosine");
gscatter(Y(:,1),Y(:,2),labelsTest)
title("Factory Report Embeddings")

Train Document Classifier

Train a multiclass linear classification model using fitcecoc.

mdl = fitcecoc(embeddedDocumentsTrain,labelsTrain,Learners="linear")
mdl = 
  CompactClassificationECOC
      ResponseName: 'Y'
        ClassNames: [Electronic Failure    Leak    Mechanical Failure    Software Failure]
    ScoreTransform: 'none'
    BinaryLearners: {6×1 cell}
      CodingMatrix: [4×6 double]


  Properties, Methods

Test Model

Predict the categories of the test documents. Compute the accuracy and plot a confusion matrix chart.

labelPredict = predict(mdl,embeddedDocumentsTest);
acc = mean(labelPredict == labelsTest)
acc = 0.9444
confusionchart(labelPredict,labelsTest)

Large values on the diagonal indicate accurate predictions for the corresponding class. Large values on the off-diagonal indicate strong confusion between the corresponding classes.

See Also

| | | | | | |

Related Topics