海思ive ann-mlp使用说明(2)
原文:https://blog.csdn.net/brightming/article/details/50895356
5 完整示例
5.1 二维数据的训练与预测
5.1.1 训练二维数据
以y=kx直线进行划分,在直线以下的为类别0,其他为类别1。
在训练的时候,指定k,产生训练数据,同时将一部分作为测试数据。
5.1.1.1 训练入口
/**
- 以斜率=slope 来做分界,训练一个mlp模型
*/
extern "C" void train_2_class_slope(float slope){//(int useExistModel,float from_x,float end_x,float from_y,float end_y,float x_step,float y_step){
CvANN_MLP annMlp;
int outputClassCnt=2;
bool loadModelFromFile=false;
Mat training_datas;
Mat trainClasses;
Mat oriTrainDatas;
generateFix2ClassSlopeTrainData(slope,training_datas,trainClasses);
training_datas.convertTo(training_datas,CV_32FC1);
trainClasses.convertTo(trainClasses,CV_32FC1);
cout<<"training_datas=\n"<<training_datas<<",oridata="<<oriTrainDatas<<endl;
cout<<"trainClasses=\n"<<trainClasses<<endl;
//创建mlp
Mat layers(1, 3, CV_32SC1);
layers.at
cout<<"------------------------trainAnnModel.input sample cnt:"<<training_datas.rows<<",input layer features:"<<layers.at
layers.at
layers.at
cout<<"outputClassCnt="<<outputClassCnt<<endl;
annMlp.create(layers, CvANN_MLP::SIGMOID_SYM, 0.6667f, 1.7159f);
//--------训练mlp-----------//
// Set up BPNetwork‘s parameters
CvANN_MLP_TrainParams params;
params.train_method = CvANN_MLP_TrainParams::BACKPROP;
params.bp_dw_scale = 0.001;
params.bp_moment_scale = 0.0;
CvTermCriteria criteria;
criteria.max_iter = 300;
criteria.epsilon = 9.999999e-06;
criteria.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
params.term_crit = criteria;
annMlp.train(training_datas, trainClasses, Mat(), Mat(), params);
cout<<"train finished"<<endl;
char _dstPath[256];
sprintf(_dstPath,"data/my/my_simple_2_class_20160307slope%.2f.xml",slope);
string dstPath(_dstPath);//="data/my/my_simple_2_class_20160307_slope_1.xml";
annMlp.save(dstPath.c_str());
cout<<"save model finished.model file="<<dstPath<<"\n";
//预测
Mat test_datas;
Mat testClasses;
int testCount=1;//每个象限的测试图片数量
Mat oriTestData;
generate2ClassSlopeTestData(slope,test_datas,testClasses,oriTestData);//,from_x,end_x,from_y,end_y);
test_datas.convertTo(test_datas,CV_32FC1);
testClasses.convertTo(testClasses,CV_32FC1);
cout<<"test_datas=\n"<<test_datas<<",oridata="<<oriTestData<<endl;
// cout<<"testClasses=\n"<<testClasses<<endl;
int correctCount=0;
int errorCount=0;
cout<<"test_datas size="<<test_datas.rows<<endl;
int totalTestSize=test_datas.rows;
bool right=false;
// TestData_2Feature* cur=testDataHead;
int expected_idx=0;
for(int i=0;i<totalTestSize;i++){
Mat predict_result(1, outputClassCnt, CV_32FC1);
annMlp.predict(test_datas.row(i), predict_result);
Point maxLoc;
double maxVal;
minMaxLoc(predict_result, 0, &maxVal, 0, &maxLoc);
right=false;
if(test_datas.row(i).at
expected_idx=0;
}else{
expected_idx=1;
}
if(expected_idx==maxLoc.x){
++correctCount;
right=true;
}else {
++errorCount;
}
cout<<"data:"<<test_datas.row(i)<<"("<<oriTestData.row(i)<<"),predict_result="<<predict_result<<",maxVal="<<maxVal<<",maxLoc.x="<<maxLoc.x<<",right?"<<right<<endl;
// cur=cur->next;
}
cout<<"total test data count="<<totalTestSize<<",correct count="<<correctCount<<",error count="<<errorCount<<",accurate="<<(correctCount)*1.0f/(totalTestSize)<<endl;
}
5.1.1.2 训练与测试数据产生方法
/**
- y=x,划分,
*/
void generateFix2ClassSlopeTrainData(float slope,Mat& mat,Mat& labels){
vectordataVec;
vectorlabVec;
float tmp1=0,tmp2=0;
printf("generateFix2ClassSlopeTrainData begin\n");
int multi=1;
float x_step=16;
float y_step=16;
int needTestSize=10;
int nowTestSize=0;
int loopcnt=0;
ostringstream os;
Int end_x=255;
Int end_y=255;
int getDataInterval=((end_x-0)/x_step (end_y-0)/y_step)/needTestSize;
printf("getDataInterval=%d,totalTrainSize=%d\n",(deltaX/x_step deltaY/y_step));
for(int x=0;x<end_x;x+=x_step){
for(int y=0;y<end_y;y+=x_step){
++loopcnt;
dataVec.clear();
multi=-1;
tmp1=multi(float)x;///255;
dataVec.push_back(tmp1);
tmp2=multi*(float)y;///255;
dataVec.push_back(tmp2);
// printf("tmp1=%f\n",tmp1);
// Mat tpmat=Mat(dataVec).reshape(1,1).clone();
mat.push_back(Mat(dataVec).reshape(1,1).clone());
labVec.clear();
if(tmp1*slope>tmp2){// x> 为类0
labVec.push_back(1.0f);
labVec.push_back(0.0f);
labels.push_back(Mat(labVec).reshape(1,1).clone());
if(loopcnt%getDataInterval==0){
os<<"0:";
}
}else{
labVec.push_back(0.0f);
labVec.push_back(1.0f);
labels.push_back(Mat(labVec).reshape(1,1).clone());
if(loopcnt%getDataInterval==0){
os<<"1:";
}
}
if(loopcnt%getDataInterval==0){
os<<x<<" "<<y<<endl;
}
}
}
//输出一部分作为测试文件
system("rm data/my/test2classdata_slope.list");
fstream ftxt;
string testfile="data/my/test2classdata_slope.list";
ftxt.open(testfile.c_str(), ios::out | ios::app);
if (ftxt.fail()) {
cout << "创建文件:"<<testfile<<" 失败!" <<endl;
getchar();
}
ftxt << os.str();
ftxt.close();
}
5.1.2 海思预测二维数据样本的所属类别
5.1.2.1 预测入口
/**
- 测a试?y=kx的?分?类え?情é况?
/
HI_VOID SAMPLE_IVE_Ann_predict_2class_slope(float slope){
// HI_CHAR pchBinFileName;
int height,width,image_type;
char pchBinFileName[256];
sprintf(pchBinFileName,"./data/my/my_simple_2_class_20160307slope%.2f.bin",slope);
// pchBinFileName = "./data/my/my_simple_2_class_20160307_slope_3.00.bin";
height=1;
width=2;
image_type=IVE_IMAGE_TYPE_S32C1;
HI_S32 s32Ret;
SAMPLE_IVE_ANN_INFO_S stAnnInfo;
printf("use model bin file:%s\n",pchBinFileName);
SAMPLE_COMM_IVE_CheckIveMpiInit();
s32Ret=SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(&stAnnInfo, pchBinFileName,image_type,height,width);
if (HI_SUCCESS != s32Ret)
{
SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp__2Class_Init fail\n");
goto ANN_FAIL;
}
// predict2ClassData(&stAnnInfo,slope);
predict2ClassSlopeData(&stAnnInfo,slope);
//uninit
SAMPLE_IVE_Ann_Mlp_Uninit(&stAnnInfo);
ANN_FAIL:
SAMPLE_COMM_IVE_IveMpiExit();
}
5.1.2.2 初始化
/**
-
function : Ann mlp init
**/
static HI_S32 SAMPLE_IVE_Ann_Mlp_2Class_Slope_Init(SAMPLE_IVE_ANN_INFO_S pstAnnInfo, HI_CHAR pchBinFileName,int image_type,int height,int width )
{
SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp_Init.....\n");
HI_S32 s32Ret = HI_SUCCESS;
HI_U32 u32Size;memset(pstAnnInfo, 0, sizeof(SAMPLE_IVE_ANN_INFO_S));
/**
- 查é找ò表括?里?的?数簓值μ范?围§是?[0,1],?精?度è是?8位?,?即′1<<8=256,?
- 表括?示?要癮被?分?成é256段?。£
-
/
pstAnnInfo->stTable.s32TabInLower = 0;
pstAnnInfo->stTable.s32TabInUpper = 1;//1;
pstAnnInfo->stTable.u8TabInPreci = 8;//12;
pstAnnInfo->stTable.u8TabOutNorm = 2;//2
pstAnnInfo->stTable.u16ElemNum = (pstAnnInfo->stTable.s32TabInUpper-pstAnnInfo->stTable.s32TabInLower) << pstAnnInfo->stTable.u8TabInPreci;
u32Size = pstAnnInfo->stTable.u16ElemNum sizeof(HI_U16);
// SAMPLE_PRT("stTable.s32TabInLower=%d,s32TabInUpper=%d,u8TabInPreci=%d,u8TabOutNorm=%d,u16ElemNum=%d\n",pstAnnInfo->stTable.s32TabInLower,pstAnnInfo->stTable.s32TabInUpper,pstAnnInfo->stTable.u8TabInPreci,pstAnnInfo->stTable.u8TabOutNorm,pstAnnInfo->stTable.u16ElemNum);
s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stTable.stTable), u32Size);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
goto ANN_INIT_FAIL;
}
s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 0.6667f, 1.7159f);
// s32Ret = SAMPLE_IVE_Ann_Mlp_CreateTable(&(pstAnnInfo->stTable), 1.0f, 1.0f);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_IVE_Ann_Mlp_CreateTable fail\n");
goto ANN_INIT_FAIL;
}
SAMPLE_PRT("begin to load model:%s\n",pchBinFileName);
s32Ret = HI_MPI_IVE_ANN_MLP_LoadModel(pchBinFileName, &(pstAnnInfo->stAnnModel));
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("HI_MPI_IVE_ANN_MLP_LoadModel fail,Error(%#x)\n", s32Ret);
goto ANN_INIT_FAIL;
}
printf("finish load model:%s\n",pchBinFileName);
u32Size = pstAnnInfo->stAnnModel.au16LayerCount[0] * sizeof(HI_S16Q16);//输?入?层?需è要癮的?空?间?大洙?小?:阰输?入?层?的?元a素?个?数簓*每?个?元a素?的?大洙?小?
printf("allocate memory for input,size=%d\n",u32Size);
s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stSrc), u32Size);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
goto ANN_INIT_FAIL;
}
u32Size = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1] * sizeof(HI_S16Q16);//输?出?类え?别纄信?息¢所ù需è空?间?的?大洙?小?:阰输?出?层?类え?别纄数簓*每?个?类え?别纄数簓值μ的?占?的?空?间?
// SAMPLE_PRT("annModel output class cnt=%d\n",pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]);
printf("allocate memory for output,size=%d\n",u32Size);
s32Ret = SAMPLE_COMM_IVE_CreateMemInfo(&(pstAnnInfo->stDst), u32Size);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("SAMPLE_COMM_IVE_CreateMemInfo fail\n");
goto ANN_INIT_FAIL;
}
ANN_INIT_FAIL:
// printf("s32Ret=%d,HI_SUCCESS=%d\n",s32Ret,HI_SUCCESS);
if (HI_SUCCESS != s32Ret)
{
SAMPLE_IVE_Ann_Mlp_Uninit(pstAnnInfo);
}
return s32Ret;
}
5.1.2.3 预测
/**
- 预¤测ay=kx的?分?类え?
/
void predict2ClassSlopeData(SAMPLE_IVE_ANN_INFO_S pstAnnInfo,float slope){
char contFile="data/my/test2classdata_slope_eq_1.list";
printf("try to get file info:%s\n",contFile);
TestData_2Feature head=get2FeatureData(contFile);
printf("after read file:%s,head=%p\n",contFile,head);
if(!head){
printf("fail to read contFile:%s\n",contFile);
return;
}
// printStringNode(head,"1");
// printStringNode(head,"2");
HI_S32 i, k;
HI_S32 s32Ret;
HI_S32 s32ResponseCls;
HI_U16 u16LayerCount;
HI_S16Q16 ps16q16Dst;
HI_S16Q16 s16q16Response;
HI_BOOL bInstant = HI_TRUE;
HI_BOOL bFinish;
HI_BOOL bBlock = HI_TRUE;
// HI_CHAR achFileName[IVE_FILE_NAME_LEN];
FILE pFp = HI_NULL;
IVE_HANDLE iveHandle;
int xs[3]={-5,-4,3};
int ys[3]={99,-10,10};
srand(time(NULL));
int totalCount=0;
int correctCount=0;
TestData_2Feature* cur=head;
int cnt=0;
int expected_idx=0;
while(cur!=NULL){
// printf("flag=%d,filePath=%s,filenName=%s -->\n ",cur->flag,cur->fileFullPath,cur->fileName);
ps16q16Dst = (HI_S16Q16*)pstAnnInfo->stDst.pu8VirAddr;
s16q16Response = 0;
s32ResponseCls = -1;
HI_S16Q16 stSrc=(HI_S16Q16)pstAnnInfo->stSrc.pu8VirAddr;
stSrc[0]=changeFloatToS16Q16(cur->x1);//转换为以s16q16表示的数据
stSrc[1]=changeFloatToS16Q16(cur->x2);
s32Ret = HI_MPI_IVE_ANN_MLP_Predict(&iveHandle, &(pstAnnInfo->stSrc), \
& (pstAnnInfo->stTable), &(pstAnnInfo->stAnnModel), &(pstAnnInfo->stDst), bInstant);
if (s32Ret != HI_SUCCESS)
{
SAMPLE_PRT("HI_MPI_IVE_ANN_MLP_Predict fail,Error(%#x)\n", s32Ret);
break;
}
s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock);
while (HI_ERR_IVE_QUERY_TIMEOUT == s32Ret)
{
usleep(100);
s32Ret = HI_MPI_IVE_Query(iveHandle, &bFinish, bBlock);
}
if (HI_SUCCESS != s32Ret)
{
SAMPLE_PRT("HI_MPI_IVE_Query fail,Error(%#x)\n", s32Ret);
break;
}
u16LayerCount = pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1];
// SAMPLE_PRT("pstAnnInfo->CstAnnModel.u8LayerNum=%d,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]=%d\n",pstAnnInfo->stAnnModel.u8LayerNum,pstAnnInfo->stAnnModel.au16LayerCount[pstAnnInfo->stAnnModel.u8LayerNum - 1]);
SAMPLE_PRT(" \n--predict2ClassSlopeData--Begin-- x1=%f(s16q16=%d),x2=%f(s16q16=%d)\n",cur->x1,changeFloatToS16Q16(cur->x1),cur->x2,changeFloatToS16Q16(cur->x2));
++totalCount;
for (k = 0; k < u16LayerCount; k++)
{
printf(" ps16q16Dst[%d]=%d,H16Q16=%f\n", k,ps16q16Dst[k],calculateS16Q16_c(ps16q16Dst[k]));
if (s16q16Response < ps16q16Dst[k])
{
s16q16Response = ps16q16Dst[k];
s32ResponseCls = k;
}
}
if(cur->x1*slope>cur->x2){
expected_idx=0;
}else{
expected_idx=1;
}
SAMPLE_PRT(" --predict2ClassSlopeData--End-- result:%s,flag:%d,class:%d ------\n\n",(expected_idx==s32ResponseCls?"right":"wrong"),expected_idx,s32ResponseCls);
cur=cur->next;
}
freeTestData_2FeatureNode(head);
}
附上斜率为的预测结果输出:
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.220000(s16q16=14417),x2=0.100000(s16q16=6553)
ps16q16Dst[0]=46174,H16Q16=0.704559
ps16q16Dst[1]=20098,H16Q16=0.306671
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=-1.000000(s16q16=-65536),x2=-3.000000(s16q16=-196608)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=1.000000(s16q16=65536),x2=0.200000(s16q16=13107)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.200000(s16q16=13107),x2=0.700000(s16q16=45875)
ps16q16Dst[0]=16687,H16Q16=0.254623
ps16q16Dst[1]=51450,H16Q16=0.785065
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.400000(s16q16=26214),x2=0.900000(s16q16=58982)
ps16q16Dst[0]=16830,H16Q16=0.256805
ps16q16Dst[1]=51033,H16Q16=0.778702
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=0.690196(s16q16=45232),x2=0.062745(s16q16=4112)
ps16q16Dst[0]=48919,H16Q16=0.746445
ps16q16Dst[1]=15412,H16Q16=0.235168
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:0,class:0 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=224.000000(s16q16=14680064),x2=80.000000(s16q16=5242880)
ps16q16Dst[0]=17622,H16Q16=0.268890
ps16q16Dst[1]=45294,H16Q16=0.691132
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:wrong,flag:0,class:1 ------
[predict2ClassSlopeData]-830:
--predict2ClassSlopeData--Begin-- x1=-224.000000(s16q16=-14680064),x2=80.000000(s16q16=5242880)
ps16q16Dst[0]=17117,H16Q16=0.261185
ps16q16Dst[1]=51728,H16Q16=0.789307
[predict2ClassSlopeData]-847: --predict2ClassSlopeData--End-- result:right,flag:1,class:1 ------
- 分享
- 举报
-
浏览量:2474次2020-08-25 18:07:51
-
浏览量:4571次2020-06-19 15:56:33
-
浏览量:3209次2020-08-05 20:32:31
-
浏览量:5722次2020-10-13 17:14:09
-
浏览量:4876次2021-03-26 15:39:50
-
浏览量:1134次2023-08-29 15:52:13
-
浏览量:5222次2019-12-28 10:17:47
-
2024-08-22 21:17:40
-
浏览量:3272次2018-04-12 11:32:51
-
浏览量:3639次2017-11-16 11:30:55
-
浏览量:1695次2019-12-31 16:25:11
-
浏览量:2182次2020-08-05 21:02:35
-
浏览量:15318次2018-09-27 20:15:39
-
浏览量:6106次2019-12-28 10:35:51
-
浏览量:1270次2023-10-12 14:39:21
-
浏览量:2366次2020-08-30 12:39:35
-
浏览量:2508次2019-12-31 16:23:45
-
浏览量:5125次2020-12-19 16:14:06
-
浏览量:3522次2020-07-30 10:36:08
-
广告/SPAM
-
恶意灌水
-
违规内容
-
文不对题
-
重复发帖
在学了在学了!
感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~
举报类型
- 内容涉黄/赌/毒
- 内容侵权/抄袭
- 政治相关
- 涉嫌广告
- 侮辱谩骂
- 其他
详细说明