打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
该量化后的模型可以在cpu上完成推理,在MLU上推理量化模型时报错,报错源码是在后处理锚框时postprocess()出现,源码如下,x为CV2读取照片的值,错误为索引错误,
def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold): transformed_anchors = regressBoxes(anchors, regression) transformed_anchors = clipBoxes(transformed_anchors, x) scores = torch.max(classification, dim=2, keepdim=True)[0] scores_over_thresh = (scores > threshold)[:, :, 0] out = [] for i in range(x.shape[0]): if scores_over_thresh[i].sum() == 0: out.append({ 'rois': np.array(()), 'class_ids': np.array(()), 'scores': np.array(()), }) continue #a = torch.zeros(49104) #scores_over_thresh = scores_over_thresh(dtype=torch.long) #socres_over_thresh[i, :] classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0) #出错 transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...] scores_per = scores[i, scores_over_thresh[i, :], ...] scores_, classes_ = classification_per.max(dim=0) anchors_nms_idx = batched_nms(transformed_anchors_per, scores_per[:, 0], classes_, iou_threshold=iou_threshold)
调试时发现scores_over_thresh[i, :]的值为49104个零,是float类型,尝试转换类型也不起作用,而且mlu上也不支持long类型,请问这个错该如何解决呢,如何把它变为mlu支持的正确的索引呢?
热门帖子
精华帖子