diff options
Diffstat (limited to 'meowpp.test/src/KD_Tree.cpp')
-rw-r--r-- | meowpp.test/src/KD_Tree.cpp | 190 |
1 files changed, 0 insertions, 190 deletions
diff --git a/meowpp.test/src/KD_Tree.cpp b/meowpp.test/src/KD_Tree.cpp deleted file mode 100644 index 8d4232e..0000000 --- a/meowpp.test/src/KD_Tree.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "meowpp/dsa/KD_Tree.h" -#include "meowpp/utility.h" - -#include "dsa.h" - -#include <vector> - -#include <cmath> -#include <cstdlib> -#include <algorithm> -#include <ctime> -#include <queue> - -static int N = 10000; -static int D = 5; - -static double dist2(std::vector<double> const& v1, std::vector<double> const& v2){ - double ret = 0; - for(int i = 0; i < D; i++){ - ret += meow::squ(v1[i] - v2[i]); - } - return ret; -} - -static std::vector< std::vector<double> > data; -static std::vector< double > dist; -static std::vector< int > order; - - -struct Answer{ - double dist; - int id; - Answer(double _dist, int _id): dist(_dist), id(_id){ } - bool operator<(Answer const& b) const{ - if(dist != b.dist) return (dist < b.dist); - return (id < b.id); - } -}; - - -static void find(std::vector<double> const& v, int k){ - std::priority_queue<Answer> qu; - for(int i = 0; i < k; i++){ - qu.push(Answer(dist2(v, data[i]), i)); - } - for(int i = k; i < N; i++){ - qu.push(Answer(dist2(v, data[i]), i)); - qu.pop(); - } - order.resize(k); - for(int i = qu.size() - 1; i >= 0; i--){ - order[i] = qu.top().id; - qu.pop(); - } -} - -static std::vector<double> v; - -/* -static bool sf(const int& a, const int& b){ - if(dist[a] != dist[b]) - return (dist[a] < dist[b]); - return (a < b); -} - -static void show(std::vector<double> const& ask, std::vector<int> kd, std::vector<int> me, int k){ - if(N <= 30 && D <= 3){ - printf("\nData:\n"); - for(int i = 0; i < N; i++){ - printf(" %2d) <", i); - for(int j = 0; j < D; j++){ - printf("%.7f", data[i][j]); - if(j < D - 1) printf(", "); - else printf(">"); - } - printf("\n"); - } - printf("Ask <"); - for(int j = 0; j < D; j++){ - printf("%.7f", ask[j]); - if(j < D - 1) printf(", "); - else printf(">"); - } - printf("\n"); - printf("MyAnswer: "); - for(int i = 0; i < k; i++) printf("%d ", me[i]); - printf("\n"); - printf("KdAnswer: "); - for(int i = 0; i < k; i++) printf("%d ", kd[i]); - printf("\n"); - order.resize(N); - dist .resize(N); - for(int i = 0; i < N; i++){ - dist [i] = dist2(ask, data[i]); - order[i] = i; - } - std::sort(order.begin(), order.end(), sf); - printf("Sorted:\n"); - for(int i = 0; i < N; i++){ - printf(" %2d) <", order[i]); - for(int j = 0; j < D; j++){ - printf("%.7f", data[order[i]][j]); - if(j < D - 1) printf(", "); - else printf(">"); - } - printf(" ((%.7f))", dist[order[i]]); - printf("\n"); - } - } -} -// */ - -struct Node{ - std::vector<double> v; - int id; - double& operator[](size_t d) { return v[d]; } - double operator[](size_t d) const { return v[d]; } - bool operator<(Node const& n) const{ return (id < n.id); } -}; - -TEST(KD_Tree, "It is very slow"){ - - int t0, t1, t2; - - meow::KD_Tree<Node, double> tree(D); - - meow::messagePrintf(1, "Create data (N = %d, D = %d)", N, D); - data.resize(N); - for(int i = 0; i < N; i++){ - data[i].resize(D); - Node nd; - nd.v.resize(D); - nd.id = i; - for(int j = 0; j < D; j++){ - data[i][j] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); - nd[j] = data[i][j]; - } - tree.insert(nd); - } - meow::messagePrintf(-1, "ok"); - meow::messagePrintf(1, "build"); - t0 = clock(); - tree.build(); - meow::messagePrintf(-1, "ok, %.3f seconds", (clock() - t0) * 1.0 / CLOCKS_PER_SEC); - - meow::messagePrintf(1, "query..."); - v.resize(D); - meow::KD_Tree<Node, double>::Vectors ret; - for(int k = 1; k <= std::min(100, N); k++){ - meow::messagePrintf(1, "range k = %d", k); - t1 = t2 = 0; - for(int i = 0; i < 10; i++){ - Node nd; - nd.v.resize(D); - for(int d = 0; d < D; d++){ - v[d] = 12345.0 * (1.0 * rand() / RAND_MAX - 0.3); - nd[d] = v[d]; - } - t0 = clock(); - tree.build(); - ret = tree.query(nd, k, true); - t1 += clock() - t0; - - t0 = clock(); - find(v, k); - t2 += clock() - t0; - if((int)ret.size() != (int)std::min(k, N)){ - meow::messagePrintf(-1, "(%d)query fail, size error", i); - meow::messagePrintf(-1, "fail"); - return false; - } - for(int kk = 1; kk <= k; kk++){ - if(order[kk - 1] != ret[kk - 1].id){ - //show(v, ret, order, k); - meow::messagePrintf(-1, "(%d)query fail", i); - meow::messagePrintf(-1, "fail"); - return false; - } - } - } - meow::messagePrintf(-1, "ok %.3f/%.3f", - t1 * 1.0 / CLOCKS_PER_SEC, - t2 * 1.0 / CLOCKS_PER_SEC - ); - } - meow::messagePrintf(-1, "ok"); - - - return true; -} |