This source file includes following definitions.
- calculateError
- run
- run
- TEST
- TEST
#include "test_precomp.hpp"
using namespace std;
using namespace cv;
using namespace cv::ml;
static bool calculateError( const Mat& _p_labels, const Mat& _o_labels, float& error)
{
error = 0.0f;
float accuracy = 0.0f;
Mat _p_labels_temp;
Mat _o_labels_temp;
_p_labels.convertTo(_p_labels_temp, CV_32S);
_o_labels.convertTo(_o_labels_temp, CV_32S);
CV_Assert(_p_labels_temp.total() == _o_labels_temp.total());
CV_Assert(_p_labels_temp.rows == _o_labels_temp.rows);
accuracy = (float)countNonZero(_p_labels_temp == _o_labels_temp)/_p_labels_temp.rows;
error = 1 - accuracy;
return true;
}
class CV_LRTest : public cvtest::BaseTest
{
public:
CV_LRTest() {}
protected:
virtual void run( int start_from );
};
void CV_LRTest::run( int )
{
string dataFileName = ts->get_data_path() + "iris.data";
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
Ptr<LogisticRegression> p = LogisticRegression::create();
p->setLearningRate(1.0);
p->setIterations(10001);
p->setRegularization(LogisticRegression::REG_L2);
p->setTrainMethod(LogisticRegression::BATCH);
p->setMiniBatchSize(10);
p->train(tdata);
Mat responses;
p->predict(tdata->getSamples(), responses);
int test_code = cvtest::TS::OK;
float error = 0.0f;
if(!calculateError(responses, tdata->getResponses(), error))
{
ts->printf(cvtest::TS::LOG, "Bad prediction labels\n" );
test_code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if(error > 0.05f)
{
ts->printf(cvtest::TS::LOG, "Bad accuracy of (%f)\n", error);
test_code = cvtest::TS::FAIL_BAD_ACCURACY;
}
{
FileStorage s("debug.xml", FileStorage::WRITE);
s << "original" << tdata->getResponses();
s << "predicted1" << responses;
s << "learnt" << p->get_learnt_thetas();
s << "error" << error;
s.release();
}
ts->set_failed_test_info(test_code);
}
class CV_LRTest_SaveLoad : public cvtest::BaseTest
{
public:
CV_LRTest_SaveLoad(){}
protected:
virtual void run(int start_from);
};
void CV_LRTest_SaveLoad::run( int )
{
int code = cvtest::TS::OK;
string dataFileName = ts->get_data_path() + "iris.data";
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
Mat responses1, responses2;
Mat learnt_mat1, learnt_mat2;
String filename = tempfile(".xml");
try
{
Ptr<LogisticRegression> lr1 = LogisticRegression::create();
lr1->setLearningRate(1.0);
lr1->setIterations(10001);
lr1->setRegularization(LogisticRegression::REG_L2);
lr1->setTrainMethod(LogisticRegression::BATCH);
lr1->setMiniBatchSize(10);
lr1->train(tdata);
lr1->predict(tdata->getSamples(), responses1);
learnt_mat1 = lr1->get_learnt_thetas();
lr1->save(filename);
}
catch(...)
{
ts->printf(cvtest::TS::LOG, "Crash in write method.\n" );
ts->set_failed_test_info(cvtest::TS::FAIL_EXCEPTION);
}
try
{
Ptr<LogisticRegression> lr2 = Algorithm::load<LogisticRegression>(filename);
lr2->predict(tdata->getSamples(), responses2);
learnt_mat2 = lr2->get_learnt_thetas();
}
catch(...)
{
ts->printf(cvtest::TS::LOG, "Crash in write method.\n" );
ts->set_failed_test_info(cvtest::TS::FAIL_EXCEPTION);
}
CV_Assert(responses1.rows == responses2.rows);
Mat comp_learnt_mats;
comp_learnt_mats = (learnt_mat1 == learnt_mat2);
comp_learnt_mats = comp_learnt_mats.reshape(1, comp_learnt_mats.rows*comp_learnt_mats.cols);
comp_learnt_mats.convertTo(comp_learnt_mats, CV_32S);
comp_learnt_mats = comp_learnt_mats/255;
float errorCount = 0.0;
errorCount += 1 - (float)countNonZero(responses1 == responses2)/responses1.rows;
errorCount += 1 - (float)sum(comp_learnt_mats)[0]/comp_learnt_mats.rows;
if(errorCount>0)
{
ts->printf( cvtest::TS::LOG, "Different prediction results before writing and after reading (errorCount=%d).\n", errorCount );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
remove( filename.c_str() );
ts->set_failed_test_info( code );
}
TEST(ML_LR, accuracy) { CV_LRTest test; test.safe_run(); }
TEST(ML_LR, save_load) { CV_LRTest_SaveLoad test; test.safe_run(); }