40#include <hip/hip_runtime.h>
43#include <neko/device/device_config.h>
44#include <neko/device/hip/check.h>
45#include <neko/math/bcknd/device/device_mpi_op.h>
48#include "mma_kernel.h"
53real * mma_bufred = NULL;
54real * mma_bufred_d = NULL;
56void delta_1dbeam_hip(
void* Delta, real* L_total, real* Le,
57 int* offset,
int* n) {
58 const dim3 nthrds(1024, 1, 1);
59 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
60 hipLaunchKernelGGL(delta_1dbeam_kernel<real>,
61 nblcks, nthrds, 0, (hipStream_t)glb_cmd_queue,
62 (real*)Delta, *L_total, *Le, *offset, *n);
63 HIP_CHECK(hipGetLastError());
66void hip_Hess(
void* Hess,
void* hijx,
void* Ljjxinv,
int *n,
int *m) {
67 const dim3 nthrds(1024, 1, 1);
68 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
69 const int nb = ((*n) + 1024 - 1) / 1024;
70 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
71 hipStreamSynchronize(stream);
75 if (mma_bufred != NULL) {
76 HIP_CHECK(hipHostFree(mma_bufred));
77 HIP_CHECK(hipFree(mma_bufred_d));
79 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
80 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
83 for (
int i = 0; i < (*m); i++) {
84 for (
int j = 0; j < (*m); j++) {
85 hipLaunchKernelGGL(mmasumHess_kernel<real>, nblcks, nthrds, 0, stream,
86 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
87 HIP_CHECK(hipGetLastError());
89 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
91 HIP_CHECK(hipGetLastError());
93 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
94 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
95 HIP_CHECK(hipGetLastError());
97 hipStreamSynchronize(stream);
102void mma_Ljjxinv_hip(
void* Ljjxinv,
void* pjlambda,
void* qjlambda,
void* x,
103 void* low,
void* upp,
void* alpha,
void* beta,
int* n) {
104 const dim3 nthrds(1024, 1, 1);
105 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
106 hipLaunchKernelGGL(mma_Ljjxinv_kernel<real>, nblcks, nthrds, 0,
107 (hipStream_t)glb_cmd_queue, (real*)Ljjxinv, (real*)pjlambda, (real*)qjlambda,
108 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
109 HIP_CHECK(hipGetLastError());
112void mma_dipsolvesub1_hip(
void* x,
void* pjlambda,
void* qjlambda,
void* low,
113 void* upp,
void* alpha,
void* beta,
int* n) {
114 const dim3 nthrds(1024, 1, 1);
115 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
116 hipLaunchKernelGGL(mma_dipsolvesub1_kernel<real>, nblcks, nthrds, 0,
117 (hipStream_t)glb_cmd_queue, (real*)x, (real*)pjlambda, (real*)qjlambda,
118 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
119 HIP_CHECK(hipGetLastError());
122void mattrans_v_mul_hip(
void* output,
void* pij,
void* lambda,
int* m,
int* n) {
123 const dim3 nthrds(1024, 1, 1);
124 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
125 hipLaunchKernelGGL(mattrans_v_mul_kernel<real>, nblcks, nthrds, 0,
126 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
127 HIP_CHECK(hipGetLastError());
130void mma_gensub4_hip(
const void* x,
const void* low,
const void* upp,
131 const void* pij,
const void* qij,
132 const int* n,
const int* m,
void* bi) {
137 const dim3 nthrds(1024, 1, 1);
138 const dim3 nblcks((N + 1023) / 1024, 1, 1);
139 const int nb = (N + 1023) / 1024;
140 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
142 if (nb > mma_red_s) {
145 if (mma_bufred !=
nullptr) {
146 HIP_CHECK(hipFreeHost(mma_bufred));
147 HIP_CHECK(hipFree(mma_bufred_d));
150 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
151 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
155 real* bi_d =
static_cast<real*
>(bi);
156 HIP_CHECK(hipMalloc(&temp, M * N *
sizeof(real)));
158 hipLaunchKernelGGL(mma_sub4_kernel<real>, nblcks, nthrds, 0, stream,
159 static_cast<const real*
>(x),
160 static_cast<const real*
>(low),
161 static_cast<const real*
>(upp),
162 static_cast<const real*
>(pij),
163 static_cast<const real*
>(qij),
166 for (
int i = 0; i < M; ++i) {
167 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
168 temp, mma_bufred_d, N, M, i);
169 HIP_CHECK(hipGetLastError());
171 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
173 HIP_CHECK(hipGetLastError());
175 HIP_CHECK(hipMemcpyAsync(
176 bi_d + i, mma_bufred_d,
sizeof(real),
177 hipMemcpyDeviceToDevice, stream));
179 HIP_CHECK(hipStreamSynchronize(stream));
182 HIP_CHECK(hipFree(temp));
186void mma_gensub3_hip(
void* x,
void* df0dx,
void* dfdx,
void* low,
187 void* upp,
void* xmin,
void* xmax,
void* alpha,
188 void* beta,
void* p0j,
void* q0j,
void* pij,
189 void* qij,
int* n,
int* m) {
190 const dim3 nthrds(1024, 1, 1);
191 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
193 hipLaunchKernelGGL(mma_sub3_kernel<real>, nblcks, nthrds, 0,
194 (hipStream_t)glb_cmd_queue,
195 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
196 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
197 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
200 HIP_CHECK(hipGetLastError());
203void mma_gensub2_hip(
void* low,
void* upp,
void* x,
void* xold1,
204 void* xold2,
void* xdiff, real* asydecr,
205 real* asyincr,
int* n) {
206 const dim3 nthrds(1024, 1, 1);
207 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
209 hipLaunchKernelGGL(mma_sub2_kernel<real>, nblcks, nthrds, 0,
210 (hipStream_t)glb_cmd_queue,
211 (real*)low, (real*)upp, (real*)x, (real*)xold1,
212 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
214 HIP_CHECK(hipGetLastError());
217void mma_gensub1_hip(
void* low,
void* upp,
void* x,
void* xmin,
void* xmax,
218 real* asyinit,
int* n) {
219 const dim3 nthrds(1024, 1, 1);
220 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
222 hipLaunchKernelGGL(mma_sub1_kernel<real>, nblcks, nthrds, 0,
223 (hipStream_t)glb_cmd_queue,
224 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
227 HIP_CHECK(hipGetLastError());
230void hip_mma_max(
void* xsi,
void* x,
void* alpha,
int* n) {
231 const dim3 nthrds(1024, 1, 1);
232 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
234 hipLaunchKernelGGL(mma_max2_kernel<real>, nblcks, nthrds, 0,
235 (hipStream_t)glb_cmd_queue,
236 (real*)xsi, (real*)x, (real*)alpha, *n);
238 HIP_CHECK(hipGetLastError());
241void hip_relambda(
void* relambda,
void* x,
void* xupp,
void* xlow,
242 void* pij,
void* qij,
int* n,
int* m) {
243 const dim3 nthrds(1024, 1, 1);
244 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
245 const int nb = nblcks.x;
246 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
248 if (nb > mma_red_s) {
250 if (mma_bufred != NULL) {
251 HIP_CHECK(hipHostFree(mma_bufred));
252 HIP_CHECK(hipFree(mma_bufred_d));
254 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
255 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
259 hipMalloc(&temp, (*n) * (*m) *
sizeof(real));
261 hipLaunchKernelGGL(relambda_kernel<real>, nblcks, nthrds, 0, stream,
262 temp, (real*)x, (real*)xupp, (real*)xlow,
263 (real*)pij, (real*)qij, *n, *m);
265 for (
int i = 0; i < (*m); i++) {
266 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
267 temp, mma_bufred_d, (*n), (*m), i);
268 HIP_CHECK(hipGetLastError());
270 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0,
271 stream, mma_bufred_d, nb);
272 HIP_CHECK(hipGetLastError());
274 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
275 (real*)relambda, mma_bufred_d, 1, i);
276 HIP_CHECK(hipGetLastError());
278 hipStreamSynchronize(stream);
284void hip_sub2cons2(
void* a,
void* b,
void* c,
void* d, real* e,
int* n) {
285 const dim3 nthrds(1024, 1, 1);
286 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
288 hipLaunchKernelGGL(sub2cons2_kernel<real>, nblcks, nthrds, 0,
289 (hipStream_t)glb_cmd_queue,
290 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
292 HIP_CHECK(hipGetLastError());
295real hip_maxval(
void* a,
int* n) {
296 const dim3 nthrds(1024, 1, 1);
297 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
298 const int nb = nblcks.x;
299 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
301 if (nb > mma_red_s) {
303 if (mma_bufred != NULL) {
304 HIP_CHECK(hipHostFree(mma_bufred));
305 HIP_CHECK(hipFree(mma_bufred_d));
307 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
308 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
311 hipLaunchKernelGGL(maxval_kernel<real>, nblcks, nthrds, 0, stream,
312 (real*)a, mma_bufred_d, (*n));
313 HIP_CHECK(hipGetLastError());
315 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
317 HIP_CHECK(hipGetLastError());
319 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
320 hipMemcpyDeviceToHost, stream));
321 hipStreamSynchronize(stream);
323 return mma_bufred[0];
327void hip_delx(
void* delx,
void* x,
void* xlow,
void* xupp,
void* pij,
328 void* qij,
void* p0j,
void* q0j,
void* alpha,
void* beta,
void* lambda,
329 real* epsi,
int* n,
int* m) {
330 const dim3 nthrds(1024, 1, 1);
331 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
333 hipLaunchKernelGGL(delx_kernel<real>, nblcks, nthrds, 0,
334 (hipStream_t)glb_cmd_queue,
335 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
336 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
337 (real*)lambda, *epsi, *n, *m);
338 HIP_CHECK(hipGetLastError());
341void hip_GG(
void* GG,
void* x,
void* xlow,
void* xupp,
342 void* pij,
void* qij,
int* n,
int* m) {
343 const dim3 nthrds(1024, 1, 1);
344 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
346 hipLaunchKernelGGL(GG_kernel<real>, nblcks, nthrds, 0,
347 (hipStream_t)glb_cmd_queue,
348 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
350 HIP_CHECK(hipGetLastError());
353void hip_diagx(
void* diagx,
void* x,
void* xsi,
void* xlow,
void* xupp,
354 void* p0j,
void* q0j,
void* pij,
void* qij,
void* alpha,
void* beta,
355 void* eta,
void* lambda,
int *n,
int *m) {
356 const dim3 nthrds(1024, 1, 1);
357 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
359 hipLaunchKernelGGL(diagx_kernel<real>, nblcks, nthrds, 0,
360 (hipStream_t)glb_cmd_queue,
361 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
362 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
363 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
364 HIP_CHECK(hipGetLastError());
367void hip_bb(
void* bb,
void* GG,
void* delx,
void* diagx,
int *n,
int *m) {
368 const dim3 nthrds(1024, 1, 1);
369 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
370 const int nb = ((*n) + 1024 - 1)/ 1024;
371 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
373 hipStreamSynchronize(stream);
375 if (nb > mma_red_s) {
377 if (mma_bufred != NULL) {
378 HIP_CHECK(hipHostFree(mma_bufred));
379 HIP_CHECK(hipFree(mma_bufred_d));
381 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
382 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
385 for (
int i = 0; i < (*m); i++) {
386 hipLaunchKernelGGL(mmasumbb_kernel<real>, nblcks, nthrds, 0, stream,
387 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
388 HIP_CHECK(hipGetLastError());
390 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
392 HIP_CHECK(hipGetLastError());
394 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
396 HIP_CHECK(hipGetLastError());
398 hipStreamSynchronize(stream);
402void hip_AA(
void* AA,
void* GG,
void* diagx,
int *n,
int *m) {
403 const dim3 nthrds(1024, 1, 1);
404 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
405 const int nb = ((*n) + 1024 - 1)/ 1024;
406 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
408 hipStreamSynchronize(stream);
410 if (nb > mma_red_s) {
412 if (mma_bufred != NULL) {
413 HIP_CHECK(hipHostFree(mma_bufred));
414 HIP_CHECK(hipFree(mma_bufred_d));
416 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
417 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
420 for (
int i = 0; i < (*m); i++) {
421 for (
int j = 0; j < (*m); j++) {
422 hipLaunchKernelGGL(mmasumAA_kernel<real>, nblcks, nthrds, 0, stream,
423 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
424 HIP_CHECK(hipGetLastError());
426 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
428 HIP_CHECK(hipGetLastError());
430 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
431 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
432 HIP_CHECK(hipGetLastError());
434 hipStreamSynchronize(stream);
439void hip_dx(
void* dx,
void* delx,
void* diagx,
void* GG,
void* dlambda,
441 const dim3 nthrds(1024, 1, 1);
442 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
444 hipLaunchKernelGGL(dx_kernel<real>, nblcks, nthrds, 0,
445 (hipStream_t)glb_cmd_queue,
446 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
447 HIP_CHECK(hipGetLastError());
450void hip_dxsi(
void* dxsi,
void* xsi,
void* dx,
void* x,
451 void* alpha, real* epsi,
int* n) {
452 const dim3 nthrds(1024, 1, 1);
453 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
455 hipLaunchKernelGGL(dxsi_kernel<real>, nblcks, nthrds, 0,
456 (hipStream_t)glb_cmd_queue,
457 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
458 HIP_CHECK(hipGetLastError());
461void hip_deta(
void* deta,
void* eta,
void* dx,
void* x,
462 void* beta, real* epsi,
int* n) {
463 const dim3 nthrds(1024, 1, 1);
464 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
466 hipLaunchKernelGGL(deta_kernel<real>, nblcks, nthrds, 0,
467 (hipStream_t)glb_cmd_queue,
468 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
469 HIP_CHECK(hipGetLastError());
472void hip_rex(
void* rex,
void* x,
void* xlow,
void* xupp,
void* pij,
473 void* p0j,
void* qij,
void* q0j,
void* lambda,
void* xsi,
void* eta,
475 const dim3 nthrds(1024, 1, 1);
476 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
478 hipLaunchKernelGGL(RexCalculation_kernel<real>, nblcks, nthrds, 0,
479 (hipStream_t)glb_cmd_queue,
480 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
481 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
482 HIP_CHECK(hipGetLastError());
485void hip_rey(
void* rey,
void* c,
void* d,
void* y,
void* lambda,
void* mu,
int* n) {
486 const dim3 nthrds(1024, 1, 1);
487 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
489 hipLaunchKernelGGL(rey_calculation_kernel<real>, nblcks, nthrds, 0,
490 (hipStream_t)glb_cmd_queue,
491 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
492 HIP_CHECK(hipGetLastError());
497void hip_sub2cons(
void *a,
void *b,
void *c, real *d,
int *n) {
498 const dim3 nthrds(1024, 1, 1);
499 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
500 hipLaunchKernelGGL(sub2cons_kernel<real>, nblcks, nthrds, 0,
501 (hipStream_t)glb_cmd_queue,
502 (real *)a, (real *)b, (real *)c, *d, *n);
503 HIP_CHECK(hipGetLastError());
508real hip_norm(
void* a,
int* n) {
509 const dim3 nthrds(1024, 1, 1);
510 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
511 const int nb = ((*n) + 1024 - 1) / 1024;
512 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
514 if (nb > mma_red_s) {
516 if (mma_bufred != NULL) {
517 HIP_CHECK(hipFreeHost(mma_bufred));
518 HIP_CHECK(hipFree(mma_bufred_d));
520 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
521 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
524 hipLaunchKernelGGL(norm_kernel<real>, nblcks, nthrds, 0, stream,
525 (real*)a, mma_bufred_d, (*n));
526 HIP_CHECK(hipGetLastError());
528 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
530 HIP_CHECK(hipGetLastError());
532 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
533 hipMemcpyDeviceToHost, stream));
535 hipStreamSynchronize(stream);
537 return mma_bufred[0];
541void hip_dely(
void* dely,
void* c,
void* d,
void* y,
void* lambda,
542 real* epsi,
int* n) {
543 const dim3 nthrds(1024, 1, 1);
544 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
545 hipLaunchKernelGGL(dely_kernel<real>, nblcks, nthrds, 0,
546 (hipStream_t)glb_cmd_queue,
547 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
548 HIP_CHECK(hipGetLastError());
552real hip_maxval2(
void* a,
void* b, real* cons,
int* n) {
553 const dim3 nthrds(1024, 1, 1);
554 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
555 const int nb = ((*n) + 1024 - 1) / 1024;
556 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
558 if (nb > mma_red_s) {
560 if (mma_bufred != NULL) {
561 HIP_CHECK(hipFreeHost(mma_bufred));
562 HIP_CHECK(hipFree(mma_bufred_d));
564 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
565 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
568 hipLaunchKernelGGL(maxval2_kernel<real>, nblcks, nthrds, 0, stream,
569 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
570 HIP_CHECK(hipGetLastError());
572 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
574 HIP_CHECK(hipGetLastError());
576 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
577 hipMemcpyDeviceToHost, stream));
579 hipStreamSynchronize(stream);
581 return mma_bufred[0];
585real hip_maxval3(
void* a,
void* b,
void* c, real* cons,
int* n) {
586 const dim3 nthrds(1024, 1, 1);
587 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
588 const int nb = ((*n) + 1024 - 1) / 1024;
589 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
591 if (nb > mma_red_s) {
593 if (mma_bufred != NULL) {
594 HIP_CHECK(hipFreeHost(mma_bufred));
595 HIP_CHECK(hipFree(mma_bufred_d));
597 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
598 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
601 hipLaunchKernelGGL(maxval3_kernel<real>, nblcks, nthrds, 0, stream,
602 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
603 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
605 HIP_CHECK(hipGetLastError());
607 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
608 hipMemcpyDeviceToHost, stream));
610 hipStreamSynchronize(stream);
612 return mma_bufred[0];
616void hip_kkt_rex(
void* rex,
void* df0dx,
void* dfdx,
void* xsi,
617 void* eta,
void* lambda,
int* n,
int* m) {
618 const dim3 nthrds(1024, 1, 1);
619 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
620 hipLaunchKernelGGL(kkt_rex_kernel<real>, nblcks, nthrds, 0,
621 (hipStream_t)glb_cmd_queue,
622 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
623 (real*)eta, (real*)lambda, *n, *m);
624 HIP_CHECK(hipGetLastError());
629void hip_maxcons(
void* a, real* b, real* c,
void* d,
int* n) {
630 const dim3 nthrds(1024, 1, 1);
631 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
632 hipLaunchKernelGGL(maxcons_kernel<real>, nblcks, nthrds, 0,
633 (hipStream_t)glb_cmd_queue,
634 (real*)a, *b, *c, (real*)d, *n);
635 HIP_CHECK(hipGetLastError());
639real hip_lcsc2(
void *a,
void*b,
int *n) {
640 const dim3 nthrds(1024, 1, 1);
641 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
642 const int nb = ((*n) + 1024 - 1) / 1024;
643 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
645 if (nb > mma_red_s) {
647 if (mma_bufred != NULL) {
648 HIP_CHECK(hipFreeHost(mma_bufred));
649 HIP_CHECK(hipFree(mma_bufred_d));
651 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
652 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
655 hipLaunchKernelGGL(glsc2_kernel<real>, nblcks, nthrds, 0, stream,
656 (real*)a, (real*)b, mma_bufred_d, (*n));
657 HIP_CHECK(hipGetLastError());
659 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
661 HIP_CHECK(hipGetLastError());
663 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
664 hipMemcpyDeviceToHost, stream));
666 hipStreamSynchronize(stream);
668 return mma_bufred[0];
672void hip_mpisum(
void *a,
int *n) {
673#ifdef HAVE_DEVICE_MPI
674 real* temp = (real*)a;
675 hipStreamSynchronize(stream);
676 device_mpi_allreduce_inplace(temp, *n,
sizeof(real), DEVICE_MPI_SUM);
681void hip_add2inv2(
void* a,
void* b, real* c,
int* n) {
682 const dim3 nthrds(1024, 1, 1);
683 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
684 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
686 hipLaunchKernelGGL(add2inv2_kernel<real>, nblcks, nthrds, 0, stream,
687 (real*)a, (real*)b, *c, *n);
688 HIP_CHECK(hipGetLastError());
691void hip_max2(
void* a, real* b,
void* c, real* d,
int* n) {
692 const dim3 nthrds(1024, 1, 1);
693 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
694 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
696 hipLaunchKernelGGL(max2_kernel<real>, nblcks, nthrds, 0, stream,
697 (real*)a, *b, (real*)c, *d, *n);
698 HIP_CHECK(hipGetLastError());
701void hip_updatebb(
void* bb,
void* dellambda,
void* dely,
void* d,
702 void* mu,
void* y, real* delz,
int* m) {
703 const dim3 nthrds(1024, 1, 1);
704 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
705 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
707 hipLaunchKernelGGL(updatebb_kernel<real>, nblcks, nthrds, 0, stream,
708 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
709 (real*)mu, (real*)y, *delz, *m);
710 HIP_CHECK(hipGetLastError());
713void hip_updateAA(
void* AA,
void* globaltmp_mm,
void* s,
void* lambda,
714 void* d,
void* mu,
void* y,
void* a,
715 real* zeta, real* z,
int* m) {
716 const dim3 nthrds(1024, 1, 1);
717 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
718 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
720 hipLaunchKernelGGL(updateAA_kernel<real>, nblcks, nthrds, 0, stream,
721 (real*)AA, (real*)globaltmp_mm, (real*)s,
722 (real*)lambda, (real*)d, (real*)mu,
723 (real*)y, (real*)a, *zeta, *z, *m);
724 HIP_CHECK(hipGetLastError());
727void hip_dy(
void* dy,
void* dely,
void* dlambda,
void* d,
728 void* mu,
void* y,
int* n) {
729 const dim3 nthrds(1024, 1, 1);
730 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
731 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
733 hipLaunchKernelGGL(dy_kernel<real>, nblcks, nthrds, 0, stream,
734 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
735 (real*)mu, (real*)y, *n);
736 HIP_CHECK(hipGetLastError());