Machine learning algorithms in C++
LDA.hpp
Go to the documentation of this file.
1 
7 #ifndef MACHINE_LEARNING_LDA_HPP
8 #define MACHINE_LEARNING_LDA_HPP
9 
10 #include <utility>
11 
12 #include "../include/matrix/Matrix.hpp"
13 
14 using namespace std;
15 
19 class LDA {
20  private:
21  MatrixD X, y, eigenvalues, eigenvectors, transformedData;
22  public:
28  LDA(MatrixD data, MatrixD classes) : X(std::move(data)), y(std::move(classes)) {
29  if (data.nRows() != classes.nRows())
30  throw invalid_argument("data and classes must have the same number of rows");
31  if (classes.nCols() != 1)
32  throw invalid_argument("classes must me a column vector");
33  }
34 
35  void fit() {
36  MatrixD Sw = X.WithinClassScatter(y);
37  MatrixD Sb = X.BetweenClassScatter(y);
38 
39  auto eigen = (Sw.inverse() * Sb).eigen();
40 
41  eigenvalues = eigen.first;
42  eigenvectors = eigen.second;
43 
44  transformedData = (eigenvectors.transpose() * X.transpose()).transpose();
45  }
46 
51  MatrixD transform() {
52  return transformedData;
53  }
54 };
55 
56 #endif //MACHINE_LEARNING_LDA_HPP
MatrixD y
Definition: LDA.hpp:21
k-nearest neighbors algorithm, able to do regression and classification
Linear discriminant analysis algorithm.
Definition: LDA.hpp:19
LDA(MatrixD data, MatrixD classes)
Linear discriminant analysis algorithm.
Definition: LDA.hpp:28
MatrixD transform()
Transforms the data matrix using the eigenvectors found by fit()
Definition: LDA.hpp:51
void fit()
Definition: LDA.hpp:35