std::tuple<at::Tensor, at::Tensor> OpMethods::nll_loss_forward(
const at::Tensor& self, const at::Tensor& target, const at::Tensor& weight,
int64_t reduction, int64_t ignore_index) {
TRACE_INFO(__FUNCTION__);
auto self_cpu = self.cpu();
auto target_cpu = target.cpu();
auto weight_cpu = weight.cpu();
auto output_cpu = at::nll_loss_forward(self_cpu, target_cpu, weight_cpu, reduction, ignore_index);
return std::make_tuple(std::get<0>(output_cpu).to(at::Device(at::Device::Type::MLU)),
std::get<1>(output_cpu).to(at::Device(at::Device::Type::MLU)));
}
请登录后评论