from functools import partial from types import SimpleNamespace as SN import torch import torch as th import torch.nn as nn import torch_mlu import torch_mlu.core.mlu_model as ct import torch_mlu.core.mlu_quantize as mlu_quantize import torchvision.models as models ct.set_core_number(1) ct.set_core_version("MLU270") torch.set_grad_enabled(False) class RnnAgent(nn.Module): def __init__(self, obs_shape, n_actions, args): super(RnnAgent, self).__init__() self._n_layers = args.n_layers self._hidden_size = args.hidden_size layers = [nn.Linear(obs_shape, self._hidden_size), nn.ReLU()] for l in range(self._n_layers - 1): layers += [nn.Linear(self._hidden_size, self._hidden_size), nn.ReLU()] self.enc = nn.Sequential(*layers) self.rnn = nn.GRUCell(self._hidden_size, self._hidden_size) self.f_out = nn.Linear(self._hidden_size, n_actions) def init_hidden(self): return th.zeros(1, self._hidden_size) def forward(self, x): y = self.enc(x['obs']) h = self.rnn(y, x['h']) y = self.f_out(h) return y, h if __name__ == '__main__': device = torch.device('cpu') # print(device) args = SN() args.n_layers = 2 args.hidden_size = 256 obs_shape = 82 n_actions = 5 policy_net = RnnAgent(obs_shape=obs_shape, n_actions=n_actions, args=args).to(device) target_net = RnnAgent(obs_shape=obs_shape, n_actions=n_actions, args=args).to(device) policy_net.eval() # print(torch.__version__) input_o = torch.rand((1, 82), dtype=torch.float) input_h = torch.rand((1, 256), dtype=torch.float) policy_net.load_state_dict(torch.load('./cp_epoch50.pth', map_location='cpu'), False) net_quantization = mlu_quantize.quantize_dynamic_mlu(policy_net, dtype='int8', gen_quant=True) x = {'obs':input_o, 'h':input_h} output = net_quantization(x) x, h = output # CPU quantization infer # print(type(x), type(h)) # print(x) torch.save(net_quantization.state_dict(), 'policyNet_cp_quantization.pth') print(ct.mlu_device()) net_quantization = mlu_quantize.quantize_dynamic_mlu(policy_net) net_quantization.load_state_dict(torch.load('policyNet_cp_quantization.pth'), False) net_quantization.to(ct.mlu_device()) input_h_mlu = input_h.to(ct.mlu_device()) input_o_mlu = input_o.to(ct.mlu_device()) x_mlu = {'obs':input_o_mlu, 'h':input_h_mlu} # MLU layer-by-layer infer output = net_quantization(x_mlu) x, h = output print(type(x), type(h)) print(x.cpu()) # fusion infer ct.save_as_cambricon("policyNet_offline") traced_model = torch.jit.trace(net_quantization, x_mlu, check_trace=False) x, h = traced_model(x_mlu) print("-------------_+++++++++++++++++++++++++-----------------------------------++++++++++++++++++++++++++++++++++++++++++++++++++=--------------------------", x.cpu())
我的代码
请登录后评论