打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
import torch
import torch_mlu
import torch_mlu.core.mlu_quantize as mlu_quantize
import torch_mlu.core.mlu_model as ct
from torch_mlu.core.device.queue import Queue
from torch_mlu.core.mlu_model import synchronize
from common_utils import Timer
import threading
import time
import logging
from models.AlexNet import *
from models.Vgg11 import *
from models.mobilenet import *
from models.squeezenet import *
from models.ResNet50 import *
from models.googlenet import *
from models.densenet import *
ct.set_core_version("MLU270") #设置MLU core version
logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger("TestNets")
queue1 = Queue(-1)
def task1(queue1):
#设置MLU core number
ct.set_core_number(1)
total_e2e = 0
total_hardware = 0
batch_size=1
img_num=1
#随即保存初始的权重
net=AlexNet().eval()
torch.save(net.state_dict(), 'test_weights.pth')
#输入数据:1:batch 1:channel 28:height 28:width
input_data = torch.rand((1, 3, 256, 256), dtype=torch.float)
#利用模型量化工具对模型的权重进行量化,并保存量化后的权重
mean = [0,0,0]
std = [1/255,1/255,1/255]
net_quantization = mlu_quantize.quantize_dynamic_mlu(net, {'mean':mean, 'std':std, 'firstconv':True}, dtype='int8', gen_quant=True)
net_quantization = mlu_quantize.quantize_dynamic_mlu(net, {'firstconv':False}, dtype='int8', gen_quant=True)
output = net_quantization(input_data)
torch.save(net_quantization.state_dict(), 'test_quantization.pth')
#step1 加载权重
net=AlexNet().eval()
net_quantization = mlu_quantize.quantize_dynamic_mlu(net)
net_quantization.load_state_dict(torch.load('test_quantization.pth'))
#step2 权重加载到mlu上
net_mlu = net_quantization.to(ct.mlu_device())
#step3 输入数据加载到mlu上
input_mlu = input_data.to(ct.mlu_device())
#step4 执行,输出
timer = Timer()
timer1 = Timer()
output=net_mlu(input_mlu)
total_hardware += timer1.elapsed()
total_e2e += timer.elapsed()
logger.info('latency: '+ str(batch_size / (img_num/total_hardware) * 1000))
logger.info('throughput: '+ str(img_num/total_e2e))
if __name__ == '__main__':
total_hardware = 0
timer = Timer()
t1 = threading.Thread(target=task1,args=(queue1,))
t1.start()
synchronize()
total_hardware += timer.elapsed()
logger.info('总耗时:'+ str(total_hardware * 1000))
您好,我想实现队列内任务的同步,但出现了下面的问题,不知道应该怎么修改呢,感谢
热门帖子
精华帖子