40#include <hip/hip_runtime.h>
43#include <neko/device/device_config.h>
44#include <neko/device/hip/check.h>
45#include <neko/math/bcknd/device/device_mpi_op.h>
48#include "mma_kernel.h"
53real * mma_bufred = NULL;
54real * mma_bufred_d = NULL;
56void hip_Hess(
void* Hess,
void* hijx,
void* Ljjxinv,
int *n,
int *m) {
57 const dim3 nthrds(1024, 1, 1);
58 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
59 const int nb = ((*n) + 1024 - 1) / 1024;
60 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
61 hipStreamSynchronize(stream);
65 if (mma_bufred != NULL) {
66 HIP_CHECK(hipHostFree(mma_bufred));
67 HIP_CHECK(hipFree(mma_bufred_d));
69 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
70 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
73 for (
int i = 0; i < (*m); i++) {
74 for (
int j = 0; j < (*m); j++) {
75 hipLaunchKernelGGL(mmasumHess_kernel<real>, nblcks, nthrds, 0, stream,
76 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
77 HIP_CHECK(hipGetLastError());
79 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
81 HIP_CHECK(hipGetLastError());
83 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
84 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
85 HIP_CHECK(hipGetLastError());
87 hipStreamSynchronize(stream);
92void mma_Ljjxinv_hip(
void* Ljjxinv,
void* pjlambda,
void* qjlambda,
void* x,
93 void* low,
void* upp,
void* alpha,
void* beta,
int* n) {
94 const dim3 nthrds(1024, 1, 1);
95 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
96 hipLaunchKernelGGL(mma_Ljjxinv_kernel<real>, nblcks, nthrds, 0,
97 (hipStream_t)glb_cmd_queue, (real*)Ljjxinv, (real*)pjlambda, (real*)qjlambda,
98 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
99 HIP_CHECK(hipGetLastError());
102void mma_dipsolvesub1_hip(
void* x,
void* pjlambda,
void* qjlambda,
void* low,
103 void* upp,
void* alpha,
void* beta,
int* n) {
104 const dim3 nthrds(1024, 1, 1);
105 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
106 hipLaunchKernelGGL(mma_dipsolvesub1_kernel<real>, nblcks, nthrds, 0,
107 (hipStream_t)glb_cmd_queue, (real*)x, (real*)pjlambda, (real*)qjlambda,
108 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
109 HIP_CHECK(hipGetLastError());
112void mattrans_v_mul_hip(
void* output,
void* pij,
void* lambda,
int* m,
int* n) {
113 const dim3 nthrds(1024, 1, 1);
114 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
115 hipLaunchKernelGGL(mattrans_v_mul_kernel<real>, nblcks, nthrds, 0,
116 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
117 HIP_CHECK(hipGetLastError());
120void mma_gensub4_hip(
void* x,
void* low,
void* upp,
void* pij,
void* qij,
121 int* n,
int* m,
void* bi) {
122 const dim3 nthrds(1024, 1, 1);
123 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
124 const int nb = ((*n) + 1024 - 1) / 1024;
125 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
127 if (nb > mma_red_s) {
129 if (mma_bufred != NULL) {
130 HIP_CHECK(hipFreeHost(mma_bufred));
131 HIP_CHECK(hipFree(mma_bufred_d));
133 HIP_CHECK(hipHostMalloc(&mma_bufred,
135 HIP_CHECK(hipMalloc(&mma_bufred_d,
140 real* bi_d = (real*)bi;
141 hipMalloc(&temp, (*m) * (*n) *
sizeof(real));
143 hipLaunchKernelGGL(mma_sub4_kernel<real>, nblcks, nthrds, 0, stream,
144 (real*)x, (real*)low, (real*)upp, (real*)pij, (real*)qij,
147 for (
int i = 0; i < (*m); i++) {
148 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
149 temp, mma_bufred_d, (*n), (*m), i);
150 HIP_CHECK(hipGetLastError());
152 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
154 HIP_CHECK(hipGetLastError());
156 HIP_CHECK(hipMemcpyAsync(
157 bi_d + i, mma_bufred_d,
sizeof(real),
158 hipMemcpyDeviceToDevice, stream));
160 hipStreamSynchronize(stream);
166void mma_gensub3_hip(
void* x,
void* df0dx,
void* dfdx,
void* low,
167 void* upp,
void* xmin,
void* xmax,
void* alpha,
168 void* beta,
void* p0j,
void* q0j,
void* pij,
169 void* qij,
int* n,
int* m) {
170 const dim3 nthrds(1024, 1, 1);
171 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
173 hipLaunchKernelGGL(mma_sub3_kernel<real>, nblcks, nthrds, 0,
174 (hipStream_t)glb_cmd_queue,
175 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
176 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
177 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
180 HIP_CHECK(hipGetLastError());
183void mma_gensub2_hip(
void* low,
void* upp,
void* x,
void* xold1,
184 void* xold2,
void* xdiff, real* asydecr,
185 real* asyincr,
int* n) {
186 const dim3 nthrds(1024, 1, 1);
187 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
189 hipLaunchKernelGGL(mma_sub2_kernel<real>, nblcks, nthrds, 0,
190 (hipStream_t)glb_cmd_queue,
191 (real*)low, (real*)upp, (real*)x, (real*)xold1,
192 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
194 HIP_CHECK(hipGetLastError());
197void mma_gensub1_hip(
void* low,
void* upp,
void* x,
void* xmin,
void* xmax,
198 real* asyinit,
int* n) {
199 const dim3 nthrds(1024, 1, 1);
200 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
202 hipLaunchKernelGGL(mma_sub1_kernel<real>, nblcks, nthrds, 0,
203 (hipStream_t)glb_cmd_queue,
204 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
207 HIP_CHECK(hipGetLastError());
210void hip_mma_max(
void* xsi,
void* x,
void* alpha,
int* n) {
211 const dim3 nthrds(1024, 1, 1);
212 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
214 hipLaunchKernelGGL(mma_max2_kernel<real>, nblcks, nthrds, 0,
215 (hipStream_t)glb_cmd_queue,
216 (real*)xsi, (real*)x, (real*)alpha, *n);
218 HIP_CHECK(hipGetLastError());
221void hip_relambda(
void* relambda,
void* x,
void* xupp,
void* xlow,
222 void* pij,
void* qij,
int* n,
int* m) {
223 const dim3 nthrds(1024, 1, 1);
224 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
225 const int nb = nblcks.x;
226 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
228 if (nb > mma_red_s) {
230 if (mma_bufred != NULL) {
231 HIP_CHECK(hipHostFree(mma_bufred));
232 HIP_CHECK(hipFree(mma_bufred_d));
234 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
235 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
239 hipMalloc(&temp, (*n) * (*m) *
sizeof(real));
241 hipLaunchKernelGGL(relambda_kernel<real>, nblcks, nthrds, 0, stream,
242 temp, (real*)x, (real*)xupp, (real*)xlow,
243 (real*)pij, (real*)qij, *n, *m);
245 for (
int i = 0; i < (*m); i++) {
246 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
247 temp, mma_bufred_d, (*n), (*m), i);
248 HIP_CHECK(hipGetLastError());
250 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0,
251 stream, mma_bufred_d, nb);
252 HIP_CHECK(hipGetLastError());
254 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
255 (real*)relambda, mma_bufred_d, 1, i);
256 HIP_CHECK(hipGetLastError());
258 hipStreamSynchronize(stream);
264void hip_sub2cons2(
void* a,
void* b,
void* c,
void* d, real* e,
int* n) {
265 const dim3 nthrds(1024, 1, 1);
266 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
268 hipLaunchKernelGGL(sub2cons2_kernel<real>, nblcks, nthrds, 0,
269 (hipStream_t)glb_cmd_queue,
270 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
272 HIP_CHECK(hipGetLastError());
275real hip_maxval(
void* a,
int* n) {
276 const dim3 nthrds(1024, 1, 1);
277 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
278 const int nb = nblcks.x;
279 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
281 if (nb > mma_red_s) {
283 if (mma_bufred != NULL) {
284 HIP_CHECK(hipHostFree(mma_bufred));
285 HIP_CHECK(hipFree(mma_bufred_d));
287 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
288 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
291 hipLaunchKernelGGL(maxval_kernel<real>, nblcks, nthrds, 0, stream,
292 (real*)a, mma_bufred_d, (*n));
293 HIP_CHECK(hipGetLastError());
295 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
297 HIP_CHECK(hipGetLastError());
299 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
300 hipMemcpyDeviceToHost, stream));
301 hipStreamSynchronize(stream);
303 return mma_bufred[0];
307void hip_delx(
void* delx,
void* x,
void* xlow,
void* xupp,
void* pij,
308 void* qij,
void* p0j,
void* q0j,
void* alpha,
void* beta,
void* lambda,
309 real* epsi,
int* n,
int* m) {
310 const dim3 nthrds(1024, 1, 1);
311 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
313 hipLaunchKernelGGL(delx_kernel<real>, nblcks, nthrds, 0,
314 (hipStream_t)glb_cmd_queue,
315 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
316 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
317 (real*)lambda, *epsi, *n, *m);
318 HIP_CHECK(hipGetLastError());
321void hip_GG(
void* GG,
void* x,
void* xlow,
void* xupp,
322 void* pij,
void* qij,
int* n,
int* m) {
323 const dim3 nthrds(1024, 1, 1);
324 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
326 hipLaunchKernelGGL(GG_kernel<real>, nblcks, nthrds, 0,
327 (hipStream_t)glb_cmd_queue,
328 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
330 HIP_CHECK(hipGetLastError());
333void hip_diagx(
void* diagx,
void* x,
void* xsi,
void* xlow,
void* xupp,
334 void* p0j,
void* q0j,
void* pij,
void* qij,
void* alpha,
void* beta,
335 void* eta,
void* lambda,
int *n,
int *m) {
336 const dim3 nthrds(1024, 1, 1);
337 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
339 hipLaunchKernelGGL(diagx_kernel<real>, nblcks, nthrds, 0,
340 (hipStream_t)glb_cmd_queue,
341 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
342 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
343 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
344 HIP_CHECK(hipGetLastError());
347void hip_bb(
void* bb,
void* GG,
void* delx,
void* diagx,
int *n,
int *m) {
348 const dim3 nthrds(1024, 1, 1);
349 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
350 const int nb = ((*n) + 1024 - 1)/ 1024;
351 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
353 hipStreamSynchronize(stream);
355 if (nb > mma_red_s) {
357 if (mma_bufred != NULL) {
358 HIP_CHECK(hipHostFree(mma_bufred));
359 HIP_CHECK(hipFree(mma_bufred_d));
361 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
362 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
365 for (
int i = 0; i < (*m); i++) {
366 hipLaunchKernelGGL(mmasumbb_kernel<real>, nblcks, nthrds, 0, stream,
367 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
368 HIP_CHECK(hipGetLastError());
370 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
372 HIP_CHECK(hipGetLastError());
374 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
376 HIP_CHECK(hipGetLastError());
378 hipStreamSynchronize(stream);
382void hip_AA(
void* AA,
void* GG,
void* diagx,
int *n,
int *m) {
383 const dim3 nthrds(1024, 1, 1);
384 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
385 const int nb = ((*n) + 1024 - 1)/ 1024;
386 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
388 hipStreamSynchronize(stream);
390 if (nb > mma_red_s) {
392 if (mma_bufred != NULL) {
393 HIP_CHECK(hipHostFree(mma_bufred));
394 HIP_CHECK(hipFree(mma_bufred_d));
396 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
397 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
400 for (
int i = 0; i < (*m); i++) {
401 for (
int j = 0; j < (*m); j++) {
402 hipLaunchKernelGGL(mmasumAA_kernel<real>, nblcks, nthrds, 0, stream,
403 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
404 HIP_CHECK(hipGetLastError());
406 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
408 HIP_CHECK(hipGetLastError());
410 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
411 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
412 HIP_CHECK(hipGetLastError());
414 hipStreamSynchronize(stream);
419void hip_dx(
void* dx,
void* delx,
void* diagx,
void* GG,
void* dlambda,
421 const dim3 nthrds(1024, 1, 1);
422 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
424 hipLaunchKernelGGL(dx_kernel<real>, nblcks, nthrds, 0,
425 (hipStream_t)glb_cmd_queue,
426 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
427 HIP_CHECK(hipGetLastError());
430void hip_dxsi(
void* dxsi,
void* xsi,
void* dx,
void* x,
431 void* alpha, real* epsi,
int* n) {
432 const dim3 nthrds(1024, 1, 1);
433 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
435 hipLaunchKernelGGL(dxsi_kernel<real>, nblcks, nthrds, 0,
436 (hipStream_t)glb_cmd_queue,
437 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
438 HIP_CHECK(hipGetLastError());
441void hip_deta(
void* deta,
void* eta,
void* dx,
void* x,
442 void* beta, real* epsi,
int* n) {
443 const dim3 nthrds(1024, 1, 1);
444 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
446 hipLaunchKernelGGL(deta_kernel<real>, nblcks, nthrds, 0,
447 (hipStream_t)glb_cmd_queue,
448 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
449 HIP_CHECK(hipGetLastError());
452void hip_rex(
void* rex,
void* x,
void* xlow,
void* xupp,
void* pij,
453 void* p0j,
void* qij,
void* q0j,
void* lambda,
void* xsi,
void* eta,
455 const dim3 nthrds(1024, 1, 1);
456 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
458 hipLaunchKernelGGL(RexCalculation_kernel<real>, nblcks, nthrds, 0,
459 (hipStream_t)glb_cmd_queue,
460 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
461 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
462 HIP_CHECK(hipGetLastError());
465void hip_rey(
void* rey,
void* c,
void* d,
void* y,
void* lambda,
void* mu,
int* n) {
466 const dim3 nthrds(1024, 1, 1);
467 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
469 hipLaunchKernelGGL(rey_calculation_kernel<real>, nblcks, nthrds, 0,
470 (hipStream_t)glb_cmd_queue,
471 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
472 HIP_CHECK(hipGetLastError());
477void hip_sub2cons(
void *a,
void *b,
void *c, real *d,
int *n) {
478 const dim3 nthrds(1024, 1, 1);
479 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
480 hipLaunchKernelGGL(sub2cons_kernel<real>, nblcks, nthrds, 0,
481 (hipStream_t)glb_cmd_queue,
482 (real *)a, (real *)b, (real *)c, *d, *n);
483 HIP_CHECK(hipGetLastError());
488real hip_norm(
void* a,
int* n) {
489 const dim3 nthrds(1024, 1, 1);
490 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
491 const int nb = ((*n) + 1024 - 1) / 1024;
492 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
494 if (nb > mma_red_s) {
496 if (mma_bufred != NULL) {
497 HIP_CHECK(hipFreeHost(mma_bufred));
498 HIP_CHECK(hipFree(mma_bufred_d));
500 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
501 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
504 hipLaunchKernelGGL(norm_kernel<real>, nblcks, nthrds, 0, stream,
505 (real*)a, mma_bufred_d, (*n));
506 HIP_CHECK(hipGetLastError());
508 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
510 HIP_CHECK(hipGetLastError());
512 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
513 hipMemcpyDeviceToHost, stream));
515 hipStreamSynchronize(stream);
517 return mma_bufred[0];
521void hip_dely(
void* dely,
void* c,
void* d,
void* y,
void* lambda,
522 real* epsi,
int* n) {
523 const dim3 nthrds(1024, 1, 1);
524 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
525 hipLaunchKernelGGL(dely_kernel<real>, nblcks, nthrds, 0,
526 (hipStream_t)glb_cmd_queue,
527 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
528 HIP_CHECK(hipGetLastError());
532real hip_maxval2(
void* a,
void* b, real* cons,
int* n) {
533 const dim3 nthrds(1024, 1, 1);
534 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
535 const int nb = ((*n) + 1024 - 1) / 1024;
536 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
538 if (nb > mma_red_s) {
540 if (mma_bufred != NULL) {
541 HIP_CHECK(hipFreeHost(mma_bufred));
542 HIP_CHECK(hipFree(mma_bufred_d));
544 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
545 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
548 hipLaunchKernelGGL(maxval2_kernel<real>, nblcks, nthrds, 0, stream,
549 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
550 HIP_CHECK(hipGetLastError());
552 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
554 HIP_CHECK(hipGetLastError());
556 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
557 hipMemcpyDeviceToHost, stream));
559 hipStreamSynchronize(stream);
561 return mma_bufred[0];
565real hip_maxval3(
void* a,
void* b,
void* c, real* cons,
int* n) {
566 const dim3 nthrds(1024, 1, 1);
567 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
568 const int nb = ((*n) + 1024 - 1) / 1024;
569 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
571 if (nb > mma_red_s) {
573 if (mma_bufred != NULL) {
574 HIP_CHECK(hipFreeHost(mma_bufred));
575 HIP_CHECK(hipFree(mma_bufred_d));
577 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
578 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
581 hipLaunchKernelGGL(maxval3_kernel<real>, nblcks, nthrds, 0, stream,
582 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
583 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
585 HIP_CHECK(hipGetLastError());
587 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
588 hipMemcpyDeviceToHost, stream));
590 hipStreamSynchronize(stream);
592 return mma_bufred[0];
596void hip_kkt_rex(
void* rex,
void* df0dx,
void* dfdx,
void* xsi,
597 void* eta,
void* lambda,
int* n,
int* m) {
598 const dim3 nthrds(1024, 1, 1);
599 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
600 hipLaunchKernelGGL(kkt_rex_kernel<real>, nblcks, nthrds, 0,
601 (hipStream_t)glb_cmd_queue,
602 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
603 (real*)eta, (real*)lambda, *n, *m);
604 HIP_CHECK(hipGetLastError());
609void hip_maxcons(
void* a, real* b, real* c,
void* d,
int* n) {
610 const dim3 nthrds(1024, 1, 1);
611 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
612 hipLaunchKernelGGL(maxcons_kernel<real>, nblcks, nthrds, 0,
613 (hipStream_t)glb_cmd_queue,
614 (real*)a, *b, *c, (real*)d, *n);
615 HIP_CHECK(hipGetLastError());
619real hip_lcsc2(
void *a,
void*b,
int *n) {
620 const dim3 nthrds(1024, 1, 1);
621 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
622 const int nb = ((*n) + 1024 - 1) / 1024;
623 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
625 if (nb > mma_red_s) {
627 if (mma_bufred != NULL) {
628 HIP_CHECK(hipFreeHost(mma_bufred));
629 HIP_CHECK(hipFree(mma_bufred_d));
631 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
632 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
635 hipLaunchKernelGGL(glsc2_kernel<real>, nblcks, nthrds, 0, stream,
636 (real*)a, (real*)b, mma_bufred_d, (*n));
637 HIP_CHECK(hipGetLastError());
639 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
641 HIP_CHECK(hipGetLastError());
643 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
644 hipMemcpyDeviceToHost, stream));
646 hipStreamSynchronize(stream);
648 return mma_bufred[0];
652void hip_mpisum(
void *a,
int *n) {
653#ifdef HAVE_DEVICE_MPI
654 real* temp = (real*)a;
655 hipStreamSynchronize(stream);
656 device_mpi_allreduce_inplace(temp, *n,
sizeof(real), DEVICE_MPI_SUM);
661void hip_add2inv2(
void* a,
void* b, real* c,
int* n) {
662 const dim3 nthrds(1024, 1, 1);
663 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
664 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
666 hipLaunchKernelGGL(add2inv2_kernel<real>, nblcks, nthrds, 0, stream,
667 (real*)a, (real*)b, *c, *n);
668 HIP_CHECK(hipGetLastError());
671void hip_max2(
void* a, real* b,
void* c, real* d,
int* n) {
672 const dim3 nthrds(1024, 1, 1);
673 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
674 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
676 hipLaunchKernelGGL(max2_kernel<real>, nblcks, nthrds, 0, stream,
677 (real*)a, *b, (real*)c, *d, *n);
678 HIP_CHECK(hipGetLastError());
681void hip_updatebb(
void* bb,
void* dellambda,
void* dely,
void* d,
682 void* mu,
void* y, real* delz,
int* m) {
683 const dim3 nthrds(1024, 1, 1);
684 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
685 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
687 hipLaunchKernelGGL(updatebb_kernel<real>, nblcks, nthrds, 0, stream,
688 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
689 (real*)mu, (real*)y, *delz, *m);
690 HIP_CHECK(hipGetLastError());
693void hip_updateAA(
void* AA,
void* globaltmp_mm,
void* s,
void* lambda,
694 void* d,
void* mu,
void* y,
void* a,
695 real* zeta, real* z,
int* m) {
696 const dim3 nthrds(1024, 1, 1);
697 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
698 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
700 hipLaunchKernelGGL(updateAA_kernel<real>, nblcks, nthrds, 0, stream,
701 (real*)AA, (real*)globaltmp_mm, (real*)s,
702 (real*)lambda, (real*)d, (real*)mu,
703 (real*)y, (real*)a, *zeta, *z, *m);
704 HIP_CHECK(hipGetLastError());
707void hip_dy(
void* dy,
void* dely,
void* dlambda,
void* d,
708 void* mu,
void* y,
int* n) {
709 const dim3 nthrds(1024, 1, 1);
710 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
711 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
713 hipLaunchKernelGGL(dy_kernel<real>, nblcks, nthrds, 0, stream,
714 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
715 (real*)mu, (real*)y, *n);
716 HIP_CHECK(hipGetLastError());