root/modules/ml/src/tree.cpp

/* [<][>][^][v][top][bottom][index][help] */

DEFINITIONS

This source file includes following definitions.
  1. clear
  2. startTraining
  3. initCompVarIdx
  4. endTraining
  5. train
  6. getActiveVars
  7. addTree
  8. setDParams
  9. addNodeAndTrySplit
  10. findBestSplit
  11. calcValue
  12. findSplitOrdClass
  13. clusterCategories
  14. findSplitCatClass
  15. findSplitOrdReg
  16. findSplitCatReg
  17. calcDir
  18. pruneCV
  19. updateTreeRNC
  20. cutTree
  21. predictTrees
  22. predict
  23. writeTrainingParams
  24. writeParams
  25. writeSplit
  26. writeNode
  27. writeTree
  28. write
  29. readParams
  30. readSplit
  31. readNode
  32. readTree
  33. read
  34. create

/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Copyright (C) 2014, Itseez Inc, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include <ctype.h>

namespace cv {
namespace ml {

using std::vector;

TreeParams::TreeParams()
{
    maxDepth = INT_MAX;
    minSampleCount = 10;
    regressionAccuracy = 0.01f;
    useSurrogates = false;
    maxCategories = 10;
    CVFolds = 10;
    use1SERule = true;
    truncatePrunedTree = true;
    priors = Mat();
}

TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
                       double _regressionAccuracy, bool _useSurrogates,
                       int _maxCategories, int _CVFolds,
                       bool _use1SERule, bool _truncatePrunedTree,
                       const Mat& _priors)
{
    maxDepth = _maxDepth;
    minSampleCount = _minSampleCount;
    regressionAccuracy = (float)_regressionAccuracy;
    useSurrogates = _useSurrogates;
    maxCategories = _maxCategories;
    CVFolds = _CVFolds;
    use1SERule = _use1SERule;
    truncatePrunedTree = _truncatePrunedTree;
    priors = _priors;
}

DTrees::Node::Node()
{
    classIdx = 0;
    value = 0;
    parent = left = right = split = defaultDir = -1;
}

DTrees::Split::Split()
{
    varIdx = 0;
    inversed = false;
    quality = 0.f;
    next = -1;
    c = 0.f;
    subsetOfs = 0;
}


DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
{
    data = _data;
    vector<int> subsampleIdx;
    Mat sidx0 = _data->getTrainSampleIdx();
    if( !sidx0.empty() )
    {
        sidx0.copyTo(sidx);
        std::sort(sidx.begin(), sidx.end());
    }
    else
    {
        int n = _data->getNSamples();
        setRangeVector(sidx, n);
    }

    maxSubsetSize = 0;
}

DTreesImpl::DTreesImpl() {}
DTreesImpl::~DTreesImpl() {}
void DTreesImpl::clear()
{
    varIdx.clear();
    compVarIdx.clear();
    varType.clear();
    catOfs.clear();
    catMap.clear();
    roots.clear();
    nodes.clear();
    splits.clear();
    subsets.clear();
    classLabels.clear();

    w.release();
    _isClassifier = false;
}

void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
{
    clear();
    w = makePtr<WorkData>(data);

    Mat vtype = data->getVarType();
    vtype.copyTo(varType);

    data->getCatOfs().copyTo(catOfs);
    data->getCatMap().copyTo(catMap);
    data->getDefaultSubstValues().copyTo(missingSubst);

    int nallvars = data->getNAllVars();

    Mat vidx0 = data->getVarIdx();
    if( !vidx0.empty() )
        vidx0.copyTo(varIdx);
    else
        setRangeVector(varIdx, nallvars);

    initCompVarIdx();

    w->maxSubsetSize = 0;

    int i, nvars = (int)varIdx.size();
    for( i = 0; i < nvars; i++ )
        w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));

    w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);

    data->getSampleWeights().copyTo(w->sample_weights);

    _isClassifier = data->getResponseType() == VAR_CATEGORICAL;

    if( _isClassifier )
    {
        data->getNormCatResponses().copyTo(w->cat_responses);
        data->getClassLabels().copyTo(classLabels);
        int nclasses = (int)classLabels.size();

        Mat class_weights = params.priors;
        if( !class_weights.empty() )
        {
            if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
            {
                Mat temp;
                class_weights.convertTo(temp, CV_64F);
                class_weights = temp;
            }
            CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );

            int nsamples = (int)w->cat_responses.size();
            const double* cw = class_weights.ptr<double>();
            CV_Assert( (int)w->sample_weights.size() == nsamples );

            for( i = 0; i < nsamples; i++ )
            {
                int ci = w->cat_responses[i];
                CV_Assert( 0 <= ci && ci < nclasses );
                w->sample_weights[i] *= cw[ci];
            }
        }
    }
    else
        data->getResponses().copyTo(w->ord_responses);
}


void DTreesImpl::initCompVarIdx()
{
    int nallvars = (int)varType.size();
    compVarIdx.assign(nallvars, -1);
    int i, nvars = (int)varIdx.size(), prevIdx = -1;
    for( i = 0; i < nvars; i++ )
    {
        int vi = varIdx[i];
        CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
        prevIdx = vi;
        compVarIdx[vi] = i;
    }
}

void DTreesImpl::endTraining()
{
    w.release();
}

bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
{
    startTraining(trainData, flags);
    bool ok = addTree( w->sidx ) >= 0;
    w.release();
    endTraining();
    return ok;
}

const vector<int>& DTreesImpl::getActiveVars()
{
    return varIdx;
}

int DTreesImpl::addTree(const vector<int>& sidx )
{
    size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();

    w->wnodes.reserve(n);
    w->wsplits.reserve(n);
    w->wsubsets.reserve(n*w->maxSubsetSize);
    w->wnodes.clear();
    w->wsplits.clear();
    w->wsubsets.clear();

    int cv_n = params.getCVFolds();

    if( cv_n > 0 )
    {
        w->cv_Tn.resize(n*cv_n);
        w->cv_node_error.resize(n*cv_n);
        w->cv_node_risk.resize(n*cv_n);
    }

    // build the tree recursively
    int w_root = addNodeAndTrySplit(-1, sidx);
    int maxdepth = INT_MAX;//pruneCV(root);

    int w_nidx = w_root, pidx = -1, depth = 0;
    int root = (int)nodes.size();

    for(;;)
    {
        const WNode& wnode = w->wnodes[w_nidx];
        Node node;
        node.parent = pidx;
        node.classIdx = wnode.class_idx;
        node.value = wnode.value;
        node.defaultDir = wnode.defaultDir;

        int wsplit_idx = wnode.split;
        if( wsplit_idx >= 0 )
        {
            const WSplit& wsplit = w->wsplits[wsplit_idx];
            Split split;
            split.c = wsplit.c;
            split.quality = wsplit.quality;
            split.inversed = wsplit.inversed;
            split.varIdx = wsplit.varIdx;
            split.subsetOfs = -1;
            if( wsplit.subsetOfs >= 0 )
            {
                int ssize = getSubsetSize(split.varIdx);
                split.subsetOfs = (int)subsets.size();
                subsets.resize(split.subsetOfs + ssize);
                // This check verifies that subsets index is in the correct range
                // as in case ssize == 0 no real resize performed.
                // Thus memory kept safe.
                // Also this skips useless memcpy call when size parameter is zero
                if(ssize > 0)
                {
                    memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
                }
            }
            node.split = (int)splits.size();
            splits.push_back(split);
        }
        int nidx = (int)nodes.size();
        nodes.push_back(node);
        if( pidx >= 0 )
        {
            int w_pidx = w->wnodes[w_nidx].parent;
            if( w->wnodes[w_pidx].left == w_nidx )
            {
                nodes[pidx].left = nidx;
            }
            else
            {
                CV_Assert(w->wnodes[w_pidx].right == w_nidx);
                nodes[pidx].right = nidx;
            }
        }

        if( wnode.left >= 0 && depth+1 < maxdepth )
        {
            w_nidx = wnode.left;
            pidx = nidx;
            depth++;
        }
        else
        {
            int w_pidx = wnode.parent;
            while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
            {
                w_nidx = w_pidx;
                w_pidx = w->wnodes[w_pidx].parent;
                nidx = pidx;
                pidx = nodes[pidx].parent;
                depth--;
            }

            if( w_pidx < 0 )
                break;

            w_nidx = w->wnodes[w_pidx].right;
            CV_Assert( w_nidx >= 0 );
        }
    }
    roots.push_back(root);
    return root;
}

void DTreesImpl::setDParams(const TreeParams& _params)
{
    params = _params;
}

int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
{
    w->wnodes.push_back(WNode());
    int nidx = (int)(w->wnodes.size() - 1);
    WNode& node = w->wnodes.back();

    node.parent = parent;
    node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
    int nfolds = params.getCVFolds();

    if( nfolds > 0 )
    {
        w->cv_Tn.resize((nidx+1)*nfolds);
        w->cv_node_error.resize((nidx+1)*nfolds);
        w->cv_node_risk.resize((nidx+1)*nfolds);
    }

    int i, n = node.sample_count = (int)sidx.size();
    bool can_split = true;
    vector<int> sleft, sright;

    calcValue( nidx, sidx );

    if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
        can_split = false;
    else if( _isClassifier )
    {
        const int* responses = &w->cat_responses[0];
        const int* s = &sidx[0];
        int first = responses[s[0]];
        for( i = 1; i < n; i++ )
            if( responses[s[i]] != first )
                break;
        if( i == n )
            can_split = false;
    }
    else
    {
        if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
            can_split = false;
    }

    if( can_split )
        node.split = findBestSplit( sidx );

    //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);

    if( node.split >= 0 )
    {
        node.defaultDir = calcDir( node.split, sidx, sleft, sright );
        if( params.useSurrogates )
            CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");

        int left = addNodeAndTrySplit( nidx, sleft );
        int right = addNodeAndTrySplit( nidx, sright );
        w->wnodes[nidx].left = left;
        w->wnodes[nidx].right = right;
        CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
    }

    return nidx;
}

int DTreesImpl::findBestSplit( const vector<int>& _sidx )
{
    const vector<int>& activeVars = getActiveVars();
    int splitidx = -1;
    int vi_, nv = (int)activeVars.size();
    AutoBuffer<int> buf(w->maxSubsetSize*2);
    int *subset = buf, *best_subset = subset + w->maxSubsetSize;
    WSplit split, best_split;
    best_split.quality = 0.;

    for( vi_ = 0; vi_ < nv; vi_++ )
    {
        int vi = activeVars[vi_];
        if( varType[vi] == VAR_CATEGORICAL )
        {
            if( _isClassifier )
                split = findSplitCatClass(vi, _sidx, 0, subset);
            else
                split = findSplitCatReg(vi, _sidx, 0, subset);
        }
        else
        {
            if( _isClassifier )
                split = findSplitOrdClass(vi, _sidx, 0);
            else
                split = findSplitOrdReg(vi, _sidx, 0);
        }
        if( split.quality > best_split.quality )
        {
            best_split = split;
            std::swap(subset, best_subset);
        }
    }

    if( best_split.quality > 0 )
    {
        int best_vi = best_split.varIdx;
        CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
        int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
        w->wsubsets.resize(prevsz + ssize);
        for( i = 0; i < ssize; i++ )
            w->wsubsets[prevsz + i] = best_subset[i];
        best_split.subsetOfs = prevsz;
        w->wsplits.push_back(best_split);
        splitidx = (int)(w->wsplits.size()-1);
    }

    return splitidx;
}

void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
{
    WNode* node = &w->wnodes[nidx];
    int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
    int m = (int)classLabels.size();

    cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));

    if( cv_n > 0 )
    {
        size_t sz = w->cv_Tn.size();
        w->cv_Tn.resize(sz + cv_n);
        w->cv_node_risk.resize(sz + cv_n);
        w->cv_node_error.resize(sz + cv_n);
    }

    if( _isClassifier )
    {
        // in case of classification tree:
        //  * node value is the label of the class that has the largest weight in the node.
        //  * node risk is the weighted number of misclassified samples,
        //  * j-th cross-validation fold value and risk are calculated as above,
        //    but using the samples with cv_labels(*)!=j.
        //  * j-th cross-validation fold error is calculated as the weighted number of
        //    misclassified samples with cv_labels(*)==j.

        // compute the number of instances of each class
        double* cls_count = buf;
        double* cv_cls_count = cls_count + m;

        double max_val = -1, total_weight = 0;
        int max_k = -1;

        for( k = 0; k < m; k++ )
            cls_count[k] = 0;

        if( cv_n == 0 )
        {
            for( i = 0; i < n; i++ )
            {
                int si = _sidx[i];
                cls_count[w->cat_responses[si]] += w->sample_weights[si];
            }
        }
        else
        {
            for( j = 0; j < cv_n; j++ )
                for( k = 0; k < m; k++ )
                    cv_cls_count[j*m + k] = 0;

            for( i = 0; i < n; i++ )
            {
                int si = _sidx[i];
                j = w->cv_labels[si]; k = w->cat_responses[si];
                cv_cls_count[j*m + k] += w->sample_weights[si];
            }

            for( j = 0; j < cv_n; j++ )
                for( k = 0; k < m; k++ )
                    cls_count[k] += cv_cls_count[j*m + k];
        }

        for( k = 0; k < m; k++ )
        {
            double val = cls_count[k];
            total_weight += val;
            if( max_val < val )
            {
                max_val = val;
                max_k = k;
            }
        }

        node->class_idx = max_k;
        node->value = classLabels[max_k];
        node->node_risk = total_weight - max_val;

        for( j = 0; j < cv_n; j++ )
        {
            double sum_k = 0, sum = 0, max_val_k = 0;
            max_val = -1; max_k = -1;

            for( k = 0; k < m; k++ )
            {
                double val_k = cv_cls_count[j*m + k];
                double val = cls_count[k] - val_k;
                sum_k += val_k;
                sum += val;
                if( max_val < val )
                {
                    max_val = val;
                    max_val_k = val_k;
                    max_k = k;
                }
            }

            w->cv_Tn[nidx*cv_n + j] = INT_MAX;
            w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
            w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
        }
    }
    else
    {
        // in case of regression tree:
        //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
        //    n is the number of samples in the node.
        //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
        //  * j-th cross-validation fold value and risk are calculated as above,
        //    but using the samples with cv_labels(*)!=j.
        //  * j-th cross-validation fold error is calculated
        //    using samples with cv_labels(*)==j as the test subset:
        //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
        //    where node_value_j is the node value calculated
        //    as described in the previous bullet, and summation is done
        //    over the samples with cv_labels(*)==j.
        double sum = 0, sum2 = 0, sumw = 0;

        if( cv_n == 0 )
        {
            for( i = 0; i < n; i++ )
            {
                int si = _sidx[i];
                double wval = w->sample_weights[si];
                double t = w->ord_responses[si];
                sum += t*wval;
                sum2 += t*t*wval;
                sumw += wval;
            }
        }
        else
        {
            double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
            double* cv_count = (double*)(cv_sum2 + cv_n);

            for( j = 0; j < cv_n; j++ )
            {
                cv_sum[j] = cv_sum2[j] = 0.;
                cv_count[j] = 0;
            }

            for( i = 0; i < n; i++ )
            {
                int si = _sidx[i];
                j = w->cv_labels[si];
                double wval = w->sample_weights[si];
                double t = w->ord_responses[si];
                cv_sum[j] += t*wval;
                cv_sum2[j] += t*t*wval;
                cv_count[j] += wval;
            }

            for( j = 0; j < cv_n; j++ )
            {
                sum += cv_sum[j];
                sum2 += cv_sum2[j];
                sumw += cv_count[j];
            }

            for( j = 0; j < cv_n; j++ )
            {
                double s = sum - cv_sum[j], si = sum - s;
                double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
                double c = cv_count[j], ci = sumw - c;
                double r = si/std::max(ci, DBL_EPSILON);
                w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
                w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
                w->cv_Tn[nidx*cv_n + j] = INT_MAX;
            }
        }

        node->node_risk = sum2 - (sum/sumw)*sum;
        node->value = sum/sumw;
    }
}

DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
{
    const double epsilon = FLT_EPSILON*2;
    int n = (int)_sidx.size();
    int m = (int)classLabels.size();

    cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
    const int* sidx = &_sidx[0];
    const int* responses = &w->cat_responses[0];
    const double* weights = &w->sample_weights[0];
    double* lcw = (double*)(uchar*)buf;
    double* rcw = lcw + m;
    float* values = (float*)(rcw + m);
    int* sorted_idx = (int*)(values + n);
    int i, best_i = -1;
    double best_val = initQuality;

    for( i = 0; i < m; i++ )
        lcw[i] = rcw[i] = 0.;

    w->data->getValues( vi, _sidx, values );

    for( i = 0; i < n; i++ )
    {
        sorted_idx[i] = i;
        int si = sidx[i];
        rcw[responses[si]] += weights[si];
    }

    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));

    double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
    for( i = 0; i < m; i++ )
    {
        double wval = rcw[i];
        R += wval;
        rsum2 += wval*wval;
    }

    for( i = 0; i < n - 1; i++ )
    {
        int curr = sorted_idx[i];
        int next = sorted_idx[i+1];
        int si = sidx[curr];
        double wval = weights[si], w2 = wval*wval;
        L += wval; R -= wval;
        int idx = responses[si];
        double lv = lcw[idx], rv = rcw[idx];
        lsum2 += 2*lv*wval + w2;
        rsum2 -= 2*rv*wval - w2;
        lcw[idx] = lv + wval; rcw[idx] = rv - wval;

        if( values[curr] + epsilon < values[next] )
        {
            double val = (lsum2*R + rsum2*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_i = i;
            }
        }
    }

    WSplit split;
    if( best_i >= 0 )
    {
        split.varIdx = vi;
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
        split.inversed = false;
        split.quality = (float)best_val;
    }
    return split;
}

// simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
{
    int iters = 0, max_iters = 100;
    int i, j, idx;
    cv::AutoBuffer<double> buf(n + k);
    double *v_weights = buf, *c_weights = buf + n;
    bool modified = true;
    RNG r((uint64)-1);

    // assign labels randomly
    for( i = 0; i < n; i++ )
    {
        double sum = 0;
        const double* v = vectors + i*m;
        labels[i] = i < k ? i : r.uniform(0, k);

        // compute weight of each vector
        for( j = 0; j < m; j++ )
            sum += v[j];
        v_weights[i] = sum ? 1./sum : 0.;
    }

    for( i = 0; i < n; i++ )
    {
        int i1 = r.uniform(0, n);
        int i2 = r.uniform(0, n);
        std::swap( labels[i1], labels[i2] );
    }

    for( iters = 0; iters <= max_iters; iters++ )
    {
        // calculate csums
        for( i = 0; i < k; i++ )
        {
            for( j = 0; j < m; j++ )
                csums[i*m + j] = 0;
        }

        for( i = 0; i < n; i++ )
        {
            const double* v = vectors + i*m;
            double* s = csums + labels[i]*m;
            for( j = 0; j < m; j++ )
                s[j] += v[j];
        }

        // exit the loop here, when we have up-to-date csums
        if( iters == max_iters || !modified )
            break;

        modified = false;

        // calculate weight of each cluster
        for( i = 0; i < k; i++ )
        {
            const double* s = csums + i*m;
            double sum = 0;
            for( j = 0; j < m; j++ )
                sum += s[j];
            c_weights[i] = sum ? 1./sum : 0;
        }

        // now for each vector determine the closest cluster
        for( i = 0; i < n; i++ )
        {
            const double* v = vectors + i*m;
            double alpha = v_weights[i];
            double min_dist2 = DBL_MAX;
            int min_idx = -1;

            for( idx = 0; idx < k; idx++ )
            {
                const double* s = csums + idx*m;
                double dist2 = 0., beta = c_weights[idx];
                for( j = 0; j < m; j++ )
                {
                    double t = v[j]*alpha - s[j]*beta;
                    dist2 += t*t;
                }
                if( min_dist2 > dist2 )
                {
                    min_dist2 = dist2;
                    min_idx = idx;
                }
            }

            if( min_idx != labels[i] )
                modified = true;
            labels[i] = min_idx;
        }
    }
}

DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
                                                  double initQuality, int* subset )
{
    int _mi = getCatCount(vi), mi = _mi;
    int n = (int)_sidx.size();
    int m = (int)classLabels.size();

    int base_size = m*(3 + mi) + mi + 1;
    if( m > 2 && mi > params.getMaxCategories() )
        base_size += m*std::min(params.getMaxCategories(), n) + mi;
    else
        base_size += mi;
    AutoBuffer<double> buf(base_size + n);

    double* lc = (double*)buf;
    double* rc = lc + m;
    double* _cjk = rc + m*2, *cjk = _cjk;
    double* c_weights = cjk + m*mi;

    int* labels = (int*)(buf + base_size);
    w->data->getNormCatValues(vi, _sidx, labels);
    const int* responses = &w->cat_responses[0];
    const double* weights = &w->sample_weights[0];

    int* cluster_labels = 0;
    double** dbl_ptr = 0;
    int i, j, k, si, idx;
    double L = 0, R = 0;
    double best_val = initQuality;
    int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;

    // init array of counters:
    // c_{jk} - number of samples that have vi-th input variable = j and response = k.
    for( j = -1; j < mi; j++ )
        for( k = 0; k < m; k++ )
            cjk[j*m + k] = 0;

    for( i = 0; i < n; i++ )
    {
        si = _sidx[i];
        j = labels[i];
        k = responses[si];
        cjk[j*m + k] += weights[si];
    }

    if( m > 2 )
    {
        if( mi > params.getMaxCategories() )
        {
            mi = std::min(params.getMaxCategories(), n);
            cjk = c_weights + _mi;
            cluster_labels = (int*)(cjk + m*mi);
            clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
        }
        subset_i = 1;
        subset_n = 1 << mi;
    }
    else
    {
        assert( m == 2 );
        dbl_ptr = (double**)(c_weights + _mi);
        for( j = 0; j < mi; j++ )
            dbl_ptr[j] = cjk + j*2 + 1;
        std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
        subset_i = 0;
        subset_n = mi;
    }

    for( k = 0; k < m; k++ )
    {
        double sum = 0;
        for( j = 0; j < mi; j++ )
            sum += cjk[j*m + k];
        CV_Assert(sum > 0);
        rc[k] = sum;
        lc[k] = 0;
    }

    for( j = 0; j < mi; j++ )
    {
        double sum = 0;
        for( k = 0; k < m; k++ )
            sum += cjk[j*m + k];
        c_weights[j] = sum;
        R += c_weights[j];
    }

    for( ; subset_i < subset_n; subset_i++ )
    {
        double lsum2 = 0, rsum2 = 0;

        if( m == 2 )
            idx = (int)(dbl_ptr[subset_i] - cjk)/2;
        else
        {
            int graycode = (subset_i>>1)^subset_i;
            int diff = graycode ^ prevcode;

            // determine index of the changed bit.
            Cv32suf u;
            idx = diff >= (1 << 16) ? 16 : 0;
            u.f = (float)(((diff >> 16) | diff) & 65535);
            idx += (u.i >> 23) - 127;
            subtract = graycode < prevcode;
            prevcode = graycode;
        }

        double* crow = cjk + idx*m;
        double weight = c_weights[idx];
        if( weight < FLT_EPSILON )
            continue;

        if( !subtract )
        {
            for( k = 0; k < m; k++ )
            {
                double t = crow[k];
                double lval = lc[k] + t;
                double rval = rc[k] - t;
                lsum2 += lval*lval;
                rsum2 += rval*rval;
                lc[k] = lval; rc[k] = rval;
            }
            L += weight;
            R -= weight;
        }
        else
        {
            for( k = 0; k < m; k++ )
            {
                double t = crow[k];
                double lval = lc[k] - t;
                double rval = rc[k] + t;
                lsum2 += lval*lval;
                rsum2 += rval*rval;
                lc[k] = lval; rc[k] = rval;
            }
            L -= weight;
            R += weight;
        }

        if( L > FLT_EPSILON && R > FLT_EPSILON )
        {
            double val = (lsum2*R + rsum2*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_subset = subset_i;
            }
        }
    }

    WSplit split;
    if( best_subset >= 0 )
    {
        split.varIdx = vi;
        split.quality = (float)best_val;
        memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
        if( m == 2 )
        {
            for( i = 0; i <= best_subset; i++ )
            {
                idx = (int)(dbl_ptr[i] - cjk) >> 1;
                subset[idx >> 5] |= 1 << (idx & 31);
            }
        }
        else
        {
            for( i = 0; i < _mi; i++ )
            {
                idx = cluster_labels ? cluster_labels[i] : i;
                if( best_subset & (1 << idx) )
                    subset[i >> 5] |= 1 << (i & 31);
            }
        }
    }
    return split;
}

DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
{
    const float epsilon = FLT_EPSILON*2;
    const double* weights = &w->sample_weights[0];
    int n = (int)_sidx.size();

    AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));

    float* values = (float*)(uchar*)buf;
    int* sorted_idx = (int*)(values + n);
    w->data->getValues(vi, _sidx, values);
    const double* responses = &w->ord_responses[0];

    int i, si, best_i = -1;
    double L = 0, R = 0;
    double best_val = initQuality, lsum = 0, rsum = 0;

    for( i = 0; i < n; i++ )
    {
        sorted_idx[i] = i;
        si = _sidx[i];
        R += weights[si];
        rsum += weights[si]*responses[si];
    }

    std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));

    // find the optimal split
    for( i = 0; i < n - 1; i++ )
    {
        int curr = sorted_idx[i];
        int next = sorted_idx[i+1];
        si = _sidx[curr];
        double wval = weights[si];
        double t = responses[si]*wval;
        L += wval; R -= wval;
        lsum += t; rsum -= t;

        if( values[curr] + epsilon < values[next] )
        {
            double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
            if( best_val < val )
            {
                best_val = val;
                best_i = i;
            }
        }
    }

    WSplit split;
    if( best_i >= 0 )
    {
        split.varIdx = vi;
        split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
        split.inversed = false;
        split.quality = (float)best_val;
    }
    return split;
}

DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
                                                double initQuality, int* subset )
{
    const double* weights = &w->sample_weights[0];
    const double* responses = &w->ord_responses[0];
    int n = (int)_sidx.size();
    int mi = getCatCount(vi);

    AutoBuffer<double> buf(3*mi + 3 + n);
    double* sum = (double*)buf + 1;
    double* counts = sum + mi + 1;
    double** sum_ptr = (double**)(counts + mi);
    int* cat_labels = (int*)(sum_ptr + mi);

    w->data->getNormCatValues(vi, _sidx, cat_labels);

    double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
    int i, si, best_subset = -1, subset_i;

    for( i = -1; i < mi; i++ )
        sum[i] = counts[i] = 0;

    // calculate sum response and weight of each category of the input var
    for( i = 0; i < n; i++ )
    {
        int idx = cat_labels[i];
        si = _sidx[i];
        double wval = weights[si];
        sum[idx] += responses[si]*wval;
        counts[idx] += wval;
    }

    // calculate average response in each category
    for( i = 0; i < mi; i++ )
    {
        R += counts[i];
        rsum += sum[i];
        sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
        sum_ptr[i] = sum + i;
    }

    std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());

    // revert back to unnormalized sums
    // (there should be a very little loss in accuracy)
    for( i = 0; i < mi; i++ )
        sum[i] *= counts[i];

    for( subset_i = 0; subset_i < mi-1; subset_i++ )
    {
        int idx = (int)(sum_ptr[subset_i] - sum);
        double ni = counts[idx];

        if( ni > FLT_EPSILON )
        {
            double s = sum[idx];
            lsum += s; L += ni;
            rsum -= s; R -= ni;

            if( L > FLT_EPSILON && R > FLT_EPSILON )
            {
                double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
                if( best_val < val )
                {
                    best_val = val;
                    best_subset = subset_i;
                }
            }
        }
    }

    WSplit split;
    if( best_subset >= 0 )
    {
        split.varIdx = vi;
        split.quality = (float)best_val;
        memset( subset, 0, getSubsetSize(vi) * sizeof(int));
        for( i = 0; i <= best_subset; i++ )
        {
            int idx = (int)(sum_ptr[i] - sum);
            subset[idx >> 5] |= 1 << (idx & 31);
        }
    }
    return split;
}

int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
                         vector<int>& _sleft, vector<int>& _sright )
{
    WSplit split = w->wsplits[splitidx];
    int i, si, n = (int)_sidx.size(), vi = split.varIdx;
    _sleft.reserve(n);
    _sright.reserve(n);
    _sleft.clear();
    _sright.clear();

    AutoBuffer<float> buf(n);
    int mi = getCatCount(vi);
    double wleft = 0, wright = 0;
    const double* weights = &w->sample_weights[0];

    if( mi <= 0 ) // split on an ordered variable
    {
        float c = split.c;
        float* values = buf;
        w->data->getValues(vi, _sidx, values);

        for( i = 0; i < n; i++ )
        {
            si = _sidx[i];
            if( values[i] <= c )
            {
                _sleft.push_back(si);
                wleft += weights[si];
            }
            else
            {
                _sright.push_back(si);
                wright += weights[si];
            }
        }
    }
    else
    {
        const int* subset = &w->wsubsets[split.subsetOfs];
        int* cat_labels = (int*)(float*)buf;
        w->data->getNormCatValues(vi, _sidx, cat_labels);

        for( i = 0; i < n; i++ )
        {
            si = _sidx[i];
            unsigned u = cat_labels[i];
            if( CV_DTREE_CAT_DIR(u, subset) < 0 )
            {
                _sleft.push_back(si);
                wleft += weights[si];
            }
            else
            {
                _sright.push_back(si);
                wright += weights[si];
            }
        }
    }
    CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
    return wleft > wright ? -1 : 1;
}

int DTreesImpl::pruneCV( int root )
{
    vector<double> ab;

    // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
    // 2. choose the best tree index (if need, apply 1SE rule).
    // 3. store the best index and cut the branches.

    int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
    // currently, 1SE for regression is not implemented
    bool use_1se = params.use1SERule != 0 && _isClassifier;
    double min_err = 0, min_err_se = 0;
    int min_idx = -1;

    // build the main tree sequence, calculate alpha's
    for(;;tree_count++)
    {
        double min_alpha = updateTreeRNC(root, tree_count, -1);
        if( cutTree(root, tree_count, -1, min_alpha) )
            break;

        ab.push_back(min_alpha);
    }

    if( tree_count > 0 )
    {
        ab[0] = 0.;

        for( ti = 1; ti < tree_count-1; ti++ )
            ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
        ab[tree_count-1] = DBL_MAX*0.5;

        Mat err_jk(cv_n, tree_count, CV_64F);

        for( j = 0; j < cv_n; j++ )
        {
            int tj = 0, tk = 0;
            for( ; tj < tree_count; tj++ )
            {
                double min_alpha = updateTreeRNC(root, tj, j);
                if( cutTree(root, tj, j, min_alpha) )
                    min_alpha = DBL_MAX;

                for( ; tk < tree_count; tk++ )
                {
                    if( ab[tk] > min_alpha )
                        break;
                    err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
                }
            }
        }

        for( ti = 0; ti < tree_count; ti++ )
        {
            double sum_err = 0;
            for( j = 0; j < cv_n; j++ )
                sum_err += err_jk.at<double>(j, ti);
            if( ti == 0 || sum_err < min_err )
            {
                min_err = sum_err;
                min_idx = ti;
                if( use_1se )
                    min_err_se = sqrt( sum_err*(n - sum_err) );
            }
            else if( sum_err < min_err + min_err_se )
                min_idx = ti;
        }
    }

    return min_idx;
}

double DTreesImpl::updateTreeRNC( int root, double T, int fold )
{
    int nidx = root, pidx = -1, cv_n = params.getCVFolds();
    double min_alpha = DBL_MAX;

    for(;;)
    {
        WNode *node = 0, *parent = 0;

        for(;;)
        {
            node = &w->wnodes[nidx];
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
            if( t <= T || node->left < 0 )
            {
                node->complexity = 1;
                node->tree_risk = node->node_risk;
                node->tree_error = 0.;
                if( fold >= 0 )
                {
                    node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
                    node->tree_error = w->cv_node_error[nidx*cv_n + fold];
                }
                break;
            }
            nidx = node->left;
        }

        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
             nidx = pidx, pidx = w->wnodes[pidx].parent )
        {
            node = &w->wnodes[nidx];
            parent = &w->wnodes[pidx];
            parent->complexity += node->complexity;
            parent->tree_risk += node->tree_risk;
            parent->tree_error += node->tree_error;

            parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
                             - parent->tree_risk)/(parent->complexity - 1);
            min_alpha = std::min( min_alpha, parent->alpha );
        }

        if( pidx < 0 )
            break;

        node = &w->wnodes[nidx];
        parent = &w->wnodes[pidx];
        parent->complexity = node->complexity;
        parent->tree_risk = node->tree_risk;
        parent->tree_error = node->tree_error;
        nidx = parent->right;
    }

    return min_alpha;
}

bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
{
    int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
    WNode* node = &w->wnodes[root];
    if( node->left < 0 )
        return true;

    for(;;)
    {
        for(;;)
        {
            node = &w->wnodes[nidx];
            double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
            if( t <= T || node->left < 0 )
                break;
            if( node->alpha <= min_alpha + FLT_EPSILON )
            {
                if( fold >= 0 )
                    w->cv_Tn[nidx*cv_n + fold] = T;
                else
                    node->Tn = T;
                if( nidx == root )
                    return true;
                break;
            }
            nidx = node->left;
        }

        for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
             nidx = pidx, pidx = w->wnodes[pidx].parent )
            ;

        if( pidx < 0 )
            break;

        nidx = w->wnodes[pidx].right;
    }

    return false;
}

float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
{
    CV_Assert( sample.type() == CV_32F );

    int predictType = flags & PREDICT_MASK;
    int nvars = (int)varIdx.size();
    if( nvars == 0 )
        nvars = (int)varType.size();
    int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
    int catbufsize = ncats > 0 ? nvars : 0;
    AutoBuffer<int> buf(nclasses + catbufsize + 1);
    int* votes = buf;
    int* catbuf = votes + nclasses;
    const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
    const uchar* vtype = &varType[0];
    const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
    const int* cmap = !catMap.empty() ? &catMap[0] : 0;
    const float* psample = sample.ptr<float>();
    const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
    size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
    double sum = 0.;
    int lastClassIdx = -1;
    const float MISSED_VAL = TrainData::missingValue();

    for( i = 0; i < catbufsize; i++ )
        catbuf[i] = -1;

    if( predictType == PREDICT_AUTO )
    {
        predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
            PREDICT_SUM : PREDICT_MAX_VOTE;
    }

    if( predictType == PREDICT_MAX_VOTE )
    {
        for( i = 0; i < nclasses; i++ )
            votes[i] = 0;
    }

    for( int ridx = range.start; ridx < range.end; ridx++ )
    {
        int nidx = roots[ridx], prev = nidx, c = 0;

        for(;;)
        {
            prev = nidx;
            const Node& node = nodes[nidx];
            if( node.split < 0 )
                break;
            const Split& split = splits[node.split];
            int vi = split.varIdx;
            int ci = cvidx ? cvidx[vi] : vi;
            float val = psample[ci*sstep];
            if( val == MISSED_VAL )
            {
                if( !missingSubstPtr )
                {
                    nidx = node.defaultDir < 0 ? node.left : node.right;
                    continue;
                }
                val = missingSubstPtr[vi];
            }

            if( vtype[vi] == VAR_ORDERED )
                nidx = val <= split.c ? node.left : node.right;
            else
            {
                if( flags & PREPROCESSED_INPUT )
                    c = cvRound(val);
                else
                {
                    c = catbuf[ci];
                    if( c < 0 )
                    {
                        int a = c = cofs[vi][0];
                        int b = cofs[vi][1];

                        int ival = cvRound(val);
                        if( ival != val )
                            CV_Error( CV_StsBadArg,
                                     "one of input categorical variable is not an integer" );

                        while( a < b )
                        {
                            c = (a + b) >> 1;
                            if( ival < cmap[c] )
                                b = c;
                            else if( ival > cmap[c] )
                                a = c+1;
                            else
                                break;
                        }

                        CV_Assert( c >= 0 && ival == cmap[c] );

                        c -= cofs[vi][0];
                        catbuf[ci] = c;
                    }
                    const int* subset = &subsets[split.subsetOfs];
                    unsigned u = c;
                    nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
                }
            }
        }

        if( predictType == PREDICT_SUM )
            sum += nodes[prev].value;
        else
        {
            lastClassIdx = nodes[prev].classIdx;
            votes[lastClassIdx]++;
        }
    }

    if( predictType == PREDICT_MAX_VOTE )
    {
        int best_idx = lastClassIdx;
        if( range.end - range.start > 1 )
        {
            best_idx = 0;
            for( i = 1; i < nclasses; i++ )
                if( votes[best_idx] < votes[i] )
                    best_idx = i;
        }
        sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
    }

    return (float)sum;
}


float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
{
    CV_Assert( !roots.empty() );
    Mat samples = _samples.getMat(), results;
    int i, nsamples = samples.rows;
    int rtype = CV_32F;
    bool needresults = _results.needed();
    float retval = 0.f;
    bool iscls = isClassifier();
    float scale = !iscls ? 1.f/(int)roots.size() : 1.f;

    if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
        rtype = CV_32S;

    if( needresults )
    {
        _results.create(nsamples, 1, rtype);
        results = _results.getMat();
    }
    else
        nsamples = std::min(nsamples, 1);

    for( i = 0; i < nsamples; i++ )
    {
        float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
        if( needresults )
        {
            if( rtype == CV_32F )
                results.at<float>(i) = val;
            else
                results.at<int>(i) = cvRound(val);
        }
        if( i == 0 )
            retval = val;
    }
    return retval;
}

void DTreesImpl::writeTrainingParams(FileStorage& fs) const
{
    fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
    fs << "max_categories" << params.getMaxCategories();
    fs << "regression_accuracy" << params.getRegressionAccuracy();

    fs << "max_depth" << params.getMaxDepth();
    fs << "min_sample_count" << params.getMinSampleCount();
    fs << "cross_validation_folds" << params.getCVFolds();

    if( params.getCVFolds() > 1 )
        fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);

    if( !params.priors.empty() )
        fs << "priors" << params.priors;
}

void DTreesImpl::writeParams(FileStorage& fs) const
{
    fs << "is_classifier" << isClassifier();
    fs << "var_all" << (int)varType.size();
    fs << "var_count" << getVarCount();

    int ord_var_count = 0, cat_var_count = 0;
    int i, n = (int)varType.size();
    for( i = 0; i < n; i++ )
        if( varType[i] == VAR_ORDERED )
            ord_var_count++;
        else
            cat_var_count++;
    fs << "ord_var_count" << ord_var_count;
    fs << "cat_var_count" << cat_var_count;

    fs << "training_params" << "{";
    writeTrainingParams(fs);

    fs << "}";

    if( !varIdx.empty() )
    {
        fs << "global_var_idx" << 1;
        fs << "var_idx" << varIdx;
    }

    fs << "var_type" << varType;

    if( !catOfs.empty() )
        fs << "cat_ofs" << catOfs;
    if( !catMap.empty() )
        fs << "cat_map" << catMap;
    if( !classLabels.empty() )
        fs << "class_labels" << classLabels;
    if( !missingSubst.empty() )
        fs << "missing_subst" << missingSubst;
}

void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
{
    const Split& split = splits[splitidx];

    fs << "{:";

    int vi = split.varIdx;
    fs << "var" << vi;
    fs << "quality" << split.quality;

    if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
    {
        int i, n = getCatCount(vi), to_right = 0;
        const int* subset = &subsets[split.subsetOfs];
        for( i = 0; i < n; i++ )
            to_right += CV_DTREE_CAT_DIR(i, subset) > 0;

        // ad-hoc rule when to use inverse categorical split notation
        // to achieve more compact and clear representation
        int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;

        fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";

        for( i = 0; i < n; i++ )
        {
            int dir = CV_DTREE_CAT_DIR(i, subset);
            if( dir*default_dir < 0 )
                fs << i;
        }

        fs << "]";
    }
    else
        fs << (!split.inversed ? "le" : "gt") << split.c;

    fs << "}";
}

void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
{
    const Node& node = nodes[nidx];
    fs << "{";
    fs << "depth" << depth;
    fs << "value" << node.value;

    if( _isClassifier )
        fs << "norm_class_idx" << node.classIdx;

    if( node.split >= 0 )
    {
        fs << "splits" << "[";

        for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
            writeSplit( fs, splitidx );

        fs << "]";
    }

    fs << "}";
}

void DTreesImpl::writeTree( FileStorage& fs, int root ) const
{
    fs << "nodes" << "[";

    int nidx = root, pidx = 0, depth = 0;
    const Node *node = 0;

    // traverse the tree and save all the nodes in depth-first order
    for(;;)
    {
        for(;;)
        {
            writeNode( fs, nidx, depth );
            node = &nodes[nidx];
            if( node->left < 0 )
                break;
            nidx = node->left;
            depth++;
        }

        for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
             nidx = pidx, pidx = nodes[pidx].parent )
            depth--;

        if( pidx < 0 )
            break;

        nidx = nodes[pidx].right;
    }

    fs << "]";
}

void DTreesImpl::write( FileStorage& fs ) const
{
    writeParams(fs);
    writeTree(fs, roots[0]);
}

void DTreesImpl::readParams( const FileNode& fn )
{
    _isClassifier = (int)fn["is_classifier"] != 0;
    /*int var_all = (int)fn["var_all"];
    int var_count = (int)fn["var_count"];
    int cat_var_count = (int)fn["cat_var_count"];
    int ord_var_count = (int)fn["ord_var_count"];*/

    FileNode tparams_node = fn["training_params"];

    TreeParams params0 = TreeParams();

    if( !tparams_node.empty() ) // training parameters are not necessary
    {
        params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
        params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
        params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
        params0.setMaxDepth((int)tparams_node["max_depth"]);
        params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
        params0.setCVFolds((int)tparams_node["cross_validation_folds"]);

        if( params0.getCVFolds() > 1 )
        {
            params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
        }

        tparams_node["priors"] >> params0.priors;
    }

    readVectorOrMat(fn["var_idx"], varIdx);
    fn["var_type"] >> varType;

    int format = 0;
    fn["format"] >> format;
    bool isLegacy = format < 3;

    int varAll = (int)fn["var_all"];
    if (isLegacy && (int)varType.size() <= varAll)
    {
        std::vector<uchar> extendedTypes(varAll + 1, 0);

        int i = 0, n;
        if (!varIdx.empty())
        {
            n = (int)varIdx.size();
            for (; i < n; ++i)
            {
                int var = varIdx[i];
                extendedTypes[var] = varType[i];
            }
        }
        else
        {
            n = (int)varType.size();
            for (; i < n; ++i)
            {
                extendedTypes[i] = varType[i];
            }
        }
        extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
        extendedTypes.swap(varType);
    }

    readVectorOrMat(fn["cat_map"], catMap);

    if (isLegacy)
    {
        // generating "catOfs" from "cat_count"
        catOfs.clear();
        classLabels.clear();
        std::vector<int> counts;
        readVectorOrMat(fn["cat_count"], counts);
        unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
        for (; i < size; ++i)
        {
            Vec2i newOffsets(0, 0);
            if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
            {
                newOffsets[0] = curShift;
                curShift += counts[j];
                newOffsets[1] = curShift;
                ++j;
            }
            catOfs.push_back(newOffsets);
        }
        // other elements in "catMap" are "classLabels"
        if (curShift < catMap.size())
        {
            classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
            catMap.erase(catMap.begin() + curShift, catMap.end());
        }
    }
    else
    {
        fn["cat_ofs"] >> catOfs;
        fn["missing_subst"] >> missingSubst;
        fn["class_labels"] >> classLabels;
    }

    // init var mapping for node reading (var indexes or varIdx indexes)
    bool globalVarIdx = false;
    fn["global_var_idx"] >> globalVarIdx;
    if (globalVarIdx || varIdx.empty())
        setRangeVector(varMapping, (int)varType.size());
    else
        varMapping = varIdx;

    initCompVarIdx();
    setDParams(params0);
}

int DTreesImpl::readSplit( const FileNode& fn )
{
    Split split;

    int vi = (int)fn["var"];
    CV_Assert( 0 <= vi && vi <= (int)varType.size() );
    vi = varMapping[vi]; // convert to varIdx if needed
    split.varIdx = vi;

    if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
    {
        int i, val, ssize = getSubsetSize(vi);
        split.subsetOfs = (int)subsets.size();
        for( i = 0; i < ssize; i++ )
            subsets.push_back(0);
        int* subset = &subsets[split.subsetOfs];
        FileNode fns = fn["in"];
        if( fns.empty() )
        {
            fns = fn["not_in"];
            split.inversed = true;
        }

        if( fns.isInt() )
        {
            val = (int)fns;
            subset[val >> 5] |= 1 << (val & 31);
        }
        else
        {
            FileNodeIterator it = fns.begin();
            int n = (int)fns.size();
            for( i = 0; i < n; i++, ++it )
            {
                val = (int)*it;
                subset[val >> 5] |= 1 << (val & 31);
            }
        }

        // for categorical splits we do not use inversed splits,
        // instead we inverse the variable set in the split
        if( split.inversed )
        {
            for( i = 0; i < ssize; i++ )
                subset[i] ^= -1;
            split.inversed = false;
        }
    }
    else
    {
        FileNode cmpNode = fn["le"];
        if( cmpNode.empty() )
        {
            cmpNode = fn["gt"];
            split.inversed = true;
        }
        split.c = (float)cmpNode;
    }

    split.quality = (float)fn["quality"];
    splits.push_back(split);

    return (int)(splits.size() - 1);
}

int DTreesImpl::readNode( const FileNode& fn )
{
    Node node;
    node.value = (double)fn["value"];

    if( _isClassifier )
        node.classIdx = (int)fn["norm_class_idx"];

    FileNode sfn = fn["splits"];
    if( !sfn.empty() )
    {
        int i, n = (int)sfn.size(), prevsplit = -1;
        FileNodeIterator it = sfn.begin();

        for( i = 0; i < n; i++, ++it )
        {
            int splitidx = readSplit(*it);
            if( splitidx < 0 )
                break;
            if( prevsplit < 0 )
                node.split = splitidx;
            else
                splits[prevsplit].next = splitidx;
            prevsplit = splitidx;
        }
    }
    nodes.push_back(node);
    return (int)(nodes.size() - 1);
}

int DTreesImpl::readTree( const FileNode& fn )
{
    int i, n = (int)fn.size(), root = -1, pidx = -1;
    FileNodeIterator it = fn.begin();

    for( i = 0; i < n; i++, ++it )
    {
        int nidx = readNode(*it);
        if( nidx < 0 )
            break;
        Node& node = nodes[nidx];
        node.parent = pidx;
        if( pidx < 0 )
            root = nidx;
        else
        {
            Node& parent = nodes[pidx];
            if( parent.left < 0 )
                parent.left = nidx;
            else
                parent.right = nidx;
        }
        if( node.split >= 0 )
            pidx = nidx;
        else
        {
            while( pidx >= 0 && nodes[pidx].right >= 0 )
                pidx = nodes[pidx].parent;
        }
    }
    roots.push_back(root);
    return root;
}

void DTreesImpl::read( const FileNode& fn )
{
    clear();
    readParams(fn);

    FileNode fnodes = fn["nodes"];
    CV_Assert( !fnodes.empty() );
    readTree(fnodes);
}

Ptr<DTrees> DTrees::create()
{
    return makePtr<DTreesImpl>();
}

}
}

/* End of file. */

/* [<][>][^][v][top][bottom][index][help] */