Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large difference in yolov5 inference output between Neuron compiled model vs expected #435

Closed
jeffhataws opened this issue Jun 20, 2022 · 17 comments

Comments

@jeffhataws
Copy link
Contributor

jeffhataws commented Jun 20, 2022

Originally posted by @josebenitezg in #253 (comment)

Hi!

I was able to convert the model from yolov5 to neuron with the follow code:

import torch
import torch_neuron
from torchvision import models

model = torch.hub.load('yolo5',
        'custom',
        path='yolov5.pt',
        source='local',
        force_reload=True)  # local repo

fake_image = torch.zeros([1, 3, 640, 640], dtype=torch.float32)
#fake_image = (torch.rand(3), torch.rand(3))
try:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])
except Exception:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])

model_neuron = torch.neuron.trace(model, 
                                example_inputs=[fake_image])

## Export to saved model
model_neuron.save("model_converted.pt")

Now that I am trying to test and compare I have the tensors outputs different from yolo as follow:

Neuron Yolov5 Model:

[tensor([[-0.0356,  0.1790,  0.7456,  0.6292,  0.9359, 13.0000],
        [ 0.5830,  0.1404,  1.1279,  0.6628,  0.9359, 13.0000],
        [ 0.0823,  0.6350,  0.6272,  1.1599,  0.9315, 13.0000],
        [-0.1443,  0.1416,  0.2542,  0.5107,  0.9224, 13.0000],
        [ 0.3516,  0.6426,  0.7500,  1.0137,  0.9188, 13.0000],
        [ 0.3555,  0.1436,  0.7539,  0.5127,  0.9147, 13.0000]])]

Yolov5:

[tensor([[334.57495, 176.98302, 407.46155, 213.81169,   0.93721,  13.00000]])]

Inference script:

im = cv2.imread('test_img.jpg')
img0 = im.copy()
im = cv2.resize(im, (640, 640), interpolation = cv2.INTER_AREA)
# Convert
im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im)
# Convert into torch
im = torch.from_numpy(im)
im = im.float()  # uint8 to fp16/32
im /= 255  # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
    im = im[None]  # expand for batch dim

# Load the compiled model
model = torch.jit.load('model_converted.pt')

# Inference
pred = model(im)
pred = non_max_suppression(pred) #nms function used same as yolov5 detect.py

#Process predictions
for i, det in enumerate(pred):  # per image
    im0 = img0.copy()
    color=(30, 30, 30)
    txt_color=(255, 255, 255)
    h_size, w_size = im.shape[-2:]
    print(h_size, w_size)
    lw = max(round(sum(im.shape) / 2 * 0.003), 2) 

    if len(det):
        # Write results
        for *xyxy, conf, cls in reversed(det):
            c = int(cls)  # integer class
            label = f'{CLASSES[c]} {conf:.2f}'
            print(label)
            box = xyxy 
            p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))
            cv2.rectangle(im0, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
            tf = max(lw - 1, 1)  # font thickness
            w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]  # text width, height
            outside = p1[1] - h - 3 >= 0  # label fits outside box
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
            cv2.rectangle(im0, p1, p2, color, -1, cv2.LINE_AA)  # filled
            cv2.putText(im0,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        lw / 3,
                        txt_color,
                        thickness=tf,
                        lineType=cv2.LINE_AA)
    # Save results (image with detections)
    status = cv2.imwrite('out.jpg', im0)

Is there something wrong when converting the model or running inference? The label and also the acc seems to be same as the expected, but tensors not.

Originally posted by @josebenitezg in #253 (comment)

@jeffhataws
Copy link
Contributor Author

@josebenitezg,
Thanks for reporting the issue. We have already posted a solution for this issue here ultralytics/yolov5#7739 (comment) . I have verified the fix using your script and the public yolov5 model ("ultralytics/yolov5"). Please let us know if you have further problems.

@josebenitezg
Copy link

josebenitezg commented Jul 1, 2022

Hi @jeffhataws
Thank you for you help, I was able to convert the model sucessfully but when I try with the xlarge model, fails
With small, medium and large model is not a problem!

Sometimes, the inf1 instance crash and I have to reboot when I trying to convert te xlarge model

I use the same code above

Traceback (most recent call last):
  File "compile_neuron.py", line 36, in <module>
    compile_to_neuron(model_to_compile, yolov5_neuron_path)
  File "compile_neuron.py", line 26, in compile_to_neuron
    trace = torch_neuron.trace(model.model, (example,), subgraph_builder_function=subgraph_builder_function)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 168, in trace
    cu.stats_post_compiler(neuron_graph)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 500, in stats_post_compiler
    "No operations were successfully partitioned and compiled to neuron for this model - aborting trace!")
RuntimeError: No operations were successfully partitioned and compiled to neuron for this model - aborting trace!
(aws_neuron_pytorch_p36) ubuntu@ip-172-31-19-19:~/some_path/neuron_test$ ...................... 

it keeps showing a few dots as if it's still compiling

@josebenitezg
Copy link

Complete log

Python 3.7.0 required by YOLOv5, but Python 3.6.13 is currently installed
YOLOv5 🚀 v6.1-272-g8983324 Python-3.6.13 torch-1.10.1+cu102 CPU

Fusing layers... 
Model summary: 574 layers, 140153500 parameters, 0 gradients, 208.5 GFLOPs
Adding AutoShape... 
/home/ubuntu/repos/yolov5/models/yolo.py:62: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
/home/ubuntu/repos/yolov5/models/yolo.py:62: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
INFO:Neuron:There are 4 ops of 1 different types in the TorchScript that are not compiled by neuron-cc: aten::_convolution, (For more information see https://github.com/aws/aws-neuron-sdk/blob/master/release-notes/neuron-cc-ops/neuron-cc-ops-pytorch.md)
There are 4 ops of 1 different types in the TorchScript that are not compiled by neuron-cc: aten::_convolution, (For more information see https://github.com/aws/aws-neuron-sdk/blob/master/release-notes/neuron-cc-ops/neuron-cc-ops-pytorch.md)
INFO:Neuron:Number of arithmetic operators (pre-compilation) before = 463, fused = 374, percent fused = 80.78%
Number of arithmetic operators (pre-compilation) before = 463, fused = 374, percent fused = 80.78%
INFO:Neuron:Compiling function _NeuronGraph$1732 with neuron-cc
Compiling function _NeuronGraph$1732 with neuron-cc
INFO:Neuron:Compiling with command line: '/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config {"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]} --verbose 35'
Compiling with command line: '/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config {"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]} --verbose 35'
..........................INFO:Neuron:Compile command returned: -9
Compile command returned: -9
WARNING:Neuron:torch.neuron.trace failed on _NeuronGraph$1732; falling back to native python function call
torch.neuron.trace failed on _NeuronGraph$1732; falling back to native python function call
ERROR:Neuron:neuron-cc failed with the following command line call:
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config '{"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]}' --verbose 35
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 389, in op_converter
    item, inputs, compiler_workdir=sg_workdir, **kwargs)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/decorators.py", line 221, in trace
    'neuron-cc failed with the following command line call:\n{}'.format(command))
subprocess.SubprocessError: neuron-cc failed with the following command line call:
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config '{"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]}' --verbose 35
neuron-cc failed with the following command line call:
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config '{"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]}' --verbose 35
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 389, in op_converter
    item, inputs, compiler_workdir=sg_workdir, **kwargs)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/decorators.py", line 221, in trace
    'neuron-cc failed with the following command line call:\n{}'.format(command))
subprocess.SubprocessError: neuron-cc failed with the following command line call:
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/bin/neuron-cc compile /tmp/tmp1xf0_msa/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /tmp/tmp1xf0_msa/graph_def.neff --io-config '{"inputs": {"0:0": [[1, 3, 640, 640], "float32"]}, "outputs": ["mul_122:0", "mul_134:0", "mul_146:0", "mul_158:0"]}' --verbose 35
INFO:Neuron:Number of arithmetic operators (post-compilation) before = 463, compiled = 0, percent compiled = 0.0%
Number of arithmetic operators (post-compilation) before = 463, compiled = 0, percent compiled = 0.0%
INFO:Neuron:The neuron partitioner created 1 sub-graphs
The neuron partitioner created 1 sub-graphs
INFO:Neuron:Neuron successfully compiled 0 sub-graphs, Total fused subgraphs = 1, Percent of model sub-graphs successfully compiled = 0.0%
Neuron successfully compiled 0 sub-graphs, Total fused subgraphs = 1, Percent of model sub-graphs successfully compiled = 0.0%
INFO:Neuron:Compiled these operators (and operator counts) to Neuron:
Compiled these operators (and operator counts) to Neuron:
INFO:Neuron:Not compiled operators (and operator counts) to Neuron:
Not compiled operators (and operator counts) to Neuron:
INFO:Neuron: => aten::Int: 16 [supported]
 => aten::Int: 16 [supported]
INFO:Neuron: => aten::_convolution: 163 [supported]
 => aten::_convolution: 163 [supported]
INFO:Neuron: => aten::add: 36 [supported]
 => aten::add: 36 [supported]
INFO:Neuron: => aten::cat: 23 [supported]
 => aten::cat: 23 [supported]
INFO:Neuron: => aten::contiguous: 4 [supported]
 => aten::contiguous: 4 [supported]
INFO:Neuron: => aten::max_pool2d: 3 [supported]
 => aten::max_pool2d: 3 [supported]
INFO:Neuron: => aten::mul: 16 [supported]
 => aten::mul: 16 [supported]
INFO:Neuron: => aten::permute: 4 [supported]
 => aten::permute: 4 [supported]
INFO:Neuron: => aten::pow: 4 [supported]
 => aten::pow: 4 [supported]
INFO:Neuron: => aten::select: 4 [supported]
 => aten::select: 4 [supported]
INFO:Neuron: => aten::sigmoid: 4 [supported]
 => aten::sigmoid: 4 [supported]
INFO:Neuron: => aten::silu: 159 [supported]
 => aten::silu: 159 [supported]
INFO:Neuron: => aten::size: 12 [supported]
 => aten::size: 12 [supported]
INFO:Neuron: => aten::split_with_sizes: 4 [supported]
 => aten::split_with_sizes: 4 [supported]
INFO:Neuron: => aten::upsample_nearest2d: 3 [supported]
 => aten::upsample_nearest2d: 3 [supported]
INFO:Neuron: => aten::view: 8 [supported]
 => aten::view: 8 [supported]
Traceback (most recent call last):
  File "compile_neuron.py", line 36, in <module>
    compile_to_neuron(model_to_compile, yolov5_neuron_path)
  File "compile_neuron.py", line 26, in compile_to_neuron
    trace = torch_neuron.trace(model.model, (example,), subgraph_builder_function=subgraph_builder_function)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 168, in trace
    cu.stats_post_compiler(neuron_graph)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 500, in stats_post_compiler
    "No operations were successfully partitioned and compiled to neuron for this model - aborting trace!")
RuntimeError: No operations were successfully partitioned and compiled to neuron for this model - aborting trace!

@fishuke
Copy link

fishuke commented Jul 4, 2022

I've the same problem, any updates?

@RobinFrcd
Copy link

RobinFrcd commented Jul 13, 2022

Same issue here, ultralytics/yolov5#7739 (comment) didn't help.

The output classes seem to match, but the other part of the tensor isn't the same at all:

jp = jit_model(img_tensor)
np = neuron_fix_model(img_tensor)

jp[0]
tensor([[[5.3412e+00, 4.9883e+00, 1.9919e+01,  ..., 1.0614e-02,
          6.9510e-02, 2.9334e-01],
         [1.1563e+01, 6.0070e+00, 2.5475e+01,  ..., 1.2538e-02,
          8.1747e-02, 2.6231e-01],
         [2.0204e+01, 6.7078e+00, 2.6301e+01,  ..., 1.2518e-02,
          8.8539e-02, 6.8642e-02],
         ...,
         [5.6400e+02, 6.1444e+02, 1.7986e+02,  ..., 1.5812e-01,
          2.1855e-01, 9.6946e-01],
         [5.9491e+02, 6.0066e+02, 1.1395e+02,  ..., 3.8019e-02,
          1.0645e-01, 9.6876e-01],
         [6.1129e+02, 6.1150e+02, 1.3300e+02,  ..., 3.8576e-02,
          8.6339e-02, 9.1997e-01]]])

np[0]
tensor([[[0.5859, 0.5625, 0.7031,  ..., 0.0107, 0.0708, 0.2969],
         [0.4727, 0.6250, 0.7969,  ..., 0.0125, 0.0825, 0.2637],
         [0.5117, 0.6680, 0.8086,  ..., 0.0125, 0.0889, 0.0693],
         ...,
         [0.5586, 0.3496, 0.3477,  ..., 0.1582, 0.2178, 0.9688],
         [0.5469, 0.1357, 0.2773,  ..., 0.0378, 0.1050, 0.9688],
         [0.3027, 0.3047, 0.2988,  ..., 0.0386, 0.0864, 0.9180]]])

jp[0][..., 4][jp[0][..., 4] > 0.45]
tensor([0.8304, 0.6804, 0.7879, 0.8152, 0.6366, 0.7737, 0.8117, 0.6276, 0.7688])

np[0][..., 4][np[0][..., 4] > 0.45]
tensor([0.8281, 0.6797, 0.7891, 0.8164, 0.6367, 0.7773, 0.8086, 0.6250, 0.7695])

@jluntamazon
Copy link
Contributor

@josebenitezg

Sometimes, the inf1 instance crash and I have to reboot when I trying to convert te xlarge model

In the compiler logs you provided the following error indicates an out-of-memory issue during compilation:

Compile command returned: -9

The error code will be given a better message in an upcoming release. For now the recommendation is to compile on a larger instance type.

@RobinFrcd

The output classes seem to match, but the other part of the tensor isn’t the same at all:

Is it possible you are using an old version of the yolov5 repository? This mutation fix is only available starting in: https://github.com/ultralytics/yolov5/releases/tag/v6.0

If this is not the issue, do you have an example script similar to ultralytics/yolov5#7739 (comment) that easily allows the issue to be reproduced?

@RobinFrcd
Copy link

Indeed, upgrading YOLO to v6.1 fixed the issue. The great thing is that v5.0 models are still compatible with the v6 so didn't have to retrain anything !

Thanks ! 🙏

@jluntamazon
Copy link
Contributor

Reopening since there may be unresolved issues.

@fishuke Could you clarify which issues you have encountered?

@jluntamazon jluntamazon reopened this Jul 13, 2022
@fishuke
Copy link

fishuke commented Jul 18, 2022

<omitting python frames> frame #25: __libc_start_main + 0xe7 (0x7eff0723cc87 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped)

Hello while i was compiling yolov5s6 i encountered this problem. @RobinFrcd can you provide code snippet for compiling yolo v6.1

@RobinFrcd
Copy link

RobinFrcd commented Jul 18, 2022

To make it work, I did this:

  • Create a Python 3.7.10 virtualenv
  • Install yolov requirements
  • pip install "torch-neuron==1.10.2.*" "neuron-cc[tensorflow]" "protobuf<4" torchvision==0.11.3 --extra-index-url https://pip.repos.neuron.amazonaws.com
  • Edit export.py:
    1. Replace torch.jit.trace with torch.neuron.trace
    2. Add
# Configure to use inplace flag for Neuron
for m in model.modules():
    if hasattr(m, 'inplace'):
        m.inplace = False
  • Run python export.py --weights model.pt --img 640 --batch 1 --include torchscript

@fishuke
Copy link

fishuke commented Jul 18, 2022

Thanks, I successfully compiled yolov5 for neuron. But i have an error while inferencing.

2022-Jul-18 08:43:01.0682  8984:8984  ERROR   NRT:nrt_init                                Neuron Runtime 2.x requires Neuron driver(aws-neuron-dkms) version 2.1 or above.
2022-Jul-18 08:43:01.0682  8984:8984  ERROR   NRT:nrt_init                                Please make sure to upgrade to latest aws-neuron-dkms; for detailed installation instructions visit Neuron documentation.
2022-Jul-18 08:43:01.0682  8984:8984  ERROR   NRT:nrt_init                                If you still see this warning message after upgrading to latest aws-neuron-dkms, it can be due to an existing older runtime version (neuron-rtd), please visit trouble shooting guide for a solution.
Traceback (most recent call last):
  File "v5_inference.py", line 23, in <module>
    pred = model(im)
  File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torch_neuron/runtime/___torch_mangle_550.py", line 9, in forward
    argument_1: Tensor) -> Tensor:
    _0 = getattr(self, "_NeuronGraph#0")
    _1 = ops.neuron.forward_1([argument_1], CONSTANTS.c0, CONSTANTS.c1, CONSTANTS.c2)
         ~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return _1

Traceback of TorchScript, original code (most recent call last):
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/decorators.py(432): forward
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(709): _slow_forward
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(725): _call_impl
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/graph.py(548): __call__
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/graph.py(207): run_op
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/graph.py(196): __call__
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/runtime.py(69): forward
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(709): _slow_forward
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(725): _call_impl
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py(940): trace_module
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py(742): trace
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/tensorboard.py(307): tb_parse
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/tensorboard.py(533): tb_graph
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/decorators.py(482): maybe_generate_tb_graph_def
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py(513): maybe_determine_names_from_tensorboard
/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py(200): trace
main.py(15): <module>
RuntimeError: The PyTorch Neuron Runtime could not be initialized. Neuron Driver issues are logged
to your system logs. See the Neuron Runtime's troubleshooting guide for help on this
topic: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/```

@RobinFrcd
Copy link

I see you're running Python 3.6 + outdated neuron runtime versions, maybe you could try to upgrade everything first. Other than that, no idea, sorry !

@fishuke
Copy link

fishuke commented Jul 18, 2022

I updated the neuron runtime following here.

I will update python to 3.7 and will see if it helps. Also packages are updated too.

@fishuke
Copy link

fishuke commented Jul 18, 2022

We managed to successfully run the project it compiles to output but. It doesnt draw anything. It find objects but it fails to draw where are them. Here is my code:

import torch
import numpy as np
import torch.neuron
from util_yolo import non_max_suppression

im = cv2.imread('img.jpg')
img0 = im.copy()
im = cv2.resize(im, (640, 384), interpolation = cv2.INTER_AREA)
# Convert
im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im)
# Convert into torch
im = torch.from_numpy(im)
im = im.float()  # uint8 to fp16/32
im /= 255  # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
    im = im[None]  # expand for batch dim

# Load the compiled model
model = torch.jit.load('neuron_yolov5s6.pt')

# Inference
pred = model(im)
pred = non_max_suppression(pred) #nms function used same as yolov5 detect.py

#Process predictions
for i, det in enumerate(pred):  # per image
    im0 = img0.copy()
    color=(30, 30, 30)
    txt_color=(255, 255, 255)
    h_size, w_size = im.shape[-2:]
    print(h_size, w_size)
    lw = max(round(sum(im.shape) / 2 * 0.003), 2)

    if len(det):
        # Write results
        for *xyxy, conf, cls in reversed(det):
            c = int(cls)  # integer class
            #label = f'{CLASSES[c]} {conf:.2f}'
            #print(label)
            label = "human"
            box = xyxy
            p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))
            cv2.rectangle(im0, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
            print(f'p1={p1}, p2={p2}, box={xyxy}')
            tf = max(lw - 1, 1)  # font thickness
            w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]  # text width, height
            outside = p1[1] - h - 3 >= 0  # label fits outside box
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
            cv2.rectangle(im0, p1, p2, color, -1, cv2.LINE_AA)  # filled

            cv2.putText(im0,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        lw / 3,
                        txt_color,
                        thickness=tf,
                        lineType=cv2.LINE_AA)
    # Save results (image with detections)
    status = cv2.imwrite('out.jpg', im0)

Here is the console output:

384 640
p1=(35734, 62139), p2=(80772, 81189), box=[tensor(55.8359), tensor(161.8212), tensor(126.2065), tensor(211.4310), tensor(0.6836)]
p1=(382147, 34822), p2=(406332, 46585), box=[tensor(597.1053), tensor(90.6842), tensor(634.8947), tensor(121.3158), tensor(0.5312)]
p1=(215721, 42492), p2=(245078, 56580), box=[tensor(337.0656), tensor(110.6562), tensor(382.9344), tensor(147.3438), tensor(0.5781)]
p1=(60641, 48579), p2=(93330, 62780), box=[tensor(94.7518), tensor(126.5102), tensor(145.8289), tensor(163.4898), tensor(0.5859)]
p1=(378723, 131992), p2=(409756, 147559), box=[tensor(591.7553), tensor(343.7306), tensor(640.2447), tensor(384.2694), tensor(0.6133)]
p1=(105122, 53227), p2=(140637, 71956), box=[tensor(164.2532), tensor(138.6126), tensor(219.7468), tensor(187.3874), tensor(0.6016)]
p1=(275314, 54436), p2=(316045, 71688), box=[tensor(430.1792), tensor(141.7620), tensor(493.8208), tensor(186.6888), tensor(0.6797)]
p1=(439, 81220), p2=(61379, 106883), box=[tensor(0.6863), tensor(211.5111), tensor(95.9061), tensor(278.3429), tensor(0.7500)]
p1=(165981, 58546), p2=(200098, 74965), box=[tensor(259.3466), tensor(152.4642), tensor(312.6534), tensor(195.2239), tensor(0.7422)]
p1=(222846, 53525), p2=(253805, 68049), box=[tensor(348.1972), tensor(139.3899), tensor(396.5718), tensor(177.2129), tensor(0.7656)]
p1=(92734, 73943), p2=(137665, 97193), box=[tensor(144.8983), tensor(192.5620), tensor(215.1017), tensor(253.1082), tensor(0.7773)]
p1=(-507, 109050), p2=(39503, 144946), box=[tensor(-0.7928), tensor(283.9846), tensor(61.7244), tensor(377.4661), tensor(0.7969)]
p1=(63420, 103860), p2=(126157, 135396), box=[tensor(99.0947), tensor(270.4710), tensor(197.1204), tensor(352.5956), tensor(0.8125)]
p1=(260474, 128021), p2=(337878, 147164), box=[tensor(406.9920), tensor(333.3892), tensor(527.9354), tensor(383.2409), tensor(0.7930)]
p1=(318159, 91825), p2=(398849, 127253), box=[tensor(497.1239), tensor(239.1289), tensor(623.2028), tensor(331.3900), tensor(0.8164)]
p1=(225537, 69642), p2=(271102, 91004), box=[tensor(352.4025), tensor(181.3601), tensor(423.5975), tensor(236.9915), tensor(0.8164)]
p1=(164172, 77055), p2=(207027, 100864), box=[tensor(256.5191), tensor(200.6649), tensor(323.4809), tensor(262.6670), tensor(0.8281)]
p1=(285309, 70511), p2=(337195, 92304), box=[tensor(445.7965), tensor(183.6239), tensor(526.8674), tensor(240.3760), tensor(0.8516)]
p1=(149750, 108920), p2=(215432, 147192), box=[tensor(233.9858), tensor(283.6476), tensor(336.6125), tensor(383.3141), tensor(0.8711)]
p1=(231839, 95786), p2=(298080, 128469), box=[tensor(362.2498), tensor(249.4452), tensor(465.7502), tensor(334.5547), tensor(0.8828)]```

@fishuke
Copy link

fishuke commented Jul 18, 2022

My problem solved by following new fresh start as @RobinFrcd said. And fixing the code above by removing this line:
p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))

@jluntamazon
Copy link
Contributor

Hi @fishuke, it’s good to see that you were able to get everything working!

I’m going to close this issue now that everything appears to be resolved, but please feel free to open a new issue if any other problems come up.

@hadilou
Copy link

hadilou commented Aug 5, 2022

My problem solved by following new fresh start as @RobinFrcd said. And fixing the code above by removing this line: p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))

Yes. Because the network's outputs are already scaled. No need to scale them anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants