OpenCV machine learning algorithms
OpenCV implements eight of these machine learning algorithms. All of them are inherited from the StatModel class:
- Artificial neural networks
- Random trees
- Expectation maximization
- k-nearest neighbors
- Logistic regression
- Normal Bayes classifiers
- support vector machine
- Stochastic gradient descent SVMs
Version 3 supports deep learning at a basic level, but version 4 is stable and more supported. We will delve into deep learning in detail in further chapters.
The following diagram shows the machine learning class hierarchy:
The StatModel class is the base class for all machine learning algorithms. This provides the prediction and all the read and write functions that are very important for saving and reading our machine learning parameters and training data.
In machine learning, the most time-consuming and computing resource-consuming part is the training method. Training can take from seconds to weeks or months for large datasets and complex machine learning structures. For example, in deep learning, big neural network structures with more than 100,000 image datasets can take a long time to train. With deep learning algorithms, it is common to use parallel hardware processing such as GPUs with CUDA technology to decrease the computing time during training, or most new chip devices such as Intel Movidius. This means that we cannot train our algorithm each time we run our application, and therefore it's recommended to save our trained model with all of the parameters that have been learned. In future executions, we only have to load/read from our saved model without training, except if we need to update our model with more sample data.
StatModel is the base class of all machine learning classes, such as SVM or ANN, except deep learning methods. StatModel is basically a virtual class that defines the two most important functions—train and predict. The train method is the main method that's responsible for learning model parameters using a training dataset. This has the following three possible calls:
bool train(const Ptr<TrainData>& trainData, int flags=0 ); bool train(InputArray samples, int layout, InputArray responses); Ptr<_Tp> train(const Ptr<TrainData>& data, int flags=0 );
The train function has the following parameters:
- TrainData: Training data that can be loaded or created from the TrainData class. This class is new in OpenCV 3 and helps developers create training data and abstract from the machine learning algorithm. This is done because different algorithms require different types of structures of arrays for training and prediction, such as the ANN algorithm.
- samples: An array of training array samples such as training data in the format required by the machine learning algorithm.
- layout: ROW_SAMPLE (training samples are the matrix rows) or COL_SAMPLE (training samples are the matrix columns).
- responses: Vector of responses associated with the sample data.
- flags: Optional flags defined by each method.
The last train method creates and trains a model of the _TP class type. The only classes accepted are the classes that implement a static create method with no parameters or with all default parameter values.
The predict method is much simpler and has only one possible call:
float StatModel::predict(InputArray samples, OutputArray results=noArray(), int flags=0)
The predict function has the following parameters:
- samples: The input samples to predict results from the model can consist of any amount of data, whether single or multiple.
- results: The results of each input row sample (computed by the algorithm from the previously trained model).
- flags: These optional flags are model-dependent. Some models, such as Boost, are recognized by the SVM StatModel::RAW_OUTPUT flag, which makes the method return the raw results (the sum), and not the class label.
The StatModel class provides an interface for other very useful methods:
-
- isTrained() returns true if the model is trained
- isClassifier() returns true if the model is a classifier, or false in the case of regression
- getVarCount() returns the number of variables in training samples
- save(const string& filename) saves the model in the filename
- Ptr<_Tp> load(const string& filename) loads the <indexentry content="StatModel class:Ptr load(const string& filename)"> model from a filename, for example—Ptr<SVM> svm = StatModel::load<SVM>("my_svm_model.xml")
- calcError(const Ptr<TrainData>& data, bool test, OutputArray resp) calculates the error from test data, where the data is the training data. If the test parameter is true, the method calculates the error from a test subset of data; if its false, the method calculates the error from all training data. resp is the optional output result.
Now, we are going to introduce how a basic application that uses machine learning in a computer vision application is constructed.