miniprojet/tests/src/knn.cpp

92 lines
2 KiB
C++

#include <map>
#include "file.hpp"
#include "math.hpp"
#include <stdexcept>
#include <queue>
#include <opencv2/opencv.hpp>
#include <iostream>
double distance(math::csignal& v1, math::csignal& v2, int n){
if (v1.size() != v2.size()) {
throw std::runtime_error("les deux vecteurs doivent être de même longueur");
}
double d = 0;
auto v1_it = v1.begin();
auto v2_it = v2.begin();
while (v1_it != v1.end()) {
double dist = std::abs(*(v1_it++) - *(v2_it++));
d += std::pow(dist, n);
}
return std::pow(d, 1/n);
};
int argmax(std::vector<int>& v){
int arg = 0;
int max = v[0];
for(int i = 1; i < v.size() ; ++i){
if (v[i]>max){
arg = i;
max = v[i];
};
};
return arg;
};
struct pair_comp {
bool operator()(std::pair<double, std::string> a, std::pair<double, std::string> b) {
if (a.first == b.first) {
return false;
}
if (a.first > b.first) {
return true;
}
return false;
};
};
int main(int argc, char** argv) {
int k = 20;
int size = 100;
std::string path;
int cmax = 10;
int threshold = 20;
if (argc > 2) {
path = argv[1];
threshold = atoi(argv[2]);
} else {
std::cout << "Invalid number of arguments" << std::endl;
return 0;
}
math::dataset references = get_data(path, size, cmax, threshold);
math::csignal sample = math::img2desc(path+"/arret/arret0199.jpg", cmax, threshold);
std::priority_queue<std::pair<double, std::string>, std::vector<std::pair<double, std::string>>, pair_comp> neighbors;
std::map<std::string, int> labels;
for (auto desc: references) {
double d = distance(desc.first, sample, 1);
neighbors.push({d, desc.second});
}
for (int i=0; i<k; ++i) {
std::pair<double, std::string> nearest = neighbors.top();
neighbors.pop();
labels[nearest.second] += 1;
}
int max = 0;
std::string label;
for (auto val: labels) {
if (val.second > max) {
max = val.second;
label = val.first;
}
}
std::cout << label << std::endl;
};