40#include <hip/hip_runtime.h>
43#include <neko/device/device_config.h>
44#include <neko/device/hip/check.h>
45#include <neko/math/bcknd/device/device_mpi_op.h>
46#include <neko/math/bcknd/device/device_mpi_reduce.h>
49#include "math_ext_kernel.h"
56void hip_copy_mask(
void* a,
void* b,
int* size,
int* mask,
int* mask_size) {
57 const dim3 nthrds(1024, 1, 1);
58 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
60 if(*mask_size == 0)
return;
61 hipLaunchKernelGGL(copy_mask_kernel<real>, nblcks, nthrds, 0,
62 (hipStream_t)glb_cmd_queue,
63 (real*)a, (real*)b, *size, mask, *mask_size);
64 HIP_CHECK(hipGetLastError());
70void hip_cadd_mask(
void* a, real* c,
int* size,
int* mask,
int* mask_size) {
71 const dim3 nthrds(1024, 1, 1);
72 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
74 if(*mask_size == 0)
return;
75 hipLaunchKernelGGL(cadd_mask_kernel<real>, nblcks, nthrds, 0,
76 (hipStream_t)glb_cmd_queue, (real*)a, *c, *size, mask, *mask_size);
77 HIP_CHECK(hipGetLastError());
83void hip_invcol1_mask(
void* a,
int* size,
int* mask,
int* mask_size) {
84 const dim3 nthrds(1024, 1, 1);
85 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
87 if(*mask_size == 0)
return;
88 hipLaunchKernelGGL(invcol1_mask_kernel<real>, nblcks, nthrds, 0,
89 (hipStream_t)glb_cmd_queue, (real*)a, *size, mask, *mask_size);
90 HIP_CHECK(hipGetLastError());
96void hip_col2_mask(
void* a,
void* b,
int* size,
int* mask,
int* mask_size) {
97 const dim3 nthrds(1024, 1, 1);
98 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
100 if(*mask_size == 0)
return;
101 hipLaunchKernelGGL(col2_mask_kernel<real>, nblcks, nthrds, 0,
102 (hipStream_t)glb_cmd_queue,
103 (real*)a, (real*)b, *size, mask, *mask_size);
104 HIP_CHECK(hipGetLastError());
111 void* a,
void* b,
void* c,
int* size,
int* mask,
int* mask_size) {
113 const dim3 nthrds(1024, 1, 1);
114 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
116 if(*mask_size == 0)
return;
117 hipLaunchKernelGGL(col3_mask_kernel<real>, nblcks, nthrds, 0,
118 (hipStream_t)glb_cmd_queue,
119 (real*)a, (real*)b, (real*)c, *size, mask, *mask_size);
120 HIP_CHECK(hipGetLastError());
127 void* a,
void* b,
void* c,
int* size,
int* mask,
int* mask_size) {
129 const dim3 nthrds(1024, 1, 1);
130 const dim3 nblcks(((*mask_size) + 1024 - 1) / 1024, 1, 1);
132 if(*mask_size == 0)
return;
133 hipLaunchKernelGGL(sub3_mask_kernel<real>, nblcks, nthrds, 0,
134 (hipStream_t)glb_cmd_queue,
135 (real*)a, (real*)b, (real*)c, *size, mask, *mask_size);
136 HIP_CHECK(hipGetLastError());