×
分享到微信

打开微信,使用扫一扫进入页面后,点击右上角菜单,

点击“发送给朋友”或“分享到朋友圈”完成分享

NRAM中定义数组过大时,运行出错(不是编译链接出错) 官方回复 hxf02232021-06-29 15:05:54 回复 6 查看 使用求助 经验交流
NRAM中定义数组过大时,运行出错(不是编译链接出错)
分享到:

#include #include #include #include #include #if 1


#define BLOCK_SIZE      (1024*16)
#define CORE_DIM        4
#define STAGE           2


__mlu_func__ void kernel_row_mul_addb(bang::pipeline& pipe, 
    float* __restrict__ nC, float* __restrict__ nA, 
    float* __restrict__ nB, int count) {
    /*for (int i = 0; i < count; i++) {
        nC[i] += nA[i] * nB[i];
    }*/
    __nram__ float temp_buff[BLOCK_SIZE];
    bang::vector::mul(pipe, temp_buff, nA, nB, count);
    bang::vector::add(pipe, nC, nC, temp_buff, count);
}

__mlu_func__ void kernel_row_mul_add(float* __restrict__ gA, 
                                        float* __restrict__ gB, 
                                        float* __restrict__ gC, 
                                        int colNum, int rowNum, int rowA1, int rowA2) {
    //__mlu_shared__ float shr_buffb[STAGE * CORE_DIM * BLOCK_SIZE];
    //__mlu_shared__ float shr_buffc[CORE_DIM * BLOCK_SIZE];
    __nram__ float nbuffa[BLOCK_SIZE];
    __nram__ float nbuffc[BLOCK_SIZE];
    __nram__ float nbuffb[BLOCK_SIZE*STAGE];
    bang::pipeline pipe0(0);
    bang::pipeline pipe1(1);
    bang::pipeline pipe4(2);
    bang::pipeline pipe5(3);

    for ( int row = rowA1; row < rowA2; row++ ) {
        size_t cp_sizeb = colNum * sizeof(float);
        __bang_write_zero(nbuffc, BLOCK_SIZE);

        bang::memcpy_async(pipe4, nbuffa, 
            gA+row*colNum, cp_sizeb, GDRAM2NRAM);
        pipe4.wait_copy_dram_to_nram();

        for (int rowb = 0; rowb = rowNum) return;
    if (end_row > rowNum) end_row = rowNum;

    kernel_row_mul_add(gA, gB, gC, colNum, rowNum, start_row, end_row);
}


int matrixMul(float* ptrC, float* ptrA, float* ptrB, int rowNum, int colNum) {
    float* deva, *devb, *devc;
    constexpr int loops = 1;

    cnrtQueue_t queue;
    CNRT_CHECK( cnrtCreateQueue(&queue) );

    const size_t mem_size = rowNum * colNum * sizeof(float);
    CNRT_CHECK( cnrtMalloc((void**)(&deva), mem_size) );
    CNRT_CHECK( cnrtMalloc((void**)(&devb), mem_size) );
    CNRT_CHECK( cnrtMalloc((void**)(&devc), mem_size) );

    CNRT_CHECK( cnrtMemcpy(deva, ptrA, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
    CNRT_CHECK( cnrtMemcpy(devb, ptrB, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
    CNRT_CHECK( cnrtMemcpy(devc, ptrC, mem_size, CNRT_MEM_TRANS_DIR_HOST2DEV) );
    /*CNRT_CHECK( cnrtMemcpyAsync(deva, ptrA, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );
    CNRT_CHECK( cnrtMemcpyAsync(devb, ptrB, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );
    CNRT_CHECK( cnrtMemcpyAsync(devc, ptrC, mem_size, queue, CNRT_MEM_TRANS_DIR_HOST2DEV) );*/

    // Create cnrt event handles.
    cnrtNotifier_t start, stop;
    CNRT_CHECK( cnrtCreateNotifier(&start) );
    CNRT_CHECK( cnrtCreateNotifier(&stop) );

    // Record timer start
    CNRT_CHECK( cnrtPlaceNotifier(start, queue) );

    cnrtDim3_t dim = {16, 1, 1};
    cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_UNION4;
    for (int i = 0; i < loops; i++) {
        kernelMM<<>>(deva, devb, devc, rowNum, colNum);
    }

    CNRT_CHECK( cnrtSyncQueue(queue) );

    // Record timer stop
    CNRT_CHECK( cnrtPlaceNotifier(stop, queue) );

    
    CNRT_CHECK( cnrtMemcpy(ptrC, devc, mem_size, CNRT_MEM_TRANS_DIR_DEV2HOST) );

    float usecTotal = 0.0f;
    CNRT_CHECK(cnrtNotifierDuration(start, stop, &usecTotal));
    printf("time spent executing by the MLU: %.3f ms\n", usecTotal / (1000*loops) );

    CNRT_CHECK( cnrtFree((void*)deva) );
    CNRT_CHECK( cnrtFree((void*)devb) );
    CNRT_CHECK( cnrtFree((void*)devc) );
    CNRT_CHECK( cnrtDestroyQueue(queue) );

    return 0;
}




int bang_dev_init() {
    cnrtDev_t dev;
    CNRT_CHECK(cnrtInit(0));
    CNRT_CHECK(cnrtGetDeviceHandle(&dev, 0));
    CNRT_CHECK(cnrtSetCurrentDevice(dev));
    return 0;
}


int bang_dev_deinit() {
    cnrtDestroy();
    return 0;
}


#endif


版权所有 © 2021 寒武纪 Cambricon.com 备案/许可证号:京ICP备17003415号-1
关闭