OpenCV的softcascade代码分析
发布时间:2021-11-20 14:39:28 所属栏目:PHP教程 来源:互联网
导读://头文件为softcascade.hpp、core.hpp //实现文件为octave.cpp //但是还涉及到了ml.hpp和Dtree.cpp等文件 //softcascade检测器训练函数的代码如下: bool BoostedSoftCascadeOctave::train(const Dataset* dataset, const FeaturePool* pool, int weaks, int
//头文件为softcascade.hpp、core.hpp //实现文件为octave.cpp //但是还涉及到了ml.hpp和Dtree.cpp等文件 //softcascade检测器训练函数的代码如下: bool BoostedSoftCascadeOctave::train(const Dataset* dataset, const FeaturePool* pool, int weaks, int treeDepth) {//第1个参数dataset是一个可以与训练集通信的实例,第2个参数表示特征集,第3个参数是需要训练的弱分类器树数量,第4个参数是分类器树的深度 CV_Assert(treeDepth == 2); //?要求弱分类树的深度只能为2吗? CV_Assert(weaks > 0); params.max_depth = treeDepth; params.weak_count = weaks; // 1. fill integrals and classes,?这里计算了每个样本的所有积分图和标定了标签吗?是的,对于正样本,response向量中响应的元素值设置为了1,负样本设置为了0 processPositives(dataset); generateNegatives(dataset); // 2. only simple case (all features used) //使用所有的特征 int nfeatures = pool->size(); cv::Mat varIdx(1, nfeatures, CV_32SC1); int* ptr = varIdx.ptr<int>(0); for (int x = 0; x < nfeatures; ++x) ptr[x] = x; // 3. only simple case (all samples used) //使用所有训练样本 int nsamples = npositives + nnegatives; cv::Mat sampleIdx(1, nsamples, CV_32SC1); ptr = sampleIdx.ptr<int>(0); for (int x = 0; x < nsamples; ++x) ptr[x] = x; // 4. ICF has an ordered response. cv::Mat varType(1, nfeatures + 1, CV_8UC1); //指定按特征响应排序?最后多的一个元素表示什么意思?是不是表示将要输入训练函数train()(在倒数第3行)中的response向量保存的是类标签 uchar* uptr = varType.ptr<uchar>(0); for (int x = 0; x < nfeatures; ++x) uptr[x] = CV_VAR_ORDERED; //这个排序无法理解! uptr[nfeatures] = CV_VAR_CATEGORICAL; trainData.create(nfeatures, nsamples, CV_32FC1); //生成用来训练用的数据矩阵,看下面的双重循环就知道,其中每行对应到一个特征,每列对应到一个样本,该矩阵每个元素保存的是某样本某特征的积分图。 for (int fi = 0; fi < nfeatures; ++fi) { float* dptr = trainData.ptr<float>(fi); for (int si = 0; si < nsamples; ++si) { dptr[si] = pool->apply(fi, si, integrals); } } cv::Mat missingMask; bool ok = train(trainData, responses, varIdx, sampleIdx, varType, missingMask); //?主要的训练函数,trainData是用特征表示的训练集矩阵,responses是样本对应的类标签,varIdx是所学习的特征的所有索引,sampleIdx是训练样本的索引,?varType表示是否对特征响应排序?,missingMask用来保存错误的分类。 if (!ok) CV_Error(CV_StsInternal, "ERROR: tree can not be trained"); return ok; } //其中用到的训练函数如下,可见它调用了boost类的train()函数,并且推测得到的检测器保存在params中。 //它调用了CvBoostTree::train()函数来,并加入了params参数,看来在BoostedSoftCascadeOctave类中没有params成员,但是最终的检测器时保存在该成员中的,到现在为止还是训练整个的检测器: bool BoostedSoftCascadeOctave::train( const cv::Mat& _trainData, const cv::Mat& _responses, const cv::Mat& varIdx, const cv::Mat& sampleIdx, const cv::Mat& varType, const cv::Mat& missingDataMask) { bool update = false; return cv::Boost::train(_trainData, CV_COL_SAMPLE, _responses, varIdx, sampleIdx, varType, missingDataMask, params, update); } //boost类的train()函数定义如下(在boost.cpp中): bool CvBoost::train( const CvMat* _train_data, int _tflag, const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx, const CvMat* _var_type, const CvMat* _missing_mask, CvBoostParams _params, bool _update ) {//顺着来,到了这里,_train_data是训练数据,每行对应到一个特征每列对应到一个样本;_tflag表明了训练数据矩阵的每列对应到一个样本; //_responsens是各样本对应的类标签;_var_idx是特征索引(序号),_sample_idx是样本索引,_var_type在这里表示按特征响应排序;_missing_mask不知何意; //_params保存训练得到的检测器;_update为false。这里使用CvBoost.tree->train()函数训练每一个弱分类器 // bool ok = false; CvMemStorage* storage = 0; CV_FUNCNAME( "CvBoost::train" ); __BEGIN__; int i; set_params( _params );//初始化检测器参数 cvReleaseMat( &active_vars ); cvReleaseMat( &active_vars_abs ); if( !_update || !data ) //准备好训练用的数据,并确定只包含正负样本两类,分配保存弱分类器的存储空间 { clear(); data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, //准备训练数据,但是这里怎么还要用到boost的参数_params呢? _sample_idx, _var_type, _missing_mask, _params, true, true ); if( data->get_num_classes() != 2 ) CV_ERROR( CV_StsNotImplemented, "Boosted trees can only be used for 2-class classification." ); CV_CALL( storage = cvCreateMemStorage() ); weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); //这是CvBoost类中保存弱分类器的向量? storage = 0; } else { data->set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx, _var_type, _missing_mask, _params, true, true, true ); } if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) ) data->do_responses_copy(); update_weights( 0 ); //将各样本权重平均分配 for( i = 0; i < params.weak_count; i++ ) //训练weak_count个弱分类器 { CvBoostTree* tree = new CvBoostTree; if( !tree->train( data, subsample_mask, this ) ) //主要的训练函数,subsample_mask似乎是一个输出参数,查了其初始值是值为0的指针,记录弱分类器正确分类的样本,也许初始值是全0的向量? //第三个参数是训练出的弱分类器要连接的‘宿主’分类器 { delete tree; break; } //cvCheckArr( get_weak_response()); cvSeqPush( weak, &tree ); update_weights( tree ); //这里是不是根据训练出的弱分类器的分类情况调整各样本的权重? trim_weights(); if( cvCountNonZero(subsample_mask) == 0 ) break; } if(weak->total > 0)//释放存储空间 { get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits. data->is_classifier = true; data->free_train_data(); ok = true; } else clear(); __END__; return ok; } //CvBoostTree::train()函数定义如下,它用来训练单个弱分类器,它进一步调用了CvDTree::do_train()函数: CvBoostTree::train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvBoost* _ensemble ) { clear(); ensemble = _ensemble; data = _train_data; data->shared = true; return do_train( _subsample_idx ); } //CvDTree::do_train()函数定义如下(在文件tree.cpp中,头文件为ml.hpp): bool CvDTree::do_train( const CvMat* _subsample_idx ) { bool result = false; CV_FUNCNAME( "CvDTree::do_train" ); __BEGIN__; root = data->subsample_data( _subsample_idx ); //明显是选择参与训练的样本 CV_CALL( try_split_node(root)); if( root->split ) { CV_Assert( root->left ); CV_Assert( root->right ); if( data->params.cv_folds > 0 ) CV_CALL( prune_cv() ); if( !data->shared ) data->free_train_data(); result = true; } __END__; return result; } //do_train()的核心函数如下: CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) { CvDTreeNode* root = 0; CvMat* isubsample_idx = 0; CvMat* subsample_co = 0; bool isMakeRootCopy = true; CV_FUNCNAME( "CvDTreeTrainData::subsample_data" ); __BEGIN__; if( !data_root ) CV_ERROR( CV_StsError, "No training data has been set" ); if( _subsample_idx ) { CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); //如果已训练出了一些弱分类器,则在这里进行一定的处理。_subsample_idx只能是一个行向量或者是列向量 //_subsample_idx中保存的可能是选中的样本的索引,也可能长度为sample_count的表明选择的'0''1'掩膜,但 //输出只包含了选择的样本的编号,并且进行了排序。 if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count ) //isubsample_idx是一个指向行向量或者列向量的指针,这里验证元素个数与样本数是否相等。 { const int* sidx = isubsample_idx->data.i; for( int i = 0; i < sample_count; i++ ) { if( sidx[i] != i ) { isMakeRootCopy = false; //若尚无任何弱分类器,则'isMakeRootCopy = true', break; } } } else isMakeRootCopy = false; } if( isMakeRootCopy ) { // make a copy of the root node CvDTreeNode temp; int i; root = new_node( 0, 1, 0, 0 ); temp = *root; *root = *data_root; root->num_valid = temp.num_valid; if( root->num_valid ) { for( i = 0; i < var_count; i++ ) root->num_valid[i] = data_root->num_valid[i]; } root->cv_Tn = temp.cv_Tn; root->cv_node_risk = temp.cv_node_risk; root->cv_node_error = temp.cv_node_error; } else { int* sidx = isubsample_idx->data.i; // co - array of count/offset pairs (to handle duplicated values in _subsample_idx) int* co, cur_ofs = 0; int vi, i; int workVarCount = get_work_var_count(); //得到已使用的特征个数? int count = isubsample_idx->rows + isubsample_idx->cols - 1; //该弱分类器判断为正的样本个数 root = new_node( 0, count, 1, 0 ); CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )); cvZero( subsample_co ); co = subsample_co->data.i; for( i = 0; i < count; i++ ) co[sidx[i]*2]++; for( i = 0; i < sample_count; i++ ) { if( co[i*2] ) { co[i*2+1] = cur_ofs; cur_ofs += co[i*2]; } else co[i*2+1] = -1; } cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float))); for( vi = 0; vi < workVarCount; vi++ ) { int ci = get_var_type(vi); if( ci >= 0 || vi >= var_count ) { int num_valid = 0; const int* src = CvDTreeTrainData::get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf ); if (is_buf_16u) { unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + vi*sample_count + root->offset); for( i = 0; i < count; i++ ) { int val = src[sidx[i]]; udst[i] = (unsigned short)val; num_valid += val >= 0; } } else { int* idst = buf->data.i + root->buf_idx*get_length_subbuf() + vi*sample_count + root->offset; for( i = 0; i < count; i++ ) { int val = src[sidx[i]]; idst[i] = val; num_valid += val >= 0; } } if( vi < var_count ) root->set_num_valid(vi, num_valid); } else { int *src_idx_buf = (int*)(uchar*)inn_buf; float *src_val_buf = (float*)(src_idx_buf + sample_count); int* sample_indices_buf = (int*)(src_val_buf + sample_count); const int* src_idx = 0; const float* src_val = 0; get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf ); int j = 0, idx, count_i; int num_valid = data_root->get_num_valid(vi); if (is_buf_16u) { unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + vi*sample_count + data_root->offset); for( i = 0; i < num_valid; i++ ) { idx = src_idx[i]; count_i = co[idx*2]; if( count_i ) for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) udst_idx[j] = (unsigned short)cur_ofs; } root->set_num_valid(vi, j); for( ; i < sample_count; i++ ) { idx = src_idx[i]; count_i = co[idx*2]; if( count_i ) for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) udst_idx[j] = (unsigned short)cur_ofs; } } else { int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() + vi*sample_count + root->offset; for( i = 0; i < num_valid; i++ ) { idx = src_idx[i]; count_i = co[idx*2]; if( count_i ) for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) idst_idx[j] = cur_ofs; } root->set_num_valid(vi, j); for( ; i < sample_count; i++ ) { idx = src_idx[i]; count_i = co[idx*2]; if( count_i ) for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) idst_idx[j] = cur_ofs; } } } } // sample indices subsampling const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf); if (is_buf_16u) { unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + workVarCount*sample_count + root->offset); for (i = 0; i < count; i++) sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]]; } else { int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() + workVarCount*sample_count + root->offset; for (i = 0; i < count; i++) sample_idx_dst[i] = sample_idx_src[sidx[i]]; } } __END__; cvReleaseMat( &isubsample_idx ); cvReleaseMat( &subsample_co ); return root; } //do_train()另一个核心函数如下: void CvDTree::try_split_node( CvDTreeNode* node ) { CvDTreeSplit* best_split = 0; int i, n = node->sample_count, vi; bool can_split = true; double quality_scale; calc_node_value( node ); if( node->sample_count <= data->params.min_sample_count || node->depth >= data->params.max_depth ) can_split = false; if( can_split && data->is_classifier ) { // check if we have a "pure" node, // we assume that cls_count is filled by calc_node_value() int* cls_count = data->counts->data.i; int nz = 0, m = data->get_num_classes(); for( i = 0; i < m; i++ ) nz += cls_count[i] != 0; if( nz == 1 ) // there is only one class can_split = false; } else if( can_split ) { if( sqrt(node->node_risk)/n < data->params.regression_accuracy ) can_split = false; } if( can_split ) { best_split = find_best_split(node); // TODO: check the split quality ... node->split = best_split; } if( !can_split || !best_split ) { data->free_node_data(node); return; } quality_scale = calc_node_dir( node ); if( data->params.use_surrogates ) { // find all the surrogate splits // and sort them by their similarity to the primary one for( vi = 0; vi < data->var_count; vi++ ) { CvDTreeSplit* split; int ci = data->get_var_type(vi); if( vi == best_split->var_idx ) continue; if( ci >= 0 ) split = find_surrogate_split_cat( node, vi ); else split = find_surrogate_split_ord( node, vi ); if( split ) { // insert the split CvDTreeSplit* prev_split = node->split; split->quality = (float)(split->quality*quality_scale); while( prev_split->next && prev_split->next->quality > split->quality ) prev_split = prev_split->next; split->next = prev_split->next; prev_split->next = split; } } } split_node_data( node ); try_split_node( node->left ); try_split_node( node->right ); } ![]() (编辑:云计算网_泰州站长网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |