#ifndef OPENCV_FLANN_KDTREE_INDEX_H_
#define OPENCV_FLANN_KDTREE_INDEX_H_
#include <algorithm>
#include <map>
#include <cassert>
#include <cstring>
#include "general.h"
#include "nn_index.h"
#include "dynamic_bitset.h"
#include "matrix.h"
#include "result_set.h"
#include "heap.h"
#include "allocator.h"
#include "random.h"
#include "saving.h"
namespace cvflann
{
struct KDTreeIndexParams : public IndexParams
{
KDTreeIndexParams(int trees = 4)
{
(*this)["algorithm"] = FLANN_INDEX_KDTREE;
(*this)["trees"] = trees;
}
};
template <typename Distance>
class KDTreeIndex : public NNIndex<Distance>
{
public:
typedef typename Distance::ElementType ElementType;
typedef typename Distance::ResultType DistanceType;
KDTreeIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeIndexParams(),
Distance d = Distance() ) :
dataset_(inputData), index_params_(params), distance_(d)
{
size_ = dataset_.rows;
veclen_ = dataset_.cols;
trees_ = get_param(index_params_,"trees",4);
tree_roots_ = new NodePtr[trees_];
vind_.resize(size_);
for (size_t i = 0; i < size_; ++i) {
vind_[i] = int(i);
}
mean_ = new DistanceType[veclen_];
var_ = new DistanceType[veclen_];
}
KDTreeIndex(const KDTreeIndex&);
KDTreeIndex& operator=(const KDTreeIndex&);
~KDTreeIndex()
{
if (tree_roots_!=NULL) {
delete[] tree_roots_;
}
delete[] mean_;
delete[] var_;
}
void buildIndex()
{
for (int i = 0; i < trees_; i++) {
std::random_shuffle(vind_.begin(), vind_.end());
tree_roots_[i] = divideTree(&vind_[0], int(size_) );
}
}
flann_algorithm_t getType() const
{
return FLANN_INDEX_KDTREE;
}
void saveIndex(FILE* stream)
{
save_value(stream, trees_);
for (int i=0; i<trees_; ++i) {
save_tree(stream, tree_roots_[i]);
}
}
void loadIndex(FILE* stream)
{
load_value(stream, trees_);
if (tree_roots_!=NULL) {
delete[] tree_roots_;
}
tree_roots_ = new NodePtr[trees_];
for (int i=0; i<trees_; ++i) {
load_tree(stream,tree_roots_[i]);
}
index_params_["algorithm"] = getType();
index_params_["trees"] = tree_roots_;
}
size_t size() const
{
return size_;
}
size_t veclen() const
{
return veclen_;
}
int usedMemory() const
{
return int(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int));
}
void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams)
{
int maxChecks = get_param(searchParams,"checks", 32);
float epsError = 1+get_param(searchParams,"eps",0.0f);
if (maxChecks==FLANN_CHECKS_UNLIMITED) {
getExactNeighbors(result, vec, epsError);
}
else {
getNeighbors(result, vec, maxChecks, epsError);
}
}
IndexParams getParameters() const
{
return index_params_;
}
private:
struct Node
{
int divfeat;
DistanceType divval;
Node* child1, * child2;
};
typedef Node* NodePtr;
typedef BranchStruct<NodePtr, DistanceType> BranchSt;
typedef BranchSt* Branch;
void save_tree(FILE* stream, NodePtr tree)
{
save_value(stream, *tree);
if (tree->child1!=NULL) {
save_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
save_tree(stream, tree->child2);
}
}
void load_tree(FILE* stream, NodePtr& tree)
{
tree = pool_.allocate<Node>();
load_value(stream, *tree);
if (tree->child1!=NULL) {
load_tree(stream, tree->child1);
}
if (tree->child2!=NULL) {
load_tree(stream, tree->child2);
}
}
NodePtr divideTree(int* ind, int count)
{
NodePtr node = pool_.allocate<Node>();
if ( count == 1) {
node->child1 = node->child2 = NULL;
node->divfeat = *ind;
}
else {
int idx;
int cutfeat;
DistanceType cutval;
meanSplit(ind, count, idx, cutfeat, cutval);
node->divfeat = cutfeat;
node->divval = cutval;
node->child1 = divideTree(ind, idx);
node->child2 = divideTree(ind+idx, count-idx);
}
return node;
}
void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval)
{
memset(mean_,0,veclen_*sizeof(DistanceType));
memset(var_,0,veclen_*sizeof(DistanceType));
int cnt = std::min((int)SAMPLE_MEAN+1, count);
for (int j = 0; j < cnt; ++j) {
ElementType* v = dataset_[ind[j]];
for (size_t k=0; k<veclen_; ++k) {
mean_[k] += v[k];
}
}
for (size_t k=0; k<veclen_; ++k) {
mean_[k] /= cnt;
}
for (int j = 0; j < cnt; ++j) {
ElementType* v = dataset_[ind[j]];
for (size_t k=0; k<veclen_; ++k) {
DistanceType dist = v[k] - mean_[k];
var_[k] += dist * dist;
}
}
cutfeat = selectDivision(var_);
cutval = mean_[cutfeat];
int lim1, lim2;
planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
if (lim1>count/2) index = lim1;
else if (lim2<count/2) index = lim2;
else index = count/2;
if ((lim1==count)||(lim2==0)) index = count/2;
}
int selectDivision(DistanceType* v)
{
int num = 0;
size_t topind[RAND_DIM];
for (size_t i = 0; i < veclen_; ++i) {
if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) {
if (num < RAND_DIM) {
topind[num++] = i;
}
else {
topind[num-1] = i;
}
int j = num - 1;
while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
std::swap(topind[j], topind[j-1]);
--j;
}
}
}
int rnd = rand_int(num);
return (int)topind[rnd];
}
void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
{
int left = 0;
int right = count-1;
for (;; ) {
while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
if (left>right) break;
std::swap(ind[left], ind[right]); ++left; --right;
}
lim1 = left;
right = count-1;
for (;; ) {
while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
if (left>right) break;
std::swap(ind[left], ind[right]); ++left; --right;
}
lim2 = left;
}
void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError)
{
if (trees_ > 1) {
fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
}
if (trees_>0) {
searchLevelExact(result, vec, tree_roots_[0], 0.0, epsError);
}
assert(result.full());
}
void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError)
{
int i;
BranchSt branch;
int checkCount = 0;
Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
DynamicBitset checked(size_);
for (i = 0; i < trees_; ++i) {
searchLevel(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked);
}
while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
searchLevel(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked);
}
delete heap;
assert(result.full());
}
void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck,
float epsError, Heap<BranchSt>* heap, DynamicBitset& checked)
{
if (result_set.worstDist()<mindist) {
return;
}
if ((node->child1 == NULL)&&(node->child2 == NULL)) {
int index = node->divfeat;
if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return;
checked.set(index);
checkCount++;
DistanceType dist = distance_(dataset_[index], vec, veclen_);
result_set.addPoint(dist,index);
return;
}
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) {
heap->insert( BranchSt(otherChild, new_distsq) );
}
searchLevel(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked);
}
void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError)
{
if ((node->child1 == NULL)&&(node->child2 == NULL)) {
int index = node->divfeat;
DistanceType dist = distance_(dataset_[index], vec, veclen_);
result_set.addPoint(dist,index);
return;
}
ElementType val = vec[node->divfeat];
DistanceType diff = val - node->divval;
NodePtr bestChild = (diff < 0) ? node->child1 : node->child2;
NodePtr otherChild = (diff < 0) ? node->child2 : node->child1;
DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat);
searchLevelExact(result_set, vec, bestChild, mindist, epsError);
if (new_distsq*epsError<=result_set.worstDist()) {
searchLevelExact(result_set, vec, otherChild, new_distsq, epsError);
}
}
private:
enum
{
SAMPLE_MEAN = 100,
RAND_DIM=5
};
int trees_;
std::vector<int> vind_;
const Matrix<ElementType> dataset_;
IndexParams index_params_;
size_t size_;
size_t veclen_;
DistanceType* mean_;
DistanceType* var_;
NodePtr* tree_roots_;
PooledAllocator pool_;
Distance distance_;
};
}
#endif