7 #ifndef MACHINE_LEARNING_LDA_HPP 8 #define MACHINE_LEARNING_LDA_HPP 12 #include "../include/matrix/Matrix.hpp" 21 MatrixD X,
y, eigenvalues, eigenvectors, transformedData;
28 LDA(MatrixD data, MatrixD classes) : X(
std::move(data)), y(
std::move(classes)) {
29 if (data.nRows() != classes.nRows())
30 throw invalid_argument(
"data and classes must have the same number of rows");
31 if (classes.nCols() != 1)
32 throw invalid_argument(
"classes must me a column vector");
36 MatrixD Sw = X.WithinClassScatter(y);
37 MatrixD Sb = X.BetweenClassScatter(y);
39 auto eigen = (Sw.inverse() * Sb).eigen();
41 eigenvalues = eigen.first;
42 eigenvectors = eigen.second;
44 transformedData = (eigenvectors.transpose() * X.transpose()).transpose();
52 return transformedData;
56 #endif //MACHINE_LEARNING_LDA_HPP
k-nearest neighbors algorithm, able to do regression and classification
Linear discriminant analysis algorithm.
LDA(MatrixD data, MatrixD classes)
Linear discriminant analysis algorithm.
MatrixD transform()
Transforms the data matrix using the eigenvectors found by fit()