×

签到

分享到微信

打开微信,使用扫一扫进入页面后,点击右上角菜单,

点击“发送给朋友”或“分享到朋友圈”完成分享

参数 已解决 sys2021-10-10 10:00:04 回复 1 查看 技术答疑 使用求助
参数
分享到:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_mlu
import torch_mlu.core.mlu_model as ct
import torch_mlu.core.mlu_quantize as mlu_quantize
import torchvision.models as models
from common_utils import Timer
import logging

ct.set_core_number(1)          #设置MLU core number
ct.set_core_version("MLU270")  #设置MLU core version
torch.set_grad_enabled(False)  #设置输入图片的通道顺序,以决定首层卷积对三通道输入的补齐通道顺序。默认是RGBA顺序。

logging.basicConfig(level=logging.INFO,
                    format=
                    '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger("TestNets")


class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      #input_channel output_channel size  stride
      self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1)

    def forward(self, x):
      x = self.conv1(x)
      x = F.relu(x)
      x = torch.flatten(x, 1)
      output = F.log_softmax(x, dim=1)
      return output


#随即保存初始的权重
net=Net().eval()
torch.save(net.state_dict(), 'test_weights.pth')
#输入数据:1:batch 1:channel 28:height 28:width
input_data = torch.rand((5, 64, 254, 254), dtype=torch.float)
#利用模型量化工具对模型的权重进行量化,并保存量化后的权重
net.load_state_dict(torch.load('test_weights.pth', map_location='cpu'), False)
net_quantization = mlu_quantize.quantize_dynamic_mlu(net, { 'firstconv':False}, dtype='int8', gen_quant=True)
output = net_quantization(input_data)
torch.save(net_quantization.state_dict(), 'test_quantization.pth')

total_e2e = 0
total_hardware = 0
batch_size=1
img_num=1

#step1  加载权重
net_quantization = mlu_quantize.quantize_dynamic_mlu(net)
net_quantization.load_state_dict(torch.load('test_quantization.pth'))

#step2  权重加载到mlu上
net_mlu = net_quantization.to(ct.mlu_device())
#step3  输入数据加载到mlu上
input_mlu = input_data.to(ct.mlu_device())
#step4  执行,输出
timer = Timer()
timer1 = Timer()
output=net_mlu(input_mlu)
total_hardware += timer1.elapsed()
total_e2e += timer.elapsed()

logger.info('latency: '+ str(batch_size / (img_num/total_hardware) * 1000))

logger.info('throughput: '+ str(img_num/total_e2e))


您好,我想请问如果想计算latency和throughput,代码是否可以这样写呢,throughput的输出的含义是什么呢

版权所有 © 2024 寒武纪 Cambricon.com 备案/许可证号:京ICP备17003415号-1
关闭