7 #ifndef MACHINE_LEARNING_CLASSIFIERUTILS_HPP 8 #define MACHINE_LEARNING_CLASSIFIERUTILS_HPP 13 for (
size_t i = 0; i < y.nRows(); i++)
20 static MatrixD
getAllClasses(
const MatrixD yTrue,
const MatrixD yPred) {
21 MatrixD allClasses = yTrue;
22 allClasses.addColumn(yPred);
23 allClasses = allClasses.unique();
29 static void checkLabels(
const MatrixD yTrue,
const MatrixD yPred) {
30 if (yTrue.nCols() != 1 or yPred.nCols() != 1)
31 throw invalid_argument(
"Labels must be column vectors");
32 if (yTrue.nRows() != yPred.nRows())
33 throw invalid_argument(
"True labels and predicted labels must have the same size (number of rows).");
38 if (!yTrue.isBinary())
39 throw invalid_argument(
"True labels must be composed of only two classes");
40 if (!yPred.isBinary())
41 throw invalid_argument(
"Predicted labels must be composed of only two classes");
44 static MatrixI
binarize(MatrixD m,
double trueLabel) {
45 return m == trueLabel;
53 MatrixI result = MatrixI::zeros(allClasses.nRows(), allClasses.nRows());
55 for (
size_t i = 0; i < yTrue.nRows(); i++) {
56 size_t tLabel =
findLabel(allClasses, yTrue(i, 0));
57 size_t pLabel =
findLabel(allClasses, yPred(i, 0));
59 result(pLabel, tLabel) += 1;
65 static double accuracy(MatrixD yTrue, MatrixD yPred) {
68 for (
size_t i = 0; i < yTrue.nRows(); i++)
69 accuracy += yTrue(i, 0) == yPred(i, 0);
74 static double precision(MatrixD yTrue, MatrixD yPred) {
77 return cm(1, 1) / ((double) cm(1, 1) + cm(1, 0));
80 static double recall(MatrixD yTrue, MatrixD yPred) {
83 return cm(1, 1) / ((double) cm(1, 1) + cm(0, 1));
86 static double f_score(MatrixD yTrue, MatrixD yPred) {
89 return 2 * ((p * r) / (p + r));
93 #endif //MACHINE_LEARNING_CLASSIFIERUTILS_HPP static MatrixI binarize(MatrixD m, double trueLabel)
static double recall(MatrixD yTrue, MatrixD yPred)
static void checkBinaryLabels(const MatrixD yTrue, const MatrixD yPred)
static double f_score(MatrixD yTrue, MatrixD yPred)
static MatrixD getAllClasses(const MatrixD yTrue, const MatrixD yPred)
static MatrixI confusionMatrix(MatrixD yTrue, MatrixD yPred)
static void checkLabels(const MatrixD yTrue, const MatrixD yPred)
static double precision(MatrixD yTrue, MatrixD yPred)
static size_t findLabel(MatrixD y, double label)
static double accuracy(MatrixD yTrue, MatrixD yPred)