27 #ifndef RL_MATH_KDTREEBOUNDINGBOXNEARESTNEIGHBORS_H
28 #define RL_MATH_KDTREEBOUNDINGBOXNEARESTNEIGHBORS_H
37 #include <type_traits>
40 #include <boost/optional.hpp>
55 template<
typename MetricT>
76 typedef typename MetricT::Size
Size;
78 typedef typename MetricT::Value
Value;
102 template<
typename InputIterator>
111 this->
insert(first, last);
114 template<
typename InputIterator>
123 this->
insert(first, last);
139 ::std::vector<Value>
data()
const
141 ::std::vector<Value>
data;
162 template<
typename InputIterator>
163 void insert(InputIterator first, InputIterator last)
186 for (InputIterator i = first; i != last; ++i)
193 ::std::vector<Neighbor>
nearest(
const Value& query, const ::std::size_t& k,
const bool& sorted =
true)
const
195 return this->
search(query, &k,
nullptr, sorted);
207 for (::std::size_t i = 0; i < this->
boundingBox.size(); ++i)
216 for (::std::size_t i = 0; i < this->
boundingBox.size(); ++i)
252 swap(this->mean, other.mean);
254 swap(this->samples, other.samples);
257 swap(this->var, other.var);
301 return *(begin(other) + this->
index) < this->
value;
319 return *(begin(lhs) + this->
index) < *(begin(rhs) + this->
index);
336 return lhs.first < rhs.first;
368 ::std::array< ::std::unique_ptr<Node>, 2>
children;
377 template<
typename InputIterator>
382 for (::std::size_t i = 0; i <
boundingBox.size(); ++i)
384 Distance value = *(begin(*first) + i);
389 for (InputIterator i = first + 1; i != last; ++i)
391 auto start = begin(*i);
393 for (::std::size_t j = 0; j <
boundingBox.size(); ++j)
402 void data(
const Node& node, ::std::vector<Value>&
data)
const
404 data.insert(
data.end(), node.data.begin(), node.data.end());
406 for (::std::size_t i = 0; i < node.children.size(); ++i)
408 this->
data(*node.children[i],
data);
412 template<
typename InputIterator>
416 node.index = cut.index;
417 InputIterator split = ::std::partition(first, last, cut);
421 for (::std::size_t i = 0; i < 2; ++i)
423 InputIterator begin = 0 == i ? first : split;
424 InputIterator end = 0 == i ? split : last;
429 boundingBoxes[i][cut.index].high = cut.value;
432 boundingBoxes[i][cut.index].low = cut.value;
438 #if __cplusplus > 201103L || _MSC_VER >= 1800
439 node.children[i] = ::std::make_unique<Node>();
441 node.children[i].reset(
new Node());
446 this->
divide(*node.children[i], boundingBoxes[i], begin, end);
450 node.children[i]->data.insert(node.children[i]->data.end(), begin, end);
455 node.interval.low = boundingBoxes[0][cut.index].high;
456 node.interval.high = boundingBoxes[1][cut.index].low;
458 for (::std::size_t i = 0; i <
boundingBox.size(); ++i)
460 boundingBox[i].low = ::std::min(boundingBoxes[0][i].low, boundingBoxes[1][i].low);
461 boundingBox[i].high = ::std::max(boundingBoxes[0][i].high, boundingBoxes[1][i].high);
470 if (
nullptr == node.children[0] &&
nullptr == node.children[1])
472 node.data.push_back(value);
474 if (node.data.size() > this->nodeDataMax)
480 node.data.shrink_to_fit();
485 Distance tmp = *(begin(value) + node.index);
486 Distance diff0 = tmp - node.interval.low;
487 Distance diff1 = tmp - node.interval.high;
492 this->
push(*node.children[0], value);
493 node.interval.low = ::std::max(node.interval.low, tmp);
497 this->
push(*node.children[1], value);
498 node.interval.high = ::std::min(node.interval.high, tmp);
508 ::std::vector<Neighbor> neighbors;
515 neighbors.reserve(
nullptr != k ? *k : this->
values);
520 ::std::vector<Distance> sidedist(
size(query),
Distance());
522 for (::std::size_t i = 0; i < sidedist.size(); ++i)
524 Distance value = *(begin(query) + i);
529 mindist += sidedist[i];
535 mindist += sidedist[i];
539 ::std::vector<Branch> branches;
543 while (!branches.empty() && (!this->checks || checks < this->
checks))
545 Branch branch = branches.front();
546 ::std::pop_heap(branches.begin(), branches.end(), BranchCompare());
548 this->
search(*branch.node, query, k,
radius, branches, neighbors,
checks, branch.dist, branch.sidedist);
553 ::std::sort_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
558 neighbors.shrink_to_fit();
564 void search(
const Node& node,
const Value& query, const ::std::size_t* k,
const Distance*
radius, ::std::vector<Branch>& branches, ::std::vector<Neighbor>& neighbors, ::std::size_t&
checks,
const Distance& mindist, const ::std::vector<Distance>& sidedist)
const
568 if (
nullptr == node.children[0] &&
nullptr == node.children[1])
570 for (::std::size_t i = 0; i < node.data.size(); ++i)
574 if (
nullptr == k || neighbors.size() < *k ||
distance < neighbors.front().first)
578 if (
nullptr != k && *k == neighbors.size())
580 ::std::pop_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
581 neighbors.pop_back();
584 #if (defined(_MSC_VER) && _MSC_VER < 1800) || (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8)
585 neighbors.push_back(::std::make_pair(
distance, node.data[i]));
587 neighbors.emplace_back(::std::piecewise_construct, ::std::forward_as_tuple(
distance), ::std::forward_as_tuple(node.data[i]));
589 ::std::push_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
593 if (this->checks && ++
checks > this->checks)
601 Distance value = *(begin(query) + node.index);
602 Distance diff0 = value - node.interval.low;
603 Distance diff1 = value - node.interval.high;
606 ::std::size_t best = diff < 0 ? 0 : 1;
607 ::std::size_t worst = diff < 0 ? 1 : 0;
609 this->
search(*node.children[best], query, k,
radius, branches, neighbors,
checks, mindist, sidedist);
611 Distance cutdist = this->
metric(value, diff < 0 ? node.interval.high : node.interval.low, node.index);
612 Distance newdist = mindist - sidedist[node.index] + cutdist;
614 if (
nullptr == k || neighbors.size() < *k || newdist <= neighbors.front().first)
616 ::std::vector<Distance> newsidedist(sidedist);
617 newsidedist[node.index] = cutdist;
621 this->
search(*node.children[worst], query, k,
radius, branches, neighbors,
checks, newdist, newsidedist);
625 #if defined(_MSC_VER) && _MSC_VER < 1800
626 branches.push_back(Branch(newdist, newsidedist, node.children[worst].get()));
628 branches.emplace_back(newdist, newsidedist, node.children[worst].get());
630 ::std::push_heap(branches.begin(), branches.end(), BranchCompare());
636 template<
typename InputIterator>
647 for (::std::size_t i = 1; i <
boundingBox.size(); ++i)
659 ::std::pair<InputIterator, InputIterator> minmax = ::std::minmax_element(first, last, IndexCompare(cut.index));
660 Distance min = *(begin(*minmax.first) + cut.index);
661 Distance max = *(begin(*minmax.second) + cut.index);
663 cut.value = (max + min) / 2;
666 Size index = cut.index;
668 for (::std::size_t i = 0; i <
boundingBox.size(); ++i)
676 minmax = ::std::minmax_element(first, last, IndexCompare(i));
677 min = *(begin(*minmax.first) + i);
678 max = *(begin(*minmax.second) + i);
685 cut.value = (max + min) / 2;
696 ::boost::optional< ::std::size_t>
checks;
709 #endif // RL_MATH_KDTREEBOUNDINGBOXNEARESTNEIGHBORS_H