打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_mlu
import torch_mlu.core.mlu_model as ct
import torch_mlu.core.mlu_quantize as mlu_quantize
import torchvision.models as models
from common_utils import Timer
import logging
ct.set_core_number(1) #设置MLU core number
ct.set_core_version("MLU270") #设置MLU core version
torch.set_grad_enabled(False) #设置输入图片的通道顺序,以决定首层卷积对三通道输入的补齐通道顺序。默认是RGBA顺序。
logging.basicConfig(level=logging.INFO,
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger("TestNets")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#input_channel output_channel size stride
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = torch.flatten(x, 1)
output = F.log_softmax(x, dim=1)
return output
#随即保存初始的权重
net=Net().eval()
torch.save(net.state_dict(), 'test_weights.pth')
#输入数据:1:batch 1:channel 28:height 28:width
input_data = torch.rand((5, 64, 254, 254), dtype=torch.float)
#利用模型量化工具对模型的权重进行量化,并保存量化后的权重
net.load_state_dict(torch.load('test_weights.pth', map_location='cpu'), False)
net_quantization = mlu_quantize.quantize_dynamic_mlu(net, { 'firstconv':False}, dtype='int8', gen_quant=True)
output = net_quantization(input_data)
torch.save(net_quantization.state_dict(), 'test_quantization.pth')
total_e2e = 0
total_hardware = 0
batch_size=1
img_num=1
#step1 加载权重
net_quantization = mlu_quantize.quantize_dynamic_mlu(net)
net_quantization.load_state_dict(torch.load('test_quantization.pth'))
#step2 权重加载到mlu上
net_mlu = net_quantization.to(ct.mlu_device())
#step3 输入数据加载到mlu上
input_mlu = input_data.to(ct.mlu_device())
#step4 执行,输出
timer = Timer()
timer1 = Timer()
output=net_mlu(input_mlu)
total_hardware += timer1.elapsed()
total_e2e += timer.elapsed()
logger.info('latency: '+ str(batch_size / (img_num/total_hardware) * 1000))
logger.info('throughput: '+ str(img_num/total_e2e))
您好,我想请问如果想计算latency和throughput,代码是否可以这样写呢,throughput的输出的含义是什么呢
热门帖子
精华帖子