×

签到

分享到微信

打开微信,使用扫一扫进入页面后,点击右上角菜单,

点击“发送给朋友”或“分享到朋友圈”完成分享

【Pytorch Yolov7】Yolov7模型寒武纪200移植分享(下篇) MyAI2023-01-05 14:07:57 回复 查看 技术答疑 干货资源
【Pytorch Yolov7】Yolov7模型寒武纪200移植分享(下篇)
分享到:

上篇https://forum.cambricon.com/show-3-2292-1.html 转好模型之后就可以在MLU侧进行验证。

基于MLU验证

模型量化

代码修改(增加mlu/mlu_quant.py 文件)

## 模型量化
---
    # model = attempt_load(weights, map_location=device)  # load FP32 model
    from models.yolo import get_model
    model = get_model(opt)
    # print(model)

    # 配置量化参数
    import torch_mlu
    import torch_mlu.core.mlu_model as ct
    import torch_mlu.core.mlu_quantize as mlu_quantize
    qconfig={'use_avg':False, 'data_scale':1.0, 'firstconv':False, 'per_channel': False}
    # 调用量化接口
    quantized_net = mlu_quantize.quantize_dynamic_mlu(model,qconfig_spec=qconfig, dtype='int8', gen_quant=True)
    # 设置为推理模式
    quantized_net = quantized_net.eval().float()

    model = quantized_net
---

  # 保存量化模型

    qua_weight = opt.qua_weight
    print("SAVE quantize model:",qua_weight)
    torch.save(model.state_dict(),qua_weight)

---
#增加qua_weight参数
    parser.add_argument('--qua_weight', type=str,default='yolov7_intx.pth', help='model.pt path(s)')

运行命令

python mlu/mlu_quant.py --weights mlu/weight/yolov7_unzip.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --qua_weight mlu/weight/yolov7_intx.pth --no-trace

基于MLU运行量化后的模型

代码修改参看(增加mlu/mlu_detect.py 文件)

 #模型加载
   from models.yolo import get_model
   model = get_model(opt)
   # print(model)

   import torch_mlu
   import torch_mlu.core.mlu_model as ct
   import torch_mlu.core.mlu_quantize as mlu_quantize

   from postprocess import MLU_PostProcessYoloV7,PostProcessPytorchYoloV7,draw_image
   from models.yolo import get_empty_model
   model = get_empty_model(opt)

   stride = 32

   # stride = int(model.stride.max())  # model stride
   imgsz = check_img_size(imgsz, s=stride)  # check img_size

   #配置 MLU core number
   ct.set_core_number(opt.core_number)
   # 设置输入图片的通道顺序,以决定首层卷积对三通道输入的补齐通道顺序。默认是 RGBA 顺序
   #ct.set_input_format(0)
   #配置MLU core类型
   ct.set_core_version(opt.mcore)
   torch.set_grad_enabled(False)

   if opt.fake_device:
       print("fake_device mode")
       ct.set_device(-1)

   device = ct.mlu_device()
   print("run on %s ..."%device)
   # 加载量化模型
   weight = weights[0]
   quantized_net = torch_mlu.core.mlu_quantize.quantize_dynamic_mlu(model)
   print('weight:',weight)
   state_dict = torch.load(weight)
   quantized_net.load_state_dict(state_dict, strict=False)
   # 设置为推理模式
   quantized_net = quantized_net.eval().float()
   quantized_net.to(device)

   model = quantized_net

   # 设置在线融合模式
   if opt.jit:
       if opt.save:
           ct.save_as_cambricon(opt.mname)

       example = torch.randn(opt.batch_size, 3, imgsz, imgsz,dtype=torch.float)
       trace_input = torch.randn(1, 3, imgsz, imgsz,dtype=torch.float)

       if opt.half_input:
           print('half_input ')
           trace_input = trace_input.type(torch.HalfTensor)
           example = example.type(torch.HalfTensor)

       print("jit trace example shape",example.shape)
       model = torch.jit.trace(model,trace_input.to(device),check_trace=False)

       #如果是生成离线模型,推理一次,直接退出,会保存离线模型
       if opt.save:
           print("save offline model mname: ",opt.mname)
           model(example.to(device))
           ct.save_as_cambricon('')
           exit(0)
   if opt.mlu_det:
       postproc = MLU_PostProcessYoloV7()
   else:
       postproc = PostProcessPytorchYoloV7(conf_thres=opt.conf_thres,iou_thres=opt.iou_thres)
   names = postproc.names
 
 ---
 #推理及后处理部分
 ---
       # Inference
       t1 = time_synchronized()
       with torch.no_grad():   # Calculating gradients would cause a GPU memory leak
           detect_out = model(img)
           # pred = model(img, augment=opt.augment)[0]
       t2 = time_synchronized()
       if len(detect_out) == 1:
           pred = detect_out.cpu().type(torch.FloatTensor) if opt.half_input else detect_out.cpu()
       else:
           pred = [out.cpu().type(torch.FloatTensor) for out in detect_out]
       # print("mlu pred:{} {}".format(type(pred),pred))

       # from mlu.tools.dump_npy import save_npy
       # save_npy(pred,"mlu")
       # # from postprocess import draw_image
       if opt.mlu_det:
           pred = postproc.get_boxes(pred)
           p, s, im0, = path, '', im0s, getattr(dataset, ' ', 0)
           p = Path(p)  # to Path
           save_path = str(save_dir / p.name)  # img.jpg
           print(save_path)
           draw_image(pred, img, im0s, path, save_path, names)
           exit(0)
           
       pred = postproc.yolo_det(pred)[0]


备注:后处理参看postprocess.py,mlu推理代码参看mlu_detect.py

运行命令(MLU+CPU NMS)

#逐层运行+CPU NMS
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml
# 融合模型运行+CPU NMS
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit
#生成离线模型
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit --mcore MLU270 --mname yolov7_4b4c --core 4 --batch 4 --save

运行命令(MLU+mlu Yolo det)

#逐层运行 + MLU yolo detection output
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --mlu_det
#融合模型运行 + MLU yolo detection output
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit --mlu_det
#生成MLU270离线模型 4B4C
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit --mlu_det --mcore MLU270 --mname yolov7_4b4c --core 4 --batch 4 --save

#生成MLU270离线模型 1B4C
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit --mlu_det --mcore MLU270 --mname yolov7_1b4c --core 4 --batch 1 --save

#生成MLU220离线模型 4B4C
python mlu/mlu_detect.py --weights mlu/weight/yolov7_intx.pth --conf 0.25 --img-size 640 --source inference/images/horses.jpg --cfg ./cfg/deploy/yolov7.yaml --jit --mlu_det --mcore MLU220 --mname mlu220_yolov7_1b4c --core 4 --batch 1 --save


版权所有 © 2024 寒武纪 Cambricon.com 备案/许可证号:京ICP备17003415号-1
关闭