#ifndef OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
#define OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
#include <algorithm>
#include <map>
#include <cassert>
#include <cstring>
#include "general.h"
#include "nn_index.h"
#include "matrix.h"
#include "result_set.h"
#include "heap.h"
#include "allocator.h"
#include "random.h"
#include "saving.h"
namespace cvflann
{
struct KDTreeSingleIndexParams : public IndexParams
{
KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true, int dim = -1)
{
(*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
(*this)["leaf_max_size"] = leaf_max_size;
(*this)["reorder"] = reorder;
(*this)["dim"] = dim;
}
};
template <typename Distance>
class KDTreeSingleIndex : public NNIndex<Distance>
{
public:
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
Distance d = Distance() ) :
dataset_(inputData), index_params_(params), distance_(d)
{
size_ = dataset_.rows;
dim_ = dataset_.cols;
int dim_param = get_param(params,"dim",-1);
if (dim_param>0) dim_ = dim_param;
leaf_max_size_ = get_param(params,"leaf_max_size",10);
reorder_ = get_param(params,"reorder",true);
vind_.resize(size_);
for (size_t i = 0; i < size_; i++) {
vind_[i] = (int)i;
}
}
KDTreeSingleIndex(const KDTreeSingleIndex&);
KDTreeSingleIndex& operator=(const KDTreeSingleIndex&);
~KDTreeSingleIndex()
{
if (reorder_) delete[] data_.data;
}
void buildIndex()
{
computeBoundingBox(root_bbox_);
root_node_ = divideTree(0, (int)size_, root_bbox_ );
if (reorder_) {
delete[] data_.data;
data_ = cvflann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_);
for (size_t i=0; i<size_; ++i) {
for (size_t j=0; j<dim_; ++j) {
data_[i][j] = dataset_[vind_[i]][j];
}
}
}
else {
data_ = dataset_;
}
}
flann_algorithm_t getType() const
{
return FLANN_INDEX_KDTREE_SINGLE;
}
void saveIndex(FILE* stream)
{
save_value(stream, size_);
save_value(stream, dim_);
save_value(stream, root_bbox_);
save_value(stream, reorder_);
save_value(stream, leaf_max_size_);
save_value(stream, vind_);
if (reorder_) {
save_value(stream, data_);
}
save_tree(stream, root_node_);
}
void loadIndex(FILE* stream)
{
load_value(stream, size_);
load_value(stream, dim_);
load_value(stream, root_bbox_);
load_value(stream, reorder_);
load_value(stream, leaf_max_size_);
load_value(stream, vind_);
if (reorder_) {
load_value(stream, data_);
}
else {
data_ = dataset_;
}
load_tree(stream, root_node_);
index_params_["algorithm"] = getType();
index_params_["leaf_max_size"] = leaf_max_size_;
index_params_["reorder"] = reorder_;
}
size_t size() const
{
return size_;
}
size_t veclen() const
{
return dim_;
}
int usedMemory() const
{
return (int)(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));
}
void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params)
{
assert(queries.cols == veclen());
assert(indices.rows >= queries.rows);
assert(dists.rows >= queries.rows);
assert(int(indices.cols) >= knn);
assert(int(dists.cols) >= knn);
KNNSimpleResultSet<DistanceType> resultSet(knn);
for (size_t i = 0; i < queries.rows; i++) {
resultSet.init(indices[i], dists[i]);
findNeighbors(resultSet, queries[i], params);
}
}
IndexParams getParameters() const
{
return index_params_;
}
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
float epsError = 1+get_param(searchParams,"eps",0.0f);
std::vector<DistanceType> dists(dim_,0);
DistanceType distsq = computeInitialDistances(vec, dists);
searchLevel(result, vec, root_node_, distsq, dists, epsError);
}
private:
struct Node
{
int left, right;
int divfeat;
DistanceType divlow, divhigh;
Node* child1, * child2;
};
typedef Node* NodePtr;
struct Interval
{
DistanceType low, high;
};
typedef std::vector<Interval> BoundingBox;
typedef BranchStruct<NodePtr, DistanceType> BranchSt;
typedef BranchSt* Branch;
void save_tree(FILE* stream, NodePtr tree)
{
save_value(stream, *tree);
if (tree->child1!=NULL) {
save_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
save_tree(stream, tree->child2);
}
}
void load_tree(FILE* stream, NodePtr& tree)
{
tree = pool_.allocate<Node>();
load_value(stream, *tree);
if (tree->child1!=NULL) {
load_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
load_tree(stream, tree->child2);
}
}
void computeBoundingBox(BoundingBox& bbox)
{
bbox.resize(dim_);
for (size_t i=0; i<dim_; ++i) {
bbox[i].low = (DistanceType)dataset_[0][i];
bbox[i].high = (DistanceType)dataset_[0][i];
}
for (size_t k=1; k<dataset_.rows; ++k) {
for (size_t i=0; i<dim_; ++i) {
if (dataset_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[k][i];
if (dataset_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[k][i];
}
}
}
NodePtr divideTree(int left, int right, BoundingBox& bbox)
{
NodePtr node = pool_.allocate<Node>();
if ( (right-left) <= leaf_max_size_) {
node->child1 = node->child2 = NULL;
node->left = left;
node->right = right;
for (size_t i=0; i<dim_; ++i) {
bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
}
for (int k=left+1; k<right; ++k) {
for (size_t i=0; i<dim_; ++i) {
if (bbox[i].low>dataset_[vind_[k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[k]][i];
if (bbox[i].high<dataset_[vind_[k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[k]][i];
}
}
}
else {
int idx;
int cutfeat;
DistanceType cutval;
middleSplit_(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
node->divfeat = cutfeat;
BoundingBox left_bbox(bbox);
left_bbox[cutfeat].high = cutval;
node->child1 = divideTree(left, left+idx, left_bbox);
BoundingBox right_bbox(bbox);
right_bbox[cutfeat].low = cutval;
node->child2 = divideTree(left+idx, right, right_bbox);
node->divlow = left_bbox[cutfeat].high;
node->divhigh = right_bbox[cutfeat].low;
for (size_t i=0; i<dim_; ++i) {
bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
}
}
return node;
}
void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
{
min_elem = dataset_[ind[0]][dim];
max_elem = dataset_[ind[0]][dim];
for (int i=1; i<count; ++i) {
ElementType val = dataset_[ind[i]][dim];
if (val<min_elem) min_elem = val;
if (val>max_elem) max_elem = val;
}
}
void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
{
ElementType max_span = bbox[0].high-bbox[0].low;
cutfeat = 0;
cutval = (bbox[0].high+bbox[0].low)/2;
for (size_t i=1; i<dim_; ++i) {
ElementType span = bbox[i].high-bbox[i].low;
if (span>max_span) {
max_span = span;
cutfeat = i;
cutval = (bbox[i].high+bbox[i].low)/2;
}
}
ElementType min_elem, max_elem;
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
cutval = (min_elem+max_elem)/2;
max_span = max_elem - min_elem;
size_t k = cutfeat;
for (size_t i=0; i<dim_; ++i) {
if (i==k) continue;
ElementType span = bbox[i].high-bbox[i].low;
if (span>max_span) {
computeMinMax(ind, count, i, min_elem, max_elem);
span = max_elem - min_elem;
if (span>max_span) {
max_span = span;
cutfeat = i;
cutval = (min_elem+max_elem)/2;
}
}
}
int lim1, lim2;
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
if (lim1>count/2) index = lim1;
else if (lim2<count/2) index = lim2;
else index = count/2;
}
void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
{
const float EPS=0.00001f;
DistanceType max_span = bbox[0].high-bbox[0].low;
for (size_t i=1; i<dim_; ++i) {
DistanceType span = bbox[i].high-bbox[i].low;
if (span>max_span) {
max_span = span;
}
}
DistanceType max_spread = -1;
cutfeat = 0;
for (size_t i=0; i<dim_; ++i) {
DistanceType span = bbox[i].high-bbox[i].low;
if (span>(DistanceType)((1-EPS)*max_span)) {
ElementType min_elem, max_elem;
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
DistanceType spread = (DistanceType)(max_elem-min_elem);
if (spread>max_spread) {
cutfeat = (int)i;
max_spread = spread;
}
}
}
DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
ElementType min_elem, max_elem;
computeMinMax(ind, count, cutfeat, min_elem, max_elem);
if (split_val<min_elem) cutval = (DistanceType)min_elem;
else if (split_val>max_elem) cutval = (DistanceType)max_elem;
else cutval = split_val;
int lim1, lim2;
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
if (lim1>count/2) index = lim1;
else if (lim2<count/2) index = lim2;
else index = count/2;
}
void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
{
int left = 0;
int right = count-1;
for (;; ) {
while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
if (left>right) break;
std::swap(ind[left], ind[right]); ++left; --right;
}
lim1 = left;
right = count-1;
for (;; ) {
while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
if (left>right) break;
std::swap(ind[left], ind[right]); ++left; --right;
}
lim2 = left;
}
DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists)
{
DistanceType distsq = 0.0;
for (size_t i = 0; i < dim_; ++i) {
if (vec[i] < root_bbox_[i].low) {
dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, (int)i);
distsq += dists[i];
}
if (vec[i] > root_bbox_[i].high) {
dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, (int)i);
distsq += dists[i];
}
}
return distsq;
}
void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
std::vector<DistanceType>& dists, const float epsError)
{
if ((node->child1 == NULL)&&(node->child2 == NULL)) {
DistanceType worst_dist = result_set.worstDist();
for (int i=node->left; i<node->right; ++i) {
int index = reorder_ ? i : vind_[i];
DistanceType dist = distance_(vec, data_[index], dim_, worst_dist);
if (dist<worst_dist) {
result_set.addPoint(dist,vind_[i]);
}
}
return;
}
int idx = node->divfeat;
ElementType val = vec[idx];
DistanceType diff1 = val - node->divlow;
DistanceType diff2 = val - node->divhigh;
NodePtr bestChild;
NodePtr otherChild;
DistanceType cut_dist;
if ((diff1+diff2)<0) {
bestChild = node->child1;
otherChild = node->child2;
cut_dist = distance_.accum_dist(val, node->divhigh, idx);
}
else {
bestChild = node->child2;
otherChild = node->child1;
cut_dist = distance_.accum_dist( val, node->divlow, idx);
}
searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
DistanceType dst = dists[idx];
mindistsq = mindistsq + cut_dist - dst;
dists[idx] = cut_dist;
if (mindistsq*epsError<=result_set.worstDist()) {
searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
}
dists[idx] = dst;
}
private:
const Matrix<ElementType> dataset_;
IndexParams index_params_;
int leaf_max_size_;
bool reorder_;
std::vector<int> vind_;
Matrix<ElementType> data_;
size_t size_;
size_t dim_;
NodePtr root_node_;
BoundingBox root_bbox_;
PooledAllocator pool_;
Distance distance_;
};
}
#endif