Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolov8seg的640x640输入的推理能正常跑了,但是onnx在服务器4个目标结果,在mnn只出来2个。什么原因? #3030

Open
jamesBaiyuhao opened this issue Sep 13, 2024 · 2 comments
Labels
User The user ask question about how to use. Or don't use MNN correctly and cause bug.

Comments

@jamesBaiyuhao
Copy link

其实input就是 640x640的bgr的输入。
我的模型的input是 1,3,640,640, data应该是int的。 output是1,42,8400,应该是float。
您看是我input不对吗? pretreat->convert((uint8_t*)inputImage, width, height, 0, input);
还是output不对? 我直接output_t->host() 直接处理输出的tensor?

int main(int argc, const char* argv[]) {
if (argc < 4) {
MNN_PRINT("Usage: ./segment.out model.mnn input.jpg output.jpg\n");
return 0;
}
std::shared_ptr net;
net.reset(Interpreter::createFromFile(argv[1]));
if (net == nullptr) {
MNN_ERROR("Invalid Model\n");
return 0;
}
ScheduleConfig config;
auto session = net->createSession(config);
auto input = net->getSessionInput(session, "images");
MNN_PRINT("origin size: line %d\n", LINE);
// input->print();
MNN_PRINT("origin size: line %d\n", LINE);
input->printShape();
MNN_PRINT("origin size: line %d\n", LINE);

auto shape = input->shape();
{
    int input_h = 160;
    int input_w = 160;
    if(shape[2] > 0 && shape[2] > 0)
    {

    }
    else
    {
        shape[2] = input_h;
        shape[3] = input_w;
    }
}
net->resizeTensor(input, shape);
net->resizeSession(session);
input = net->getSessionInput(session, "images");
    MNN_PRINT("origin size: line %d\n", __LINE__);
// input->print();
    MNN_PRINT("origin size: line %d\n", __LINE__);
input->printShape();
    MNN_PRINT("origin size: line %d\n", __LINE__);



{
    auto ow              = input->width();
    auto oh              = input->height();
    auto ochannel             = input->channel();
    auto dimensionFormat = input->getDimensionType();
    MNN_PRINT("input---: ow:%d , oh:%d, ochannel: %d, dimensionFormat %d\n", 
        ow, oh, ochannel, dimensionFormat);

    int size_w   = 0;
    int size_h   = 0;
    int bpp      = 3;


    if (bpp == 0)
        bpp = 1;
    size_w = shape[2] ;
    size_h = shape[3] ;
    MNN_PRINT("input---: w:%d , h:%d, bpp: %d\n", size_w, size_h, bpp);

    auto inputPatch = argv[2];
    int width, height, channel;
    auto inputImage = stbi_load(inputPatch, &width, &height, &channel, 4);
    if (nullptr == inputImage) {
        MNN_ERROR("Can't open %s\n", inputPatch);
        return 0;
    }
    MNN_PRINT("origin size: %d, %d\n", width, height);
    Matrix trans;
    // Set scale, from dst scale to src
    trans.setScale((float)(width-1) / (size_w-1), (float)(height-1) / (size_h-1));
    ImageProcess::Config config;
    config.filterType = CV::BILINEAR;
    //        float mean[3]     = {103.94f, 116.78f, 123.68f};
    //        float normals[3] = {0.017f, 0.017f, 0.017f};
    // float mean[3]     = {127.5f, 127.5f, 127.5f};
    // float normals[3] = {0.00785f, 0.00785f, 0.00785f};
    const float mean[3] = {127.5f, 127.5f, 127.5f};
    const float normals[3] = {2.0f / 255.0f, 2.0f / 255.0f, 2.0f / 255.0f};
    ::memcpy(config.mean, mean, sizeof(mean));
    ::memcpy(config.normal, normals, sizeof(normals));
    config.sourceFormat = RGBA;
    config.destFormat   = BGR;

    MNN_PRINT("origin size: line %d %d, %d\n", __LINE__,width, height);
    std::shared_ptr<ImageProcess> pretreat(ImageProcess::create(config));
    pretreat->setMatrix(trans);
    pretreat->convert((uint8_t*)inputImage, width, height, 0, input);
    stbi_image_free(inputImage);
    MNN_PRINT("origin size: line %d %d, %d\n", __LINE__,width, height);
}
// Run model
    MNN_PRINT("origin size: line %d\n", __LINE__);
net->runSession(session);
    MNN_PRINT("origin size: line %d\n", __LINE__);


int numClasses = 6;//4;
float  confThreshold = 0.35;
float nmsThreshold = 0.5;

// get output
auto output_t = net->getSessionOutput(session, "output0");
    MNN_PRINT("origin size: line %d\n", __LINE__);
printf_value(output_t->host<float>(),"output_ptr",10);
    MNN_PRINT("origin size: line %d\n", __LINE__);
output_t->printShape();
    MNN_PRINT("origin size: line %d\n", __LINE__);

{
    struct timeval tstartNms, tendNms;
    gettimeofday(&tstartNms, NULL);

    int channel  = output_t->channel();
    int height   = output_t->height();
    int width    = output_t->width();
    printf("[%s:%d] channel %d height %d width %d\n", __FUNCTION__, __LINE__, channel, height, width);

    width = height;//8400
    height = channel;//42

    auto scoresPtr = output_t->host<float>();
    float *scoreHead = NULL;
    int iLayer = 0;
    std::vector<BoxInfo_t> outBoxes;


    float cx, cy, w, h;
    float *cxHead = (float *)scoresPtr;
    float *cyHead = (float *)scoresPtr + 1 * width;
    float *wHead = (float *)scoresPtr + 2 * width;
    float *hHead = (float *)scoresPtr + 3 * width;

    printf("[%s:%d] width %d\n", __FUNCTION__, __LINE__, width);
    for (int iBox = 0; iBox < width; iBox++)
    {
        scoreHead = (float *)scoresPtr + 4 * width + iBox;

        float maxScore = -INFINITY;
        int maxInd = -1;

        for (int iClass = 0; iClass < numClasses; iClass++)
        {
            //printf("[%s:%d] iBox %d iClass %d maxScore %f *scoreHead %f\n", __FUNCTION__, __LINE__, iBox, iClass, maxScore, *scoreHead);
            if (*scoreHead >= confThreshold && *scoreHead > maxScore) {
                maxScore = *scoreHead;
                maxInd = iClass;
                printf("[%s:%d] iBox %d iClass %d maxScore %f maxInd %d\n", __FUNCTION__, __LINE__, iBox, iClass, maxScore, maxInd);
            }
            scoreHead += width;
        }

        if (maxInd >= 0)
        {
            cx = cxHead[iBox];
            cy = cyHead[iBox];
            w = wHead[iBox];
            h = hHead[iBox];
            BoxInfo_t rBoxInfo;
            rBoxInfo.x1 = cx - w / 2.0;
            rBoxInfo.y1 = cy - h / 2.0;
            rBoxInfo.x2 = cx + w / 2.0;
            rBoxInfo.y2 = cy + h / 2.0;
            rBoxInfo.score = maxScore;
            rBoxInfo.label = maxInd;

            // printf(" iBox %d maxInd %d: x1 %f %f %f %f w: %f %f label %d score %f\n", iBox, maxInd,
            //     rBoxInfo.x1 , rBoxInfo.y1, rBoxInfo.x2 , rBoxInfo.y2,
            //     (rBoxInfo.x2 - rBoxInfo.x1) , (rBoxInfo.y2 - rBoxInfo.y1), rBoxInfo.label , rBoxInfo.score );

            outBoxes.push_back(rBoxInfo);
        }
    }


    gettimeofday(&tendNms, NULL);
    printf("loop time: %f ms\n",
           (tendNms.tv_sec - tstartNms.tv_sec) * 1000.0 + (tendNms.tv_usec - tstartNms.tv_usec) / 1000.0);
   CheckBoxes(outBoxes, shape[2], shape[3]);
    gettimeofday(&tendNms, NULL);
    printf("CheckBoxes time: %f ms\n",
           (tendNms.tv_sec - tstartNms.tv_sec) * 1000.0 + (tendNms.tv_usec - tstartNms.tv_usec) / 1000.0);
    Nms(outBoxes, nmsThreshold);
    gettimeofday(&tendNms, NULL);
    printf("Nms time: %f ms\n",
           (tendNms.tv_sec - tstartNms.tv_sec) * 1000.0 + (tendNms.tv_usec - tstartNms.tv_usec) / 1000.0);


    for (int i = 0; i < int(outBoxes.size()); ++i)
    {
        printf(" %d: x1 %f %f %f %f w: %f %f label %d score %f\n", i, 
              outBoxes[i].x1 , outBoxes[i].y1, outBoxes[i].x2 , outBoxes[i].y2,
               (outBoxes[i].x2 - outBoxes[i].x1) , (outBoxes[i].y2 - outBoxes[i].y1), outBoxes[i].label , outBoxes[i].score );
    }
}

return 0;

}

平台(如果交叉编译请再附上交叉编译目标平台):

Platform(Include target platform as well if cross-compiling):

macOS

Github版本:

Github Version:

20240912最新版本

直接下载ZIP包请提供下载日期以及压缩包注释里的git版本(可通过7z l zip包路径命令并在输出信息中搜索Comment 获得,形如Comment = bc80b11110cd440aacdabbf59658d630527a7f2b)。 git clone请提供 git commit 第一行的commit id

Provide date (or better yet, git revision from the comment section of the zip. Obtainable using 7z l PATH/TO/ZIP and search for Comment in the output) if downloading source as zip,otherwise provide the first commit id from the output of git commit

编译方式:

Compiling Method

cmake -DMNN_BUILD_DEMO=ON -DMNN_BUILD_CONVERTER=ON -DMNN_BUILD_TOOL=ON -DMNN_BUILD_BENCHMARK=ON -DMNN_BUILD_QUANTOOLS=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON ..

请在这里粘贴cmake参数或使用的cmake脚本路径以及完整输出
Paste cmake arguments or path of the build script used here as well as the full log of the cmake proess here or pastebin

编译日志:

Build Log:

默认
2%] Built target MNNCore
[ 2%] Built target MNNCV
[ 2%] Built target MNNMath
[ 14%] Built target MNNTransform
[ 14%] Built target MNNUtils
[ 25%] Built target MNNCPU
[ 25%] Built target MNNX8664
[ 26%] Built target MNNAVX
[ 27%] Built target MNNAVXFMA
[ 28%] Built target MNNSSE
[ 28%] Built target MNN
[ 30%] Built target MNN_Express
[ 34%] Built target MNNTrain
[ 36%] Built target MNNTrainUtils
[ 37%] Built target pictureRecognition.out
[ 37%] Built target pictureRecognition_module.out
[ 37%] Built target multithread_imgrecog.out
[ 37%] Built target pictureRecognition_batch.out
[ 37%] Built target pictureRotate.out
[ 37%] Built target multiPose.out
[ 37%] Built target segment.out
[ 37%] Building CXX object CMakeFiles/yolov8s-seg.out.dir/demo/exec/yolov8s-seg.cpp.o
[ 38%] Linking CXX executable yolov8s-seg.out
[ 38%] Built target yolov8s-seg.out
[ 38%] Built target expressDemo.out
[ 38%] Built target expressMakeModel.out
[ 38%] Built target transformerDemo.out
[ 38%] Built target nluDemo.out
[ 38%] Built target GetMNNInfo
[ 38%] Built target ModuleBasic.out
[ 38%] Built target SequenceModuleTest.out
[ 38%] Built target mergeInplaceForCPU
[ 39%] Built target modelCompare.out
[ 40%] Built target MNNV2Basic.out
[ 40%] Built target mobilenetTest.out
[ 40%] Built target backendTest.out
[ 41%] Built target testModel.out
[ 41%] Built target testModel_expr.out
[ 41%] Built target testModelWithDescribe.out
[ 42%] Built target getPerformance.out
[ 43%] Built target checkInvalidValue.out
[ 44%] Built target timeProfile.out
[ 44%] Built target testTrain.out
[ 44%] Built target fuseTest
[ 44%] Built target LoRA
[ 44%] Built target checkDir.out
[ 44%] Built target checkFile.out
[ 44%] Built target winogradExample.out
[ 44%] Built target benchmark.out
[ 45%] Built target benchmarkExprModels.out
[ 46%] Built target quantized.out
[ 49%] Built target libprotobuf-lite
[ 58%] Built target libprotobuf
[ 58%] Built target transformer
[ 58%] Built target extractForInfer
[ 59%] Built target runTrainDemo.out
[ 59%] Built target MNNCompress
[ 67%] Built target MNNConverterTF
[ 72%] Built target MNNConverterONNX
[ 72%] Built target OnnxClip
[ 76%] Built target MNNConverterCaffe
[ 76%] Built target MNNConverterMNN
[ 92%] Built target MNNConverterOpt
[ 96%] Built target MNNConverterTFL
[ 97%] Built target MNNConvertDeps
[ 98%] Built target MNNConvert
[ 98%] Built target MNNRevert2Buffer
[ 98%] Built target MNNDump2Json
[ 99%] Built target TestConvertResult
[ 99%] Built target TestPassManager
[100%] Built target MNNOpenCV

粘贴在这里
Paste log here or pastebin
@jamesBaiyuhao
Copy link
Author

转模型是:
./MNNConvert -f ONNX --modelFile yolov8s-seg-bedNurse913_1.onnx --MNNModel yyolov8s-seg-bedNurse913_1.mnn --bizCode MNN --keepInputFormat=0

@jxt1234
Copy link
Collaborator

jxt1234 commented Sep 13, 2024

  1. 换用 Module API ,在 load module 时传入 onnx 里面的4个输出名字 https://mnn-docs.readthedocs.io/en/latest/inference/module.html
  2. 创建 Session 传入 saveTensors 或者设定 path ,指定输出
    https://mnn-docs.readthedocs.io/en/latest/inference/session.html

建议换用 Module API

@jxt1234 jxt1234 added the User The user ask question about how to use. Or don't use MNN correctly and cause bug. label Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
User The user ask question about how to use. Or don't use MNN correctly and cause bug.
Projects
None yet
Development

No branches or pull requests

2 participants