aboutsummaryrefslogtreecommitdiffstats
path: root/meowpp/dsa/KD_Tree.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'meowpp/dsa/KD_Tree.hpp')
-rw-r--r--meowpp/dsa/KD_Tree.hpp390
1 files changed, 230 insertions, 160 deletions
diff --git a/meowpp/dsa/KD_Tree.hpp b/meowpp/dsa/KD_Tree.hpp
index ac9f868..f0e97a9 100644
--- a/meowpp/dsa/KD_Tree.hpp
+++ b/meowpp/dsa/KD_Tree.hpp
@@ -2,196 +2,266 @@
#include <list>
#include <vector>
#include <algorithm>
-#include <set>
-#include "utility.h"
+#include <queue>
+#include "../utility.h"
namespace meow{
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys,Key,Value>::Node::Node(Keys _key,
- Value _value,
- ssize_t _l_child,
- ssize_t _r_child):
- key(_key), value(_value), lChild(_l_child), rChild(_r_child){ }
- //
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys, Key, Value>::Sorter::Sorter(KD_Tree::Nodes const& _nodes,
- size_t _cmp):
- nodes(_nodes), cmp(_cmp){ }
- template<class Keys, class Key, class Value>
+ ////////////////////////////////////////////////////////////////////
+ // **# Node #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::Node::Node(Vector __vector,
+ ssize_t __lChild, ssize_t __rChild):
+ _vector(__vector), _lChild(__lChild), _rChild(__rChild){
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# Sorter #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::Sorter::Sorter(Nodes const* __nodes, size_t __cmp):
+ _nodes(__nodes), _cmp(__cmp){
+ }
+ template<class Vector, class Scalar>
inline bool
- KD_Tree<Keys, Key, Value>::Sorter::operator()(size_t const& a,
- size_t const& b) const{
- if(nodes[a].key[cmp] != nodes[b].key[cmp]){
- return (nodes[a].key[cmp] < nodes[b].key[cmp]);
+ KD_Tree<Vector, Scalar>::Sorter::operator()(size_t const& __a,
+ size_t const& __b) const{
+ if((*_nodes)[__a]._vector[_cmp] != (*_nodes)[__b]._vector[_cmp]){
+ return ((*_nodes)[__a]._vector[_cmp] < (*_nodes)[__b]._vector[_cmp]);
}
- return (nodes[a].value < nodes[b].value);
+ return ((*_nodes)[__a]._vector < (*_nodes)[__b]._vector);
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# Answer / Answer's Compare class #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::Answer::Answer(ssize_t __index, Scalar __dist2):
+ _index(__index), _dist2(__dist2){
+ }
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::Answer::Answer(Answer const& __answer2):
+ _index(__answer2._index), _dist2(__answer2._dist2){
}
//
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys, Key, Value>::Answer::Answer(Node const& _node,
- Key _dist2):
- node(_node), dist2(_dist2){ }
- template<class Keys, class Key, class Value>
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::AnswerCompare::AnswerCompare(Nodes const* __nodes,
+ bool __cmpValue):
+ _nodes(__nodes), _cmpValue(__cmpValue){
+ }
+ template<class Vector, class Scalar>
inline bool
- KD_Tree<Keys, Key, Value>::Answer::operator<(Answer const& b) const{
- if(dist2 != b.dist2) return (dist2 < b.dist2);
- return (node.value < b.node.value);
+ KD_Tree<Vector, Scalar>::AnswerCompare::operator()(Answer const& __a,
+ Answer const& __b) const{
+ if(_cmpValue == true && __a._dist2 == __b._dist2){
+ return ((*_nodes)[__a._index]._vector < (*_nodes)[__b._index]._vector);
+ }
+ return (__a._dist2 < __b._dist2);
}
- //
- template<class Keys, class Key, class Value>
- inline Key KD_Tree<Keys, Key, Value>::distance2(Keys const& k1,
- Keys const& k2) const{
- Key ret(0);
- for(size_t i = 0; i < dimension; i++)
- ret += squ(k1[i] - k2[i]);
+ ////////////////////////////////////////////////////////////////////
+ // **# distance2() #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline Scalar
+ KD_Tree<Vector, Scalar>::distance2(Vector const& __v1,
+ Vector const& __v2) const{
+ Scalar ret(0);
+ for(size_t i = 0; i < _dimension; i++){
+ ret += squ(__v1[i] - __v2[i]);
+ }
return ret;
}
- template<class Keys, class Key, class Value>
- inline size_t KD_Tree<Keys, Key, Value>::query(Keys const& key,
- size_t k,
- size_t index,
- int depth,
- std::vector<Key>& dist2_v,
- Key dist2_s,
- KD_Tree::AnswerList* ret) const{
- if(index == NIL){
- return 0;
- }
- size_t cmp = depth % dimension;
- ssize_t right_side, opposite, size;
- ssize_t sz, other;
- if(key[cmp] <= nodes[index].key[cmp]){
- right_side = nodes[index].lChild;
- opposite = nodes[index].rChild;
+ ////////////////////////////////////////////////////////////////////
+ // **# query() #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::query(Vector const& __vector,
+ size_t __nearestNumber,
+ AnswerCompare const& __answerCompare,
+ size_t __index,
+ int __depth,
+ std::vector<Scalar>& __dist2Vector,
+ Scalar __dist2Minimum,
+ Answers *__out) const{
+ if(__index == _NIL) return ;
+ size_t cmp = __depth % _dimension;
+ ssize_t this_side, that_side;
+ if(!(_nodes[__index]._vector[cmp] < __vector[cmp])){
+ this_side = _nodes[__index]._lChild;
+ that_side = _nodes[__index]._rChild;
}else{
- right_side = nodes[index].rChild;
- opposite = nodes[index].lChild;
+ this_side = _nodes[__index]._rChild;
+ that_side = _nodes[__index]._lChild;
}
- size = query(key, k, right_side, depth + 1, dist2_v, dist2_s, ret);
- Answer my_ans(nodes[index], distance2(nodes[index].key, key));
- if(size < k || my_ans < *(ret->rbegin())){
- KD_Tree::AnswerListIterator it = ret->begin();
- while(it != ret->end() && !(my_ans < *it)) it++;
- ret->insert(it, my_ans);
- size++;
+ query(__vector, __nearestNumber, __answerCompare,
+ this_side, __depth + 1,
+ __dist2Vector, __dist2Minimum,
+ __out);
+ Answer my_ans(__index, distance2(_nodes[__index]._vector, __vector));
+ if(__out->size() < __nearestNumber ||
+ __answerCompare(my_ans, __out->top())){
+ __out->push(my_ans);
+ if(__out->size() > __nearestNumber) __out->pop();
}
- Key dist2_old = dist2_v[cmp];
- dist2_v[cmp] = squ(nodes[index].key[cmp] - key[cmp]);
- dist2_s += dist2_v[cmp] - dist2_old;
- if(size < k || (*(ret->rbegin())).dist2 >= dist2_s){
- KD_Tree::AnswerList ret2;
- size += query(key, k, opposite, depth + 1, dist2_v, dist2_s, &ret2);
- KD_Tree::AnswerListIterator it1, it2;
- for(it1 = ret->begin(), it2 = ret2.begin(); it2 != ret2.end(); it2++){
- while(it1 != ret->end() && *it1 < *it2) it1++;
- it1 = ++(ret->insert(it1, *it2));
- }
+ Scalar dist2_old = __dist2Vector[cmp];
+ __dist2Vector[cmp] = squ(_nodes[__index]._vector[cmp] - __vector[cmp]);
+ Scalar dist2Minimum = __dist2Minimum + __dist2Vector[cmp] - dist2_old;
+ if(__out->size() < __nearestNumber ||
+ !(__out->top()._dist2 < dist2Minimum)){
+ query(__vector, __nearestNumber, __answerCompare,
+ that_side, __depth + 1,
+ __dist2Vector, dist2Minimum,
+ __out);
}
- if(size > k){
- for(int i = size - k; i--; ){
- ret->pop_back();
- }
- size = k;
- }
- dist2_v[cmp] = dist2_old;
- return size;
- }
- template<class Keys, class Key, class Value>
- inline ssize_t KD_Tree<Keys, Key, Value>::build(ssize_t beg,
- ssize_t end,
- std::vector<size_t>* orders,
- int depth){
- if(beg > end){
- return NIL;
+ __dist2Vector[cmp] = dist2_old;
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# build() #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline ssize_t
+ KD_Tree<Vector, Scalar>::build(ssize_t __beg,
+ ssize_t __end,
+ std::vector<size_t>* __orders,
+ int __depth){
+ if(__beg > __end) return _NIL;
+ size_t tmp_order = _dimension;
+ size_t which_side = _dimension + 1;
+ ssize_t mid = (__beg + __end) / 2;
+ size_t cmp = __depth % _dimension;
+ for(ssize_t i = __beg; i <= mid; i++){
+ __orders[which_side][__orders[cmp][i]] = 0;
}
- ssize_t mid = (beg + end) / 2;
- size_t cmp = depth % 2;
- std::set<size_t> right;
- for(ssize_t i = mid + 1; i <= end; i++){
- right.insert(orders[cmp][i]);
+ for(ssize_t i = mid + 1; i <= __end; i++){
+ __orders[which_side][__orders[cmp][i]] = 1;
}
- for(int i = 0; i < dimension; i++){
+ for(int i = 0; i < _dimension; i++){
if(i == cmp) continue;
- size_t aa = beg, bb = mid + 1;
- for(int j = beg; j <= end; j++){
- if(orders[i][j] == orders[cmp][mid]){
- orders[dimension][mid] = orders[i][j];
- }else if(right.find(orders[i][j]) != right.end()){
- orders[dimension][bb++] = orders[i][j];
+ size_t left = __beg, right = mid + 1;
+ for(int j = __beg; j <= __end; j++){
+ size_t ask = __orders[i][j];
+ if(ask == __orders[cmp][mid]){
+ __orders[tmp_order][mid] = ask;
+ }else if(__orders[which_side][ask] == 1){
+ __orders[tmp_order][right++] = ask;
}else{
- orders[dimension][aa++] = orders[i][j];
+ __orders[tmp_order][left++] = ask;
}
}
- for(int j = beg; j <= end; j++){
- orders[i][j] = orders[dimension][j];
+ for(int j = __beg; j <= __end; j++){
+ __orders[i][j] = __orders[tmp_order][j];
}
}
- nodes[orders[cmp][mid]].lChild = build(beg, mid - 1, orders, depth + 1);
- nodes[orders[cmp][mid]].rChild = build(mid + 1, end, orders, depth + 1);
- return orders[cmp][mid];
- }
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys, Key, Value>::KD_Tree():
- NIL(-1), root(NIL), needRebuild(false), dimension(1){ }
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys, Key, Value>::KD_Tree(size_t _dimension):
- NIL(-1), root(NIL), needRebuild(false), dimension(_dimension){ }
- template<class Keys, class Key, class Value>
- inline KD_Tree<Keys, Key, Value>::~KD_Tree(){ }
- template<class Keys, class Key, class Value>
- inline void KD_Tree<Keys, Key, Value>::insert(Keys const& key, Value value){
- nodes.push_back(Node(key, value, NIL, NIL));
- needRebuild = true;
- }
- template<class Keys, class Key, class Value>
- inline void KD_Tree<Keys, Key, Value>::build(){
- if(needRebuild){
- std::vector<size_t> *orders = new std::vector<size_t>[dimension + 1];
- for(int j = 0; j < dimension + 1; j++){
- orders[j].resize(nodes.size());
- }
- for(int j = 0; j < dimension; j++){
- for(size_t i = 0, I = nodes.size(); i < I; i++){
- orders[j][i] = i;
+ _nodes[__orders[cmp][mid]]._lChild=build(__beg,mid-1,__orders,__depth+1);
+ _nodes[__orders[cmp][mid]]._rChild=build(mid+1,__end,__orders,__depth+1);
+ return __orders[cmp][mid];
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# constructures/destructures #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::KD_Tree():
+ _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(1){
+ }
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::KD_Tree(size_t __dimension):
+ _NIL(-1), _root(_NIL), _needRebuild(false), _dimension(__dimension){
+ }
+ template<class Vector, class Scalar>
+ inline
+ KD_Tree<Vector, Scalar>::~KD_Tree(){
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# insert, build #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::insert(Vector const& __vector){
+ _nodes.push_back(Node(__vector, _NIL, _NIL));
+ _needRebuild = true;
+ }
+ template<class Vector, class Scalar>
+ inline bool
+ KD_Tree<Vector, Scalar>::erase(Vector const& __vector){
+ for(size_t i = 0, I = _nodes.size(); i < I; i++){
+ if(_nodes[i] == __vector){
+ if(i != I - 1){
+ std::swap(_nodes[i], _nodes[I - 1]);
}
- std::sort(orders[j].begin(), orders[j].end(), Sorter(nodes, j));
+ _needRebuild = true;
+ return true;
}
- root = build(0, (ssize_t)nodes.size() - 1, orders, 0);
- needRebuild = false;
- delete [] orders;
}
+ return false;
}
- template<class Keys, class Key, class Value>
- inline Value KD_Tree<Keys, Key, Value>::query(Keys const& point, int k) const{
- ((KD_Tree*)this)->build();
- KD_Tree::AnswerList ret;
- std::vector<Key> tmp(dimension, Key(0));
- query(point, k, root, 0, tmp, Key(0), &ret);
- return (*(ret.rbegin())).node.value;
- }
- template<class Keys, class Key, class Value>
- inline typename KD_Tree<Keys, Key, Value>::Values
- KD_Tree<Keys, Key, Value>::rangeQuery(Keys const& point, int k) const{
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::build(){
+ if(_needRebuild){
+ forceBuild();
+ }
+ }
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::forceBuild(){
+ std::vector<size_t> *orders = new std::vector<size_t>[_dimension + 2];
+ for(int j = 0; j < _dimension + 2; j++){
+ orders[j].resize(_nodes.size());
+ }
+ for(int j = 0; j < _dimension; j++){
+ for(size_t i = 0, I = _nodes.size(); i < I; i++){
+ orders[j][i] = i;
+ }
+ std::sort(orders[j].begin(), orders[j].end(), Sorter(&_nodes, j));
+ }
+ _root = build(0, (ssize_t)_nodes.size() - 1, orders, 0);
+ delete [] orders;
+ _needRebuild = false;
+ }
+ ////////////////////////////////////////////////////////////////////
+ // **# query #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline typename KD_Tree<Vector, Scalar>::Vectors
+ KD_Tree<Vector, Scalar>::query(Vector const& __vector,
+ size_t __nearestNumber,
+ bool __compareWholeVector) const{
((KD_Tree*)this)->build();
- KD_Tree::AnswerList ret;
- std::vector<Key> tmp(dimension, Key(0));
- query(point, k, root, 0, tmp, Key(0), &ret);
- KD_Tree::Values ret_val(ret.size());
- int i = 0;
- for(KD_Tree::AnswerListIterator it = ret.begin(); it != ret.end(); it++){
- ret_val[i++] = (*it).node.value;
+ AnswerCompare answer_compare(&_nodes, __compareWholeVector);
+ Answers answer_set(answer_compare);
+ std::vector<Scalar> tmp(_dimension, 0);
+ query(__vector, __nearestNumber,
+ answer_compare,
+ _root, 0,
+ tmp, Scalar(0),
+ &answer_set);
+ Vectors ret(answer_set.size());
+ for(int i = (ssize_t)answer_set.size() - 1; i >= 0; i--){
+ ret[i] = _nodes[answer_set.top()._index]._vector;
+ answer_set.pop();
}
- return ret_val;
+ return ret;
}
- template<class Keys, class Key, class Value>
- inline void KD_Tree<Keys, Key, Value>::clear(){
- root = NIL;
- nodes.clear();
- needRebuild = false;
+ ////////////////////////////////////////////////////////////////////
+ // **# clear, reset #** //
+ ////////////////////////////////////////////////////////////////////
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::clear(){
+ _root = _NIL;
+ _nodes.clear();
+ _needRebuild = false;
}
- template<class Keys, class Key, class Value>
- inline void KD_Tree<Keys, Key, Value>::reset(size_t _dimension){
+ template<class Vector, class Scalar>
+ inline void
+ KD_Tree<Vector, Scalar>::reset(size_t __dimension){
clear();
- dimension = _dimension;
+ _dimension = __dimension;
}
}