打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
#include #include #include #include #include #if 1
#define BLOCK_SIZE (1024*16)
#define CORE_DIM 4
#define STAGE 2
__mlu_func__ void kernel_row_mul_addb(bang::pipeline& pipe,
float* __restrict__ nC, float* __restrict__ nA,
float* __restrict__ nB, int count) {
/*for (int i = 0; i < count; i++) {
nC[i] += nA[i] * nB[i];
}*/
__nram__ float temp_buff[BLOCK_SIZE];
bang::vector::mul(pipe, temp_buff, nA, nB, count);
bang::vector::add(pipe, nC, nC, temp_buff, count);
}
__mlu_func__ void kernel_row_mul_add(float* __restrict__ gA,
float* __restrict__ gB,
float* __restrict__ gC,
int colNum, int rowNum, int rowA1, int rowA2) {
//__mlu_shared__ float shr_buffb[STAGE * CORE_DIM * BLOCK_SIZE];
//__mlu_shared__ float shr_buffc[CORE_DIM * BLOCK_SIZE];
__nram__ float nbuffa[BLOCK_SIZE];
__nram__ float nbuffc[BLOCK_SIZE];
__nram__ float nbuffb[BLOCK_SIZE*STAGE];
bang::pipeline pipe0(0);
bang::pipeline pipe1(1);
bang::pipeline pipe4(2);
bang::pipeline pipe5(3);
for ( int row = rowA1; row < rowA2; row++ ) {
size_t cp_sizeb = colNum * sizeof(float);
__bang_write_zero(nbuffc, BLOCK_SIZE);
bang::memcpy_async(pipe4, nbuffa,
gA+row*colNum, cp_sizeb, GDRAM2NRAM);
pipe4.wait_copy_dram_to_nram();
for (int rowb = 0; rowb = rowNum) return;
if (end_row > rowNum) end_row = rowNum;
kernel_row_mul_add(gA, gB, gC, colNum, rowNum, start_row, end_row);
}
int matrixMul(float* ptrC, float* ptrA, float* ptrB, int rowNum, int colNum) {
float* deva, *devb, *devc;
constexpr int loops = 1;
cnrtQueue_t queue;
CNRT_CHECK( cnrtCreateQueue(&queue) );
const size_t mem_size = rowNum * colNum * sizeof(float);
CNRT_CHECK( cnrtMalloc((void**)(&deva), mem_size) );
CNRT_CHECK( cnrtMalloc((void**)(&devb), mem_size) );
CNRT_CHECK( cnrtMalloc((void**)(&devc), mem_size) );
CNRT_CHECK( cnrtMemcpy(deva, ptrA, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
CNRT_CHECK( cnrtMemcpy(devb, ptrB, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
CNRT_CHECK( cnrtMemcpy(devc, ptrC, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
/*CNRT_CHECK( cnrtMemcpyAsync(deva, ptrA, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );
CNRT_CHECK( cnrtMemcpyAsync(devb, ptrB, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );
CNRT_CHECK( cnrtMemcpyAsync(devc, ptrC, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );*/
// Create cnrt event handles.
cnrtNotifier_t start, stop;
CNRT_CHECK( cnrtCreateNotifier(&start) );
CNRT_CHECK( cnrtCreateNotifier(&stop) );
// Record timer start
CNRT_CHECK( cnrtPlaceNotifier(start, queue) );
cnrtDim3_t dim = {16, 1, 1};
cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION4;
for (int i = 0; i < loops; i++) {
kernelMM<<>>(deva, devb, devc, rowNum, colNum);
}
CNRT_CHECK( cnrtSyncQueue(queue) );
// Record timer stop
CNRT_CHECK( cnrtPlaceNotifier(stop, queue) );
CNRT_CHECK( cnrtMemcpy(ptrC, devc, mem_size, CNRT_MEM_TRANS_DIR_DEV2HOST) );
float usecTotal = 0.0f;
CNRT_CHECK(cnrtNotifierDuration(start, stop, &usecTotal));
printf("time spent executing by the MLU: %.3f ms\n", usecTotal / (1000*loops) );
CNRT_CHECK( cnrtFree((void*)deva) );
CNRT_CHECK( cnrtFree((void*)devb) );
CNRT_CHECK( cnrtFree((void*)devc) );
CNRT_CHECK( cnrtDestroyQueue(queue) );
return 0;
}
int bang_dev_init() {
cnrtDev_t dev;
CNRT_CHECK(cnrtInit(0));
CNRT_CHECK(cnrtGetDeviceHandle(&dev, 0));
CNRT_CHECK(cnrtSetCurrentDevice(dev));
return 0;
}
int bang_dev_deinit() {
cnrtDestroy();
return 0;
}
#endif热门帖子
精华帖子