diff options
Diffstat (limited to 'meowpp/dsa/KD_Tree.hpp')
-rw-r--r-- | meowpp/dsa/KD_Tree.hpp | 390 |
1 files changed, 230 insertions, 160 deletions
diff --git a/meowpp/dsa/KD_Tree.hpp b/meowpp/dsa/KD_Tree.hpp index ac9f868..f0e97a9 100644 --- a/meowpp/dsa/KD_Tree.hpp +++ b/meowpp/dsa/KD_Tree.hpp @@ -2,196 +2,266 @@ #include <list> #include <vector> #include <algorithm> -#include <set> -#include "utility.h" +#include <queue> +#include "../utility.h" namespace meow{ - template<class Keys, class Key, class Value> - inline KD_Tree<Keys,Key,Value>::Node::Node(Keys _key, - Value _value, - ssize_t _l_child, - ssize_t _r_child): - key(_key), value(_value), lChild(_l_child), rChild(_r_child){ } - // - template<class Keys, class Key, class Value> - inline KD_Tree<Keys, Key, Value>::Sorter::Sorter(KD_Tree::Nodes const& _nodes, - size_t _cmp): - nodes(_nodes), cmp(_cmp){ } - template<class Keys, class Key, class Value> + //////////////////////////////////////////////////////////////////// + // **# Node #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::Node::Node(Vector __vector, + ssize_t __lChild, ssize_t __rChild): + _vector(__vector), _lChild(__lChild), _rChild(__rChild){ + } + //////////////////////////////////////////////////////////////////// + // **# Sorter #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::Sorter::Sorter(Nodes const* __nodes, size_t __cmp): + _nodes(__nodes), _cmp(__cmp){ + } + template<class Vector, class Scalar> inline bool - KD_Tree<Keys, Key, Value>::Sorter::operator()(size_t const& a, - size_t const& b) const{ - if(nodes[a].key[cmp] != nodes[b].key[cmp]){ - return (nodes[a].key[cmp] < nodes[b].key[cmp]); + KD_Tree<Vector, Scalar>::Sorter::operator()(size_t const& __a, + size_t const& __b) const{ + if((*_nodes)[__a]._vector[_cmp] != (*_nodes)[__b]._vector[_cmp]){ + return ((*_nodes)[__a]._vector[_cmp] < (*_nodes)[__b]._vector[_cmp]); } - return (nodes[a].value < nodes[b].value); + return ((*_nodes)[__a]._vector < (*_nodes)[__b]._vector); + } + //////////////////////////////////////////////////////////////////// + // **# Answer / Answer's Compare class #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::Answer::Answer(ssize_t __index, Scalar __dist2): + _index(__index), _dist2(__dist2){ + } + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::Answer::Answer(Answer const& __answer2): + _index(__answer2._index), _dist2(__answer2._dist2){ } // - template<class Keys, class Key, class Value> - inline KD_Tree<Keys, Key, Value>::Answer::Answer(Node const& _node, - Key _dist2): - node(_node), dist2(_dist2){ } - template<class Keys, class Key, class Value> + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::AnswerCompare::AnswerCompare(Nodes const* __nodes, + bool __cmpValue): + _nodes(__nodes), _cmpValue(__cmpValue){ + } + template<class Vector, class Scalar> inline bool - KD_Tree<Keys, Key, Value>::Answer::operator<(Answer const& b) const{ - if(dist2 != b.dist2) return (dist2 < b.dist2); - return (node.value < b.node.value); + KD_Tree<Vector, Scalar>::AnswerCompare::operator()(Answer const& __a, + Answer const& __b) const{ + if(_cmpValue == true && __a._dist2 == __b._dist2){ + return ((*_nodes)[__a._index]._vector < (*_nodes)[__b._index]._vector); + } + return (__a._dist2 < __b._dist2); } - // - template<class Keys, class Key, class Value> - inline Key KD_Tree<Keys, Key, Value>::distance2(Keys const& k1, - Keys const& k2) const{ - Key ret(0); - for(size_t i = 0; i < dimension; i++) - ret += squ(k1[i] - k2[i]); + //////////////////////////////////////////////////////////////////// + // **# distance2() #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline Scalar + KD_Tree<Vector, Scalar>::distance2(Vector const& __v1, + Vector const& __v2) const{ + Scalar ret(0); + for(size_t i = 0; i < _dimension; i++){ + ret += squ(__v1[i] - __v2[i]); + } return ret; } - template<class Keys, class Key, class Value> - inline size_t KD_Tree<Keys, Key, Value>::query(Keys const& key, - size_t k, - size_t index, - int depth, - std::vector<Key>& dist2_v, - Key dist2_s, - KD_Tree::AnswerList* ret) const{ - if(index == NIL){ - return 0; - } - size_t cmp = depth % dimension; - ssize_t right_side, opposite, size; - ssize_t sz, other; - if(key[cmp] <= nodes[index].key[cmp]){ - right_side = nodes[index].lChild; - opposite = nodes[index].rChild; + //////////////////////////////////////////////////////////////////// + // **# query() #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::query(Vector const& __vector, + size_t __nearestNumber, + AnswerCompare const& __answerCompare, + size_t __index, + int __depth, + std::vector<Scalar>& __dist2Vector, + Scalar __dist2Minimum, + Answers *__out) const{ + if(__index == _NIL) return ; + size_t cmp = __depth % _dimension; + ssize_t this_side, that_side; + if(!(_nodes[__index]._vector[cmp] < __vector[cmp])){ + this_side = _nodes[__index]._lChild; + that_side = _nodes[__index]._rChild; }else{ - right_side = nodes[index].rChild; - opposite = nodes[index].lChild; + this_side = _nodes[__index]._rChild; + that_side = _nodes[__index]._lChild; } - size = query(key, k, right_side, depth + 1, dist2_v, dist2_s, ret); - Answer my_ans(nodes[index], distance2(nodes[index].key, key)); - if(size < k || my_ans < *(ret->rbegin())){ - KD_Tree::AnswerListIterator it = ret->begin(); - while(it != ret->end() && !(my_ans < *it)) it++; - ret->insert(it, my_ans); - size++; + query(__vector, __nearestNumber, __answerCompare, + this_side, __depth + 1, + __dist2Vector, __dist2Minimum, + __out); + Answer my_ans(__index, distance2(_nodes[__index]._vector, __vector)); + if(__out->size() < __nearestNumber || + __answerCompare(my_ans, __out->top())){ + __out->push(my_ans); + if(__out->size() > __nearestNumber) __out->pop(); } - Key dist2_old = dist2_v[cmp]; - dist2_v[cmp] = squ(nodes[index].key[cmp] - key[cmp]); - dist2_s += dist2_v[cmp] - dist2_old; - if(size < k || (*(ret->rbegin())).dist2 >= dist2_s){ - KD_Tree::AnswerList ret2; - size += query(key, k, opposite, depth + 1, dist2_v, dist2_s, &ret2); - KD_Tree::AnswerListIterator it1, it2; - for(it1 = ret->begin(), it2 = ret2.begin(); it2 != ret2.end(); it2++){ - while(it1 != ret->end() && *it1 < *it2) it1++; - it1 = ++(ret->insert(it1, *it2)); - } + Scalar dist2_old = __dist2Vector[cmp]; + __dist2Vector[cmp] = squ(_nodes[__index]._vector[cmp] - __vector[cmp]); + Scalar dist2Minimum = __dist2Minimum + __dist2Vector[cmp] - dist2_old; + if(__out->size() < __nearestNumber || + !(__out->top()._dist2 < dist2Minimum)){ + query(__vector, __nearestNumber, __answerCompare, + that_side, __depth + 1, + __dist2Vector, dist2Minimum, + __out); } - if(size > k){ - for(int i = size - k; i--; ){ - ret->pop_back(); - } - size = k; - } - dist2_v[cmp] = dist2_old; - return size; - } - template<class Keys, class Key, class Value> - inline ssize_t KD_Tree<Keys, Key, Value>::build(ssize_t beg, - ssize_t end, - std::vector<size_t>* orders, - int depth){ - if(beg > end){ - return NIL; + __dist2Vector[cmp] = dist2_old; + } + //////////////////////////////////////////////////////////////////// + // **# build() #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline ssize_t + KD_Tree<Vector, Scalar>::build(ssize_t __beg, + ssize_t __end, + std::vector<size_t>* __orders, + int __depth){ + if(__beg > __end) return _NIL; + size_t tmp_order = _dimension; + size_t which_side = _dimension + 1; + ssize_t mid = (__beg + __end) / 2; + size_t cmp = __depth % _dimension; + for(ssize_t i = __beg; i <= mid; i++){ + __orders[which_side][__orders[cmp][i]] = 0; } - ssize_t mid = (beg + end) / 2; - size_t cmp = depth % 2; - std::set<size_t> right; - for(ssize_t i = mid + 1; i <= end; i++){ - right.insert(orders[cmp][i]); + for(ssize_t i = mid + 1; i <= __end; i++){ + __orders[which_side][__orders[cmp][i]] = 1; } - for(int i = 0; i < dimension; i++){ + for(int i = 0; i < _dimension; i++){ if(i == cmp) continue; - size_t aa = beg, bb = mid + 1; - for(int j = beg; j <= end; j++){ - if(orders[i][j] == orders[cmp][mid]){ - orders[dimension][mid] = orders[i][j]; - }else if(right.find(orders[i][j]) != right.end()){ - orders[dimension][bb++] = orders[i][j]; + size_t left = __beg, right = mid + 1; + for(int j = __beg; j <= __end; j++){ + size_t ask = __orders[i][j]; + if(ask == __orders[cmp][mid]){ + __orders[tmp_order][mid] = ask; + }else if(__orders[which_side][ask] == 1){ + __orders[tmp_order][right++] = ask; }else{ - orders[dimension][aa++] = orders[i][j]; + __orders[tmp_order][left++] = ask; } } - for(int j = beg; j <= end; j++){ - orders[i][j] = orders[dimension][j]; + for(int j = __beg; j <= __end; j++){ + __orders[i][j] = __orders[tmp_order][j]; } } - nodes[orders[cmp][mid]].lChild = build(beg, mid - 1, orders, depth + 1); - nodes[orders[cmp][mid]].rChild = build(mid + 1, end, orders, depth + 1); - return orders[cmp][mid]; - } - template<class Keys, class Key, class Value> - inline KD_Tree<Keys, Key, Value>::KD_Tree(): - NIL(-1), root(NIL), needRebuild(false), dimension(1){ } - template<class Keys, class Key, class Value> - inline KD_Tree<Keys, Key, Value>::KD_Tree(size_t _dimension): - NIL(-1), root(NIL), needRebuild(false), dimension(_dimension){ } - template<class Keys, class Key, class Value> - inline KD_Tree<Keys, Key, Value>::~KD_Tree(){ } - template<class Keys, class Key, class Value> - inline void KD_Tree<Keys, Key, Value>::insert(Keys const& key, Value value){ - nodes.push_back(Node(key, value, NIL, NIL)); - needRebuild = true; - } - template<class Keys, class Key, class Value> - inline void KD_Tree<Keys, Key, Value>::build(){ - if(needRebuild){ - std::vector<size_t> *orders = new std::vector<size_t>[dimension + 1]; - for(int j = 0; j < dimension + 1; j++){ - orders[j].resize(nodes.size()); - } - for(int j = 0; j < dimension; j++){ - for(size_t i = 0, I = nodes.size(); i < I; i++){ - orders[j][i] = i; + _nodes[__orders[cmp][mid]]._lChild=build(__beg,mid-1,__orders,__depth+1); + _nodes[__orders[cmp][mid]]._rChild=build(mid+1,__end,__orders,__depth+1); + return __orders[cmp][mid]; + } + //////////////////////////////////////////////////////////////////// + // **# constructures/destructures #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::KD_Tree(): + _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(1){ + } + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::KD_Tree(size_t __dimension): + _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(__dimension){ + } + template<class Vector, class Scalar> + inline + KD_Tree<Vector, Scalar>::~KD_Tree(){ + } + //////////////////////////////////////////////////////////////////// + // **# insert, build #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::insert(Vector const& __vector){ + _nodes.push_back(Node(__vector, _NIL, _NIL)); + _needRebuild = true; + } + template<class Vector, class Scalar> + inline bool + KD_Tree<Vector, Scalar>::erase(Vector const& __vector){ + for(size_t i = 0, I = _nodes.size(); i < I; i++){ + if(_nodes[i] == __vector){ + if(i != I - 1){ + std::swap(_nodes[i], _nodes[I - 1]); } - std::sort(orders[j].begin(), orders[j].end(), Sorter(nodes, j)); + _needRebuild = true; + return true; } - root = build(0, (ssize_t)nodes.size() - 1, orders, 0); - needRebuild = false; - delete [] orders; } + return false; } - template<class Keys, class Key, class Value> - inline Value KD_Tree<Keys, Key, Value>::query(Keys const& point, int k) const{ - ((KD_Tree*)this)->build(); - KD_Tree::AnswerList ret; - std::vector<Key> tmp(dimension, Key(0)); - query(point, k, root, 0, tmp, Key(0), &ret); - return (*(ret.rbegin())).node.value; - } - template<class Keys, class Key, class Value> - inline typename KD_Tree<Keys, Key, Value>::Values - KD_Tree<Keys, Key, Value>::rangeQuery(Keys const& point, int k) const{ + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::build(){ + if(_needRebuild){ + forceBuild(); + } + } + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::forceBuild(){ + std::vector<size_t> *orders = new std::vector<size_t>[_dimension + 2]; + for(int j = 0; j < _dimension + 2; j++){ + orders[j].resize(_nodes.size()); + } + for(int j = 0; j < _dimension; j++){ + for(size_t i = 0, I = _nodes.size(); i < I; i++){ + orders[j][i] = i; + } + std::sort(orders[j].begin(), orders[j].end(), Sorter(&_nodes, j)); + } + _root = build(0, (ssize_t)_nodes.size() - 1, orders, 0); + delete [] orders; + _needRebuild = false; + } + //////////////////////////////////////////////////////////////////// + // **# query #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline typename KD_Tree<Vector, Scalar>::Vectors + KD_Tree<Vector, Scalar>::query(Vector const& __vector, + size_t __nearestNumber, + bool __compareWholeVector) const{ ((KD_Tree*)this)->build(); - KD_Tree::AnswerList ret; - std::vector<Key> tmp(dimension, Key(0)); - query(point, k, root, 0, tmp, Key(0), &ret); - KD_Tree::Values ret_val(ret.size()); - int i = 0; - for(KD_Tree::AnswerListIterator it = ret.begin(); it != ret.end(); it++){ - ret_val[i++] = (*it).node.value; + AnswerCompare answer_compare(&_nodes, __compareWholeVector); + Answers answer_set(answer_compare); + std::vector<Scalar> tmp(_dimension, 0); + query(__vector, __nearestNumber, + answer_compare, + _root, 0, + tmp, Scalar(0), + &answer_set); + Vectors ret(answer_set.size()); + for(int i = (ssize_t)answer_set.size() - 1; i >= 0; i--){ + ret[i] = _nodes[answer_set.top()._index]._vector; + answer_set.pop(); } - return ret_val; + return ret; } - template<class Keys, class Key, class Value> - inline void KD_Tree<Keys, Key, Value>::clear(){ - root = NIL; - nodes.clear(); - needRebuild = false; + //////////////////////////////////////////////////////////////////// + // **# clear, reset #** // + //////////////////////////////////////////////////////////////////// + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::clear(){ + _root = _NIL; + _nodes.clear(); + _needRebuild = false; } - template<class Keys, class Key, class Value> - inline void KD_Tree<Keys, Key, Value>::reset(size_t _dimension){ + template<class Vector, class Scalar> + inline void + KD_Tree<Vector, Scalar>::reset(size_t __dimension){ clear(); - dimension = _dimension; + _dimension = __dimension; } } |