Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma.hip
Go to the documentation of this file.
1
37// System includes
38#include <stdio.h>
39#include <stdlib.h>
40
41// Device includes
42#include <hip/hip_runtime.h>
43#include <hipsolver/hipsolver.h>
44
45// Neko includes
46#include <neko/device/device_config.h>
47#include <neko/device/hip/check.h>
48#include <neko/math/bcknd/device/device_mpi_op.h>
49
50// Local includes
51#include "mma_kernel.h"
52
53extern "C" {
54
55int mma_red_s = 0;
56real * mma_bufred = NULL;
57real * mma_bufred_d = NULL;
58
59void mma_update_hessian_z_hip(void* Hess, void* a, int m) {
60 const int M = m;
61 const int total = M * M;
62
63 const dim3 nthrds(1024, 1, 1);
64 const dim3 nblcks((total + 1024 - 1) / 1024, 1, 1);
65 const hipStream_t stream = (hipStream_t)glb_cmd_queue; // same as CUDA stream
66
69 (real*)Hess, (real*)a, M);
71}
72
73void hipSOLVER_wrapper(void* A, void* b, int n, int* jj) {
77
78 int lwork;
79 double *workspace;
80 int *ipiv;
81 int *info; // Device pointer for hipSOLVER info
82 int host_info = 0; // Host variable to store the info
83
84 // Workspace query
85 status = hipsolverDnDgetrf_bufferSize(handle, n, n, (double*)A, n, &lwork);
86 hipMalloc(&workspace, lwork * sizeof(double));
87 hipMalloc(&ipiv, n * sizeof(int));
88 hipMalloc(&info, sizeof(int));
89
90 // LU factorization and solve
91 hipsolverDnDgetrf(handle, n, n, (double*)A, n, workspace, ipiv, info);
92
93 // Copy info from device to host to check if factorization succeeded
95
96 if (host_info == 0) {
97 // Only solve if factorization was successful
98 hipsolverDnDgetrs(handle, HIPSOLVER_OP_N, n, 1, (double*)A, n, ipiv, (double*)b, n, info);
99 // Copy the final info value
101 }
102
103 // Return the actual info value through jj
104 *jj = host_info;
105
106 // Cleanup
108 hipFree(ipiv);
109 hipFree(info);
111}
112
113void mma_prepare_aa_matrix_hip(void* AA, void* s, void* lambda,
114 void* d, void* mu, void* y,
115 void* a, real zeta, real z, int m) {
116 const int M = m;
117 const int matrix_size = M + 1;
118 const dim3 nthrds(256, 1, 1);
119 const dim3 nblcks((M + 256 - 1) / 256, 1, 1);
121
122 // Launch kernel to prepare AA matrix entirely on device
124 nblcks, nthrds, 0, stream,
125 (real*)AA, (real*)s, (real*)lambda, (real*)d,
126 (real*)mu, (real*)y, (real*)a, zeta, z, M);
127
129}
130
131void mma_prepare_hessian_hip(void* Hess, void* y,
132 void* mu, void* lambda, int m) {
133 const int M = m;
134 const dim3 nthrds(1024, 1, 1);
135 const dim3 nblcks((M + 1024 - 1) / 1024, 1, 1);
137
138 // Update diagonal elements
140 nblcks, nthrds, 0, stream,
141 (real*)Hess, (real*)y, (real*)mu, (real*)lambda, M);
143
144 // Synchronize to ensure diagonal updates are complete
146
147 // Choose kernel based on problem size
148 if (M <= 1024) {
149 // Single-block version (fast for small m)
150 const dim3 stab_nblcks(1, 1, 1);
153 (real*)Hess, M);
155 } else {
156 // Multi-block version (for large m)
157 // Compute trace on host (simple and reliable)
158 real* h_Hess = (real*)malloc(M * sizeof(real));
159
160 // Extract diagonal elements
161 for (int i = 0; i < M; i++) {
163 (real*)Hess + i * M + i,
164 sizeof(real),
166 }
168
169 // Compute trace and LM factor
170 real trace = 0.0;
171 for (int i = 0; i < M; i++) {
172 trace += h_Hess[i];
173 }
174 real lm_factor = fmax(-1.0e-4 * trace / M, 1.0e-7);
175
176 // Apply stabilization in parallel
178 nblcks, nthrds, 0, stream,
179 (real*)Hess, lm_factor, M);
181
182 free(h_Hess);
183 }
184}
185
186// Custom linear solver using kernel
187extern "C" void hip_custom_solver(void* A, void* b, int n, int* info) {
189
190 if (n <= 0) {
191 *info = -1; // Use CPU fallback
192 return;
193 }
194 const dim3 nthrds(1024, 1, 1);
195 const dim3 nblcks(1, 1, 1);
196
198 (real*)A, (real*)b, n);
199
201 if (err == hipSuccess) {
202 *info = 0; // GPU solver succeeded
203 } else {
204 *info = -1; // GPU failed
205 }
206}
207
208void delta_1dbeam_hip(void* Delta, real* L_total, real* Le,
209 int* offset, int* n) {
210 const dim3 nthrds(1024, 1, 1);
211 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
214 (real*)Delta, *L_total, *Le, *offset, *n);
216}
217
218void hip_Hess(void* Hess, void* hijx, void* Ljjxinv, int *n, int *m) {
219 const dim3 nthrds(1024, 1, 1);
220 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
221 const int nb = ((*n) + 1024 - 1) / 1024;
224
225 if (nb > mma_red_s) {
226 mma_red_s = nb;
227 if (mma_bufred != NULL) {
228 HIP_CHECK(hipHostFree(mma_bufred));
229 HIP_CHECK(hipFree(mma_bufred_d));
230 }
231 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
232 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
233 }
234
235 for (int i = 0; i < (*m); i++) {
236 for (int j = 0; j < (*m); j++) {
238 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
240
242 mma_bufred_d, nb);
244
245 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
246 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
248
250 }
251 }
252}
253
254void mma_Ljjxinv_hip(void* Ljjxinv, void* pjlambda, void* qjlambda, void* x,
255 void* low, void* upp, void* alpha, void* beta, int* n) {
256 const dim3 nthrds(1024, 1, 1);
257 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
260 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
262}
263
264void mma_dipsolvesub1_hip(void* x, void* pjlambda, void* qjlambda, void* low,
265 void* upp, void* alpha, void* beta, int* n) {
266 const dim3 nthrds(1024, 1, 1);
267 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
270 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
272}
273
274void mattrans_v_mul_hip(void* output, void* pij, void* lambda, int* m, int* n) {
275 const dim3 nthrds(1024, 1, 1);
276 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
278 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
280}
281
282void mma_gensub4_hip(const void* x, const void* low, const void* upp,
283 const void* pij, const void* qij,
284 const int* n, const int* m, void* bi) {
285
286 const int N = *n;
287 const int M = *m;
288
289 const dim3 nthrds(1024, 1, 1);
290 const dim3 nblcks((N + 1023) / 1024, 1, 1);
291 const int nb = (N + 1023) / 1024;
293
294 if (nb > mma_red_s) {
295 mma_red_s = nb;
296
297 if (mma_bufred != nullptr) {
298 HIP_CHECK(hipFreeHost(mma_bufred));
299 HIP_CHECK(hipFree(mma_bufred_d));
300 }
301
302 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
303 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
304 }
305
306 real* temp;
307 real* bi_d = static_cast<real*>(bi);
308 HIP_CHECK(hipMalloc(&temp, M * N * sizeof(real)));
309
311 static_cast<const real*>(x),
312 static_cast<const real*>(low),
313 static_cast<const real*>(upp),
314 static_cast<const real*>(pij),
315 static_cast<const real*>(qij),
316 temp, N, M);
317
318 for (int i = 0; i < M; ++i) {
320 temp, mma_bufred_d, N, M, i);
322
324 mma_bufred_d, nb);
326
328 bi_d + i, mma_bufred_d, sizeof(real),
330
332 }
333
334 HIP_CHECK(hipFree(temp));
335}
336
337
338void mma_gensub3_hip(void* x, void* df0dx, void* dfdx, void* low,
339 void* upp, void* xmin, void* xmax, void* alpha,
340 void* beta, void* p0j, void* q0j, void* pij,
341 void* qij, int* n, int* m) {
342 const dim3 nthrds(1024, 1, 1);
343 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
344
347 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
348 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
349 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
350 (real*)qij, *n, *m);
351
353}
354
355void mma_gensub2_hip(void* low, void* upp, void* x, void* xold1,
356 void* xold2, void* xdiff, real* asydecr,
357 real* asyincr, int* n) {
358 const dim3 nthrds(1024, 1, 1);
359 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
360
363 (real*)low, (real*)upp, (real*)x, (real*)xold1,
364 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
365
367}
368
369void mma_gensub1_hip(void* low, void* upp, void* x, void* xmin, void* xmax,
370 real* asyinit, int* n) {
371 const dim3 nthrds(1024, 1, 1);
372 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
373
376 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
377 *asyinit, *n);
378
380}
381
382void hip_mma_max(void* xsi, void* x, void* alpha, int* n) {
383 const dim3 nthrds(1024, 1, 1);
384 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
385
388 (real*)xsi, (real*)x, (real*)alpha, *n);
389
391}
392
393void hip_relambda(void* relambda, void* x, void* xupp, void* xlow,
394 void* pij, void* qij, int* n, int* m) {
395 const dim3 nthrds(1024, 1, 1);
396 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
397 const int nb = nblcks.x;
399
400 if (nb > mma_red_s) {
401 mma_red_s = nb;
402 if (mma_bufred != NULL) {
403 HIP_CHECK(hipHostFree(mma_bufred));
404 HIP_CHECK(hipFree(mma_bufred_d));
405 }
406 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
407 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
408 }
409
410 real* temp;
411 hipMalloc(&temp, (*n) * (*m) * sizeof(real));
412
414 temp, (real*)x, (real*)xupp, (real*)xlow,
415 (real*)pij, (real*)qij, *n, *m);
416
417 for (int i = 0; i < (*m); i++) {
419 temp, mma_bufred_d, (*n), (*m), i);
421
423 stream, mma_bufred_d, nb);
425
426 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
427 (real*)relambda, mma_bufred_d, 1, i);
429
431 }
432
433 hipFree(temp);
434}
435
436void hip_sub2cons2(void* a, void* b, void* c, void* d, real* e, int* n) {
437 const dim3 nthrds(1024, 1, 1);
438 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
439
442 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
443
445}
446
447real hip_maxval(void* a, int* n) {
448 const dim3 nthrds(1024, 1, 1);
449 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
450 const int nb = nblcks.x;
452
453 if (nb > mma_red_s) {
454 mma_red_s = nb;
455 if (mma_bufred != NULL) {
456 HIP_CHECK(hipHostFree(mma_bufred));
457 HIP_CHECK(hipFree(mma_bufred_d));
458 }
459 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
460 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
461 }
462
464 (real*)a, mma_bufred_d, (*n));
466
468 mma_bufred_d, nb);
470
471 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
474
475 return mma_bufred[0];
476}
477
478
479void hip_delx(void* delx, void* x, void* xlow, void* xupp, void* pij,
480 void* qij, void* p0j, void* q0j, void* alpha, void* beta, void* lambda,
481 real* epsi, int* n, int* m) {
482 const dim3 nthrds(1024, 1, 1);
483 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
484
487 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
488 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
489 (real*)lambda, *epsi, *n, *m);
491}
492
493void hip_GG(void* GG, void* x, void* xlow, void* xupp,
494 void* pij, void* qij, int* n, int* m) {
495 const dim3 nthrds(1024, 1, 1);
496 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
497
500 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
501 (real*)qij, *n, *m);
503}
504
505void hip_diagx(void* diagx, void* x, void* xsi, void* xlow, void* xupp,
506 void* p0j, void* q0j, void* pij, void* qij, void* alpha, void* beta,
507 void* eta, void* lambda, int *n, int *m) {
508 const dim3 nthrds(1024, 1, 1);
509 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
510
513 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
514 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
515 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
517}
518
519void hip_bb(void* bb, void* GG, void* delx, void* diagx, int *n, int *m) {
520 const dim3 nthrds(1024, 1, 1);
521 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
522 const int nb = ((*n) + 1024 - 1)/ 1024;
524
526
527 if (nb > mma_red_s) {
528 mma_red_s = nb;
529 if (mma_bufred != NULL) {
530 HIP_CHECK(hipHostFree(mma_bufred));
531 HIP_CHECK(hipFree(mma_bufred_d));
532 }
533 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
534 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
535 }
536
537 for (int i = 0; i < (*m); i++) {
539 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
541
543 mma_bufred_d, nb);
545
546 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
547 mma_bufred_d, 1, i);
549
551 }
552}
553
554void hip_AA(void* AA, void* GG, void* diagx, int *n, int *m) {
555 const dim3 nthrds(1024, 1, 1);
556 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
557 const int nb = ((*n) + 1024 - 1)/ 1024;
559
561
562 if (nb > mma_red_s) {
563 mma_red_s = nb;
564 if (mma_bufred != NULL) {
565 HIP_CHECK(hipHostFree(mma_bufred));
566 HIP_CHECK(hipFree(mma_bufred_d));
567 }
568 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
569 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
570 }
571
572 for (int i = 0; i < (*m); i++) {
573 for (int j = 0; j < (*m); j++) {
575 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
577
579 mma_bufred_d, nb);
581
582 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
583 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
585
587 }
588 }
589}
590
591void hip_dx(void* dx, void* delx, void* diagx, void* GG, void* dlambda,
592 int* n, int* m) {
593 const dim3 nthrds(1024, 1, 1);
594 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
595
598 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
600}
601
602void hip_dxsi(void* dxsi, void* xsi, void* dx, void* x,
603 void* alpha, real* epsi, int* n) {
604 const dim3 nthrds(1024, 1, 1);
605 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
606
609 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
611}
612
613void hip_deta(void* deta, void* eta, void* dx, void* x,
614 void* beta, real* epsi, int* n) {
615 const dim3 nthrds(1024, 1, 1);
616 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
617
620 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
622}
623
624void hip_rex(void* rex, void* x, void* xlow, void* xupp, void* pij,
625 void* p0j, void* qij, void* q0j, void* lambda, void* xsi, void* eta,
626 int* n, int* m) {
627 const dim3 nthrds(1024, 1, 1);
628 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
629
632 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
633 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
635}
636
637void hip_rey(void* rey, void* c, void* d, void* y, void* lambda, void* mu, int* n) {
638 const dim3 nthrds(1024, 1, 1);
639 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
640
643 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
645}
646
647
648 // a_d = b_d * c_d - d
649void hip_sub2cons(void *a, void *b, void *c, real *d, int *n) {
650 const dim3 nthrds(1024, 1, 1);
651 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
654 (real *)a, (real *)b, (real *)c, *d, *n);
656}
657
658
659// sum(a^2)
660real hip_norm(void* a, int* n) {
661 const dim3 nthrds(1024, 1, 1);
662 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
663 const int nb = ((*n) + 1024 - 1) / 1024;
665
666 if (nb > mma_red_s) {
667 mma_red_s = nb;
668 if (mma_bufred != NULL) {
669 HIP_CHECK(hipFreeHost(mma_bufred));
670 HIP_CHECK(hipFree(mma_bufred_d));
671 }
672 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
673 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
674 }
675
677 (real*)a, mma_bufred_d, (*n));
679
681 mma_bufred_d, nb);
683
684 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
686
688
689 return mma_bufred[0];
690}
691
692
693void hip_dely(void* dely, void* c, void* d, void* y, void* lambda,
694 real* epsi, int* n) {
695 const dim3 nthrds(1024, 1, 1);
696 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
699 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
701}
702
703
704real hip_maxval2(void* a, void* b, real* cons, int* n) {
705 const dim3 nthrds(1024, 1, 1);
706 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
707 const int nb = ((*n) + 1024 - 1) / 1024;
709
710 if (nb > mma_red_s) {
711 mma_red_s = nb;
712 if (mma_bufred != NULL) {
713 HIP_CHECK(hipFreeHost(mma_bufred));
714 HIP_CHECK(hipFree(mma_bufred_d));
715 }
716 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
717 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
718 }
719
721 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
723
725 mma_bufred_d, nb);
727
728 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
730
732
733 return mma_bufred[0];
734}
735
736
737real hip_maxval3(void* a, void* b, void* c, real* cons, int* n) {
738 const dim3 nthrds(1024, 1, 1);
739 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
740 const int nb = ((*n) + 1024 - 1) / 1024;
742
743 if (nb > mma_red_s) {
744 mma_red_s = nb;
745 if (mma_bufred != NULL) {
746 HIP_CHECK(hipFreeHost(mma_bufred));
747 HIP_CHECK(hipFree(mma_bufred_d));
748 }
749 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
750 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
751 }
752
754 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
756 mma_bufred_d, nb);
758
759 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
761
763
764 return mma_bufred[0];
765}
766
767
768void hip_kkt_rex(void* rex, void* df0dx, void* dfdx, void* xsi,
769 void* eta, void* lambda, int* n, int* m) {
770 const dim3 nthrds(1024, 1, 1);
771 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
774 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
775 (real*)eta, (real*)lambda, *n, *m);
777}
778
779
780// a_d = max(b, c * d_d)
781void hip_maxcons(void* a, real* b, real* c, void* d, int* n) {
782 const dim3 nthrds(1024, 1, 1);
783 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
786 (real*)a, *b, *c, (real*)d, *n);
788}
789
790
791real hip_lcsc2(void *a, void*b, int *n) {
792 const dim3 nthrds(1024, 1, 1);
793 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
794 const int nb = ((*n) + 1024 - 1) / 1024;
796
797 if (nb > mma_red_s) {
798 mma_red_s = nb;
799 if (mma_bufred != NULL) {
800 HIP_CHECK(hipFreeHost(mma_bufred));
801 HIP_CHECK(hipFree(mma_bufred_d));
802 }
803 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
804 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
805 }
806
808 (real*)a, (real*)b, mma_bufred_d, (*n));
810
812 mma_bufred_d, nb);
814
815 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
817
819
820 return mma_bufred[0];
821}
822
823
824void hip_mpisum(void *a, int *n) {
825#ifdef HAVE_DEVICE_MPI
826 real* temp = (real*)a;
829#endif
830}
831
832
833void hip_add2inv2(void* a, void* b, real* c, int* n) {
834 const dim3 nthrds(1024, 1, 1);
835 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
837
839 (real*)a, (real*)b, *c, *n);
841}
842
843void hip_max2(void* a, real* b, void* c, real* d, int* n) {
844 const dim3 nthrds(1024, 1, 1);
845 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
847
849 (real*)a, *b, (real*)c, *d, *n);
851}
852
853void hip_updatebb(void* bb, void* dellambda, void* dely, void* d,
854 void* mu, void* y, real* delz, int* m) {
855 const dim3 nthrds(1024, 1, 1);
856 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
858
860 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
861 (real*)mu, (real*)y, *delz, *m);
863}
864
865void hip_updateAA(void* AA, void* globaltmp_mm, void* s, void* lambda,
866 void* d, void* mu, void* y, void* a,
867 real* zeta, real* z, int* m) {
868 const dim3 nthrds(1024, 1, 1);
869 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
871
873 (real*)AA, (real*)globaltmp_mm, (real*)s,
874 (real*)lambda, (real*)d, (real*)mu,
875 (real*)y, (real*)a, *zeta, *z, *m);
877}
878
879void hip_dy(void* dy, void* dely, void* dlambda, void* d,
880 void* mu, void* y, int* n) {
881 const dim3 nthrds(1024, 1, 1);
882 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
884
886 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
887 (real*)mu, (real*)y, *n);
889}
890}
__global__ void heaviside_mapping_apply_kernel(const T beta, const T eta, T *__restrict__ X_out_d, T *__restrict__ X_in_d, const int n)