×

签到

分享到微信

打开微信,使用扫一扫进入页面后,点击右上角菜单,

点击“发送给朋友”或“分享到朋友圈”完成分享

【经验总结】Yolov5移植流程指南 徐浩然2021-08-03 14:58:26 回复 23 查看 经验交流
【经验总结】Yolov5移植流程指南
分享到:

本篇文章旨在以个人移植yolov5的经验,向有需要的开发者提供一个参考。由于yolov5至今已有多个版本,我们仅以最新的v5版本作为示例进行介绍,但是实际上多个版本的移植流程都比较相似,可以互为借鉴。


1.模型解压

  由于寒武纪提供的sdk中的pytoch版本一般为1.3,而yolov5使用了1.7或更高版本的pytorch,导致保存的pt文件不能直接被sdk中的pytorch读取,因此需要先将模型进行解压。

  修改models/experimental.py 中的attempt_load 函数:


def attempt_load(weights, map_location=None, inplace=True):
    from models.yolo import Detect, Model
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
    model = Ensemble()
    for w in weights if isinstance(weights, list) else [weights]:
    temp_model = torch.load(w,  map_location='cpu')['model']
  torch.save(temp_model.state_dict(),"./unzip.pt",_use_new_zipfile_serialization=False)
   ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
  model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())

 然后进行一次模型载入,即可获得解压后的模型文件。也可以修改train.py中保存模型时的参数,免去再次适配时进行解压的操作。


2.注册SiLU/HardSwish 激活函数


  由于pytorch1.3 中并没有yolov5使用到的SiLU/HardSwish激活函数,因此有些版本需要手动在/your/python/path/torch/nn/modules/activation.py中添加定义:


class Hardswish(Module):
  @staticmethod
  def forward(x):
        return x * F.hardtanh(x + 3, 0., 6.) / 6.
class SiLU(Module): 
  @staticmethod
  def forward(x):
        return x * torch.sigmoid(x)

以及在/your/python/path/torch/nn/modules/__init__.py中,在from .activation import 中添加SiLU, Hardswish,, 在__all__ = 中添加 ’SiLU’,  ‘Hardswish’即可。

3. 注释干扰代码

  有些代码在寒武纪sdk环境中会报错,但不影响程序运行,因此我们先将其注释。包括models/common.py中的from torch.cuda import amp,以及requirements.txt 中对torch和torchvision版本要求。

4.添加后处理大算子

  最新的寒武纪sdk中,yolov5的后处理部分的大算子集成在了pytorch框架内,因此可以直接使用该算子进行后处理,提升模型性能。

  修改models/yolo.py中class Detect的forward函数:

def forward(self, x):
    # x = x.copy()  # for profiling
    z = []  # inference output
    if x[0].device.type == 'mlu':
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            y = x[i].sigmoid()
            z.append(y)
            anchors_list = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
            num_anchors = len(anchors_list)
            img_h = 640
            img_w = 640
            conf_thres = 0.45
            iou_thres = 0.5max
            BoxNum = 1024
            detect_out = torch.ops.torch_mlu.yolov5_detection_output(z[0], z[1], z[2],
                              anchors_list, self.nc, num_anchors,
                              img_h, img_w, conf_thres, iou_thres, maxBoxNum)
      return detect_out
      for i in range(self.nl):
          x[i] = self.m[i](x[i])  # conv
      ……

  其中yolov5_detection_output算子的各个参数可依据实际需求进行修改。


 


5. 量化模型并生成离线模型


  量化和生成离线模型时,需要注意不能使用


model = attempt_load(weights, map_location=device)


的方式来获取模型了,因为我们需要使用未压缩的pth文件导入权重,因此需要修改为:


model = Model('/path/to/your.yaml')

state_dict = torch.load(‘unzip.pt’, map_location='cpu')

model.load_state_dict(state_dict)


剩余过程和其他pytorch的模型相似,在此就不再赘述了。


 


6.后处理


  大算子输出的结果包含了框的数量及相关信息,后处理的过程如下:



def get_boxes(prediction, batch_size=1):
    """
    Returns detections with shape:
    (x1, y1, x2, y2,  _conf, class)
    """
    reshape_value = torch.reshape(prediction, (-1, 1))
    num_boxes_final = reshape_value[0].item()
    print('num_boxes_final: ',num_boxes_final)
    all_list = [[] for _ in range(batch_size)]
    for i in range(int(num_boxes_final)):
        batch_idx = int(reshape_value[64 + i * 7 + 0].item())
        if batch_idx >= 0 and batch_idx  0 and bb -br > 0:
            all_list[batch_idx].append(bl)
            all_list[batch_idx].append(br)
            all_list[batch_idx].append(bt)all_
            list[batch_idx].append(bb)
            all_list[batch_idx].append(reshape_value[64 + i * 7 + 2].item())
            all_list[batch_idx].append(reshape_value[64 + i * 7 + 1].item())
    outputs = [torch.FloatTensor(all_list[i]).reshape(-1, 6) for i in range(batch_size)]
    return outputs


使用C++部署离线模型时,后处理也可以仿照实现。


 


以上就是yolov5移植的整个过程。


版权所有 © 2024 寒武纪 Cambricon.com 备案/许可证号:京ICP备17003415号-1
关闭