27 #ifndef RL_MATH_GNATNEARESTNEIGHBORS_H
28 #define RL_MATH_GNATNEARESTNEIGHBORS_H
34 #include <type_traits>
37 #include <boost/optional.hpp>
52 template<
typename MetricT>
73 typedef typename MetricT::Value
Value;
103 template<
typename InputIterator>
115 if (this->
root.
data.size() > this->nodeDataMax && this->root.data.size() > this->root.degree)
121 template<
typename InputIterator>
133 if (this->
root.
data.size() > this->nodeDataMax && this->root.data.size() > this->root.degree)
152 ::std::vector<Value>
data()
const
154 ::std::vector<Value>
data;
190 template<
typename InputIterator>
191 void insert(InputIterator first, InputIterator last)
197 if (this->
root.
data.size() > this->nodeDataMax && this->root.data.size() > this->root.degree)
206 for (InputIterator i = first; i != last; ++i)
213 ::std::vector<Neighbor>
nearest(
const Value& query, const ::std::size_t& k,
const bool& sorted =
true)
const
215 return this->
search(query, &k,
nullptr, sorted);
229 void seed(const ::std::mt19937::result_type& value)
285 typedef ::std::pair<Distance, const Node*>
Branch;
291 return lhs.first - lhs.second->max[lhs.second->index] > rhs.first - rhs.second->max[rhs.second->index];
299 return lhs.first < rhs.first;
305 Node(const ::std::size_t&
index, const ::std::size_t& siblings, const ::std::size_t&
degree, const ::std::size_t& capacity,
const bool&
removed =
false) :
316 this->
data.reserve(capacity + 1);
319 template<
typename InputIterator>
320 Node(InputIterator first, InputIterator last, const ::std::size_t&
index, const ::std::size_t& siblings, const ::std::size_t&
degree, const ::std::size_t& capacity,
const bool&
removed =
false) :
331 this->
data.reserve(capacity + 1);
364 ::std::vector<Distance>
max;
366 ::std::vector<Distance>
min;
373 void choose(
const Node& node, ::std::vector< ::std::size_t>& centers, ::std::vector< ::std::vector<Distance>>& distances)
375 ::std::size_t k = node.degree;
376 ::std::vector<Distance> min(node.data.size(), ::std::numeric_limits<Distance>::infinity());
378 ::std::uniform_int_distribution< ::std::size_t> distribution(0, node.data.size() - 1);
379 centers[0] = distribution(this->
generator);
381 for (::std::size_t i = 0; i < k - 1; ++i)
385 for (::std::size_t j = 0; j < node.data.size(); ++j)
387 distances[i][j] = j != centers[i] ? this->
metric(node.data[j], node.data[centers[i]]) : 0;
388 min[j] = ::std::min(min[j], distances[i][j]);
398 for (::std::size_t j = 0; j < node.data.size(); ++j)
400 distances[k - 1][j] = this->
metric(node.data[j], node.data[centers[k - 1]]);
404 void data(
const Node& node, ::std::vector<Value>&
data)
const
406 data.insert(
data.end(), node.data.begin(), node.data.end());
408 for (::std::size_t i = 0; i < node.children.size(); ++i)
410 data.push_back(node.children[i].pivot);
417 if (node.children.empty())
419 node.data.push_back(value);
421 if (node.data.size() > this->nodeDataMax && node.data.size() > node.degree)
428 ::std::vector<Distance> distances(node.children.size());
429 ::std::size_t index = 0;
430 Distance min = ::std::numeric_limits<Distance>::infinity();
432 for (::std::size_t i = 0; i < node.children.size(); ++i)
434 distances[i] = this->
metric(value, node.children[i].pivot);
436 if (distances[i] < min)
443 for (::std::size_t i = 0; i < node.children.size(); ++i)
445 node.children[i].max[index] = ::std::max(node.children[i].max[index], distances[i]);
446 node.children[i].min[index] = ::std::min(node.children[i].min[index], distances[i]);
449 this->
push(node.children[index], value);
455 ::std::vector<Neighbor> neighbors;
456 neighbors.reserve(
nullptr != k ? *k : this->
values);
460 ::std::vector<Branch> branches;
463 while (!branches.empty() && (!this->checks || checks < this->
checks))
465 Branch branch = branches.front();
466 ::std::pop_heap(branches.begin(), branches.end(), BranchCompare());
469 if (
nullptr == k || *k == neighbors.size())
473 if (branch.first -
distance > branch.second->max[branch.second->index] ||
474 branch.first + distance < branch.second->min[branch.second->index])
485 ::std::sort_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
490 neighbors.shrink_to_fit();
496 void search(
const Node& node,
const Value& query, const ::std::size_t* k,
const Distance*
radius, ::std::vector<Branch>& branches, ::std::vector<Neighbor>& neighbors, ::std::size_t&
checks)
const
498 if (node.children.empty())
500 for (::std::size_t i = 0; i < node.data.size(); ++i)
504 if (
nullptr == k || neighbors.size() < *k ||
distance < neighbors.front().first)
508 if (
nullptr != k && *k == neighbors.size())
510 ::std::pop_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
511 neighbors.pop_back();
514 #if (defined(_MSC_VER) && _MSC_VER < 1800) || (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8)
515 neighbors.push_back(::std::make_pair(
distance, node.data[i]));
517 neighbors.emplace_back(::std::piecewise_construct, ::std::forward_as_tuple(
distance), ::std::forward_as_tuple(node.data[i]));
519 ::std::push_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
523 if (this->checks && ++
checks > this->checks)
531 ::std::vector<Distance> distances(node.children.size());
532 ::std::vector<bool> removed(node.children.size(),
false);
534 for (::std::size_t i = 0; i < node.children.size(); ++i)
538 distances[i] = this->
metric(query, node.children[i].pivot);
540 if (!node.children[i].removed)
542 if (
nullptr == k || neighbors.size() < *k || distances[i] < neighbors.front().first)
546 if (
nullptr != k && *k == neighbors.size())
548 ::std::pop_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
549 neighbors.pop_back();
552 #if (defined(_MSC_VER) && _MSC_VER < 1800) || (defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ < 8)
553 neighbors.push_back(::std::make_pair(distances[i], node.children[i].pivot));
555 neighbors.emplace_back(::std::piecewise_construct, ::std::forward_as_tuple(distances[i]), ::std::forward_as_tuple(node.children[i].pivot));
557 ::std::push_heap(neighbors.begin(), neighbors.end(), NeighborCompare());
562 if (
nullptr == k || *k == neighbors.size())
566 for (::std::size_t j = 0; j < node.children.size(); ++j)
568 if (i != j && !removed[j])
570 if (distances[i] -
distance > node.children[i].max[j] ||
571 distances[i] +
distance < node.children[i].min[j])
579 if (this->checks && ++
checks > this->checks)
586 for (::std::size_t i = 0; i < node.children.size(); ++i)
592 if (distances[i] -
distance <= node.children[i].max[i] &&
593 distances[i] +
distance >= node.children[i].min[i])
595 #if defined(_MSC_VER) && _MSC_VER < 1800
596 branches.push_back(::std::make_pair(distances[i], &node.children[i]));
598 branches.emplace_back(distances[i], &node.children[i]);
600 ::std::push_heap(branches.begin(), branches.end(), BranchCompare());
610 ::std::vector< ::std::vector<Distance>> distances(node.degree, ::std::vector<Distance>(node.data.size()));
611 ::std::vector< ::std::size_t> centers(node.degree);
612 this->
choose(node, centers, distances);
614 for (::std::size_t i = 0; i < centers.size(); ++i)
616 #if defined(_MSC_VER) && _MSC_VER < 1800
617 node.children.push_back(Node(i, node.degree - 1, this->nodeDegree, this->nodeDataMax));
619 node.children.emplace_back(i, node.degree - 1, this->nodeDegree, this->nodeDataMax);
621 node.children[i].pivot = ::std::move(node.data[centers[i]]);
624 for (::std::size_t i = 0; i < node.data.size(); ++i)
626 ::std::size_t index = 0;
627 Distance min = ::std::numeric_limits<Distance>::infinity();
629 for (::std::size_t j = 0; j < centers.size(); ++j)
640 for (::std::size_t j = 0; j < centers.size(); ++j)
644 node.children[j].max[index] = ::std::max(node.children[j].max[index], distances[j][i]);
645 node.children[j].min[index] = ::std::min(node.children[j].min[index], distances[j][i]);
649 if (i != centers[index])
651 node.children[index].data.push_back(::std::move(node.data[i]));
655 for (::std::size_t i = 0; i < node.children.size(); ++i)
657 node.children[i].degree = ::std::min(::std::max(this->
nodeDegree * node.children[i].data.size() / node.data.size(), this->nodeDegreeMin), this->nodeDegreeMax);
659 if (node.children[i].data.empty())
661 node.children[i].max[i] =
Distance();
662 node.children[i].min[i] =
Distance();
666 ::std::size_t
size = node.data.size();
669 node.data.shrink_to_fit();
671 #pragma omp parallel for if (size > 2 * this->nodeDataMax)
672 #if defined(_OPENMP) && _OPENMP < 200805
673 for (::std::ptrdiff_t i = 0; i < node.children.size(); ++i)
675 for (::std::size_t i = 0; i < node.children.size(); ++i)
678 if (node.children[i].data.size() > this->nodeDataMax && node.children[i].data.size() > node.children[i].degree)
680 this->
split(node.children[i]);
685 ::boost::optional< ::std::size_t>
checks;
706 #endif // RL_MATH_GNATNEARESTNEIGHBORS_H