40#include <hip/hip_runtime.h>
41#include <hipsolver/hipsolver.h>
44#include <neko/device/device_config.h>
45#include <neko/device/hip/check.h>
46#include <neko/math/bcknd/device/device_mpi_op.h>
49#include "mma_kernel.h"
54real * mma_bufred = NULL;
55real * mma_bufred_d = NULL;
57void hipSOLVER_wrapper(
void* A,
void* b,
int n,
int* jj) {
58 hipsolverDnHandle_t handle;
59 hipsolverStatus_t status;
60 hipsolverDnCreate(&handle);
69 status = hipsolverDnDgetrf_bufferSize(handle, n, n, (
double*)A, n, &lwork);
70 hipMalloc(&workspace, lwork *
sizeof(
double));
71 hipMalloc(&ipiv, n *
sizeof(
int));
72 hipMalloc(&info,
sizeof(
int));
75 hipsolverDnDgetrf(handle, n, n, (
double*)A, n, workspace, ipiv, info);
78 hipMemcpy(&host_info, info,
sizeof(
int), hipMemcpyDeviceToHost);
82 hipsolverDnDgetrs(handle, HIPSOLVER_OP_N, n, 1, (
double*)A, n, ipiv, (
double*)b, n, info);
84 hipMemcpy(&host_info, info,
sizeof(
int), hipMemcpyDeviceToHost);
94 hipsolverDnDestroy(handle);
97void mma_prepare_aa_matrix_hip(
void* AA,
void* s,
void* lambda,
98 void* d,
void* mu,
void* y,
99 void* a, real zeta, real z,
int m) {
101 const int matrix_size = M + 1;
102 const dim3 nthrds(256, 1, 1);
103 const dim3 nblcks((M + 256 - 1) / 256, 1, 1);
104 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
107 hipLaunchKernelGGL(mma_prepare_aa_matrix_kernel<real>,
108 nblcks, nthrds, 0, stream,
109 (real*)AA, (real*)s, (real*)lambda, (real*)d,
110 (real*)mu, (real*)y, (real*)a, zeta, z, M);
112 HIP_CHECK(hipGetLastError());
115void mma_prepare_hessian_hip(
void* Hess,
void* y,
void* d,
116 void* mu,
void* lambda,
int m) {
118 const dim3 nthrds(1024, 1, 1);
119 const dim3 nblcks((M + 1024 - 1) / 1024, 1, 1);
120 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
123 hipLaunchKernelGGL(mma_update_hessian_diagonal_kernel<real>,
124 nblcks, nthrds, 0, stream,
125 (real*)Hess, (real*)y, (real*)d, (real*)mu, (real*)lambda, M);
126 HIP_CHECK(hipGetLastError());
129 HIP_CHECK(hipStreamSynchronize(stream));
134 const dim3 stab_nblcks(1, 1, 1);
135 hipLaunchKernelGGL(mma_stabilize_hessian_single_kernel<real>,
136 stab_nblcks, nthrds, 0, stream,
138 HIP_CHECK(hipGetLastError());
142 real* h_Hess = (real*)malloc(M *
sizeof(real));
145 for (
int i = 0; i < M; i++) {
146 HIP_CHECK(hipMemcpyAsync(&h_Hess[i],
147 (real*)Hess + i * M + i,
149 hipMemcpyDeviceToHost, stream));
151 HIP_CHECK(hipStreamSynchronize(stream));
155 for (
int i = 0; i < M; i++) {
158 real lm_factor = fmax(-1.0e-4 * trace / M, 1.0e-7);
161 hipLaunchKernelGGL(mma_stabilize_hessian_multi_kernel<real>,
162 nblcks, nthrds, 0, stream,
163 (real*)Hess, lm_factor, M);
164 HIP_CHECK(hipGetLastError());
171extern "C" void hip_custom_solver(
void* A,
void* b,
int n,
int* info) {
172 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
178 const dim3 nthrds(1024, 1, 1);
179 const dim3 nblcks(1, 1, 1);
181 hipLaunchKernelGGL(mma_small_lu_kernel<real>, nblcks, nthrds, 0, stream,
182 (real*)A, (real*)b, n);
184 hipError_t err = hipGetLastError();
185 if (err == hipSuccess) {
192void delta_1dbeam_hip(
void* Delta, real* L_total, real* Le,
193 int* offset,
int* n) {
194 const dim3 nthrds(1024, 1, 1);
195 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
196 hipLaunchKernelGGL(delta_1dbeam_kernel<real>,
197 nblcks, nthrds, 0, (hipStream_t)glb_cmd_queue,
198 (real*)Delta, *L_total, *Le, *offset, *n);
199 HIP_CHECK(hipGetLastError());
202void hip_Hess(
void* Hess,
void* hijx,
void* Ljjxinv,
int *n,
int *m) {
203 const dim3 nthrds(1024, 1, 1);
204 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
205 const int nb = ((*n) + 1024 - 1) / 1024;
206 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
207 hipStreamSynchronize(stream);
209 if (nb > mma_red_s) {
211 if (mma_bufred != NULL) {
212 HIP_CHECK(hipHostFree(mma_bufred));
213 HIP_CHECK(hipFree(mma_bufred_d));
215 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
216 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
219 for (
int i = 0; i < (*m); i++) {
220 for (
int j = 0; j < (*m); j++) {
221 hipLaunchKernelGGL(mmasumHess_kernel<real>, nblcks, nthrds, 0, stream,
222 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
223 HIP_CHECK(hipGetLastError());
225 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
227 HIP_CHECK(hipGetLastError());
229 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
230 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
231 HIP_CHECK(hipGetLastError());
233 hipStreamSynchronize(stream);
238void mma_Ljjxinv_hip(
void* Ljjxinv,
void* pjlambda,
void* qjlambda,
void* x,
239 void* low,
void* upp,
void* alpha,
void* beta,
int* n) {
240 const dim3 nthrds(1024, 1, 1);
241 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
242 hipLaunchKernelGGL(mma_Ljjxinv_kernel<real>, nblcks, nthrds, 0,
243 (hipStream_t)glb_cmd_queue, (real*)Ljjxinv, (real*)pjlambda, (real*)qjlambda,
244 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
245 HIP_CHECK(hipGetLastError());
248void mma_dipsolvesub1_hip(
void* x,
void* pjlambda,
void* qjlambda,
void* low,
249 void* upp,
void* alpha,
void* beta,
int* n) {
250 const dim3 nthrds(1024, 1, 1);
251 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
252 hipLaunchKernelGGL(mma_dipsolvesub1_kernel<real>, nblcks, nthrds, 0,
253 (hipStream_t)glb_cmd_queue, (real*)x, (real*)pjlambda, (real*)qjlambda,
254 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
255 HIP_CHECK(hipGetLastError());
258void mattrans_v_mul_hip(
void* output,
void* pij,
void* lambda,
int* m,
int* n) {
259 const dim3 nthrds(1024, 1, 1);
260 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
261 hipLaunchKernelGGL(mattrans_v_mul_kernel<real>, nblcks, nthrds, 0,
262 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
263 HIP_CHECK(hipGetLastError());
266void mma_gensub4_hip(
const void* x,
const void* low,
const void* upp,
267 const void* pij,
const void* qij,
268 const int* n,
const int* m,
void* bi) {
273 const dim3 nthrds(1024, 1, 1);
274 const dim3 nblcks((N + 1023) / 1024, 1, 1);
275 const int nb = (N + 1023) / 1024;
276 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
278 if (nb > mma_red_s) {
281 if (mma_bufred !=
nullptr) {
282 HIP_CHECK(hipFreeHost(mma_bufred));
283 HIP_CHECK(hipFree(mma_bufred_d));
286 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
287 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
291 real* bi_d =
static_cast<real*
>(bi);
292 HIP_CHECK(hipMalloc(&temp, M * N *
sizeof(real)));
294 hipLaunchKernelGGL(mma_sub4_kernel<real>, nblcks, nthrds, 0, stream,
295 static_cast<const real*
>(x),
296 static_cast<const real*
>(low),
297 static_cast<const real*
>(upp),
298 static_cast<const real*
>(pij),
299 static_cast<const real*
>(qij),
302 for (
int i = 0; i < M; ++i) {
303 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
304 temp, mma_bufred_d, N, M, i);
305 HIP_CHECK(hipGetLastError());
307 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
309 HIP_CHECK(hipGetLastError());
311 HIP_CHECK(hipMemcpyAsync(
312 bi_d + i, mma_bufred_d,
sizeof(real),
313 hipMemcpyDeviceToDevice, stream));
315 HIP_CHECK(hipStreamSynchronize(stream));
318 HIP_CHECK(hipFree(temp));
322void mma_gensub3_hip(
void* x,
void* df0dx,
void* dfdx,
void* low,
323 void* upp,
void* xmin,
void* xmax,
void* alpha,
324 void* beta,
void* p0j,
void* q0j,
void* pij,
325 void* qij,
int* n,
int* m) {
326 const dim3 nthrds(1024, 1, 1);
327 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
329 hipLaunchKernelGGL(mma_sub3_kernel<real>, nblcks, nthrds, 0,
330 (hipStream_t)glb_cmd_queue,
331 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
332 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
333 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
336 HIP_CHECK(hipGetLastError());
339void mma_gensub2_hip(
void* low,
void* upp,
void* x,
void* xold1,
340 void* xold2,
void* xdiff, real* asydecr,
341 real* asyincr,
int* n) {
342 const dim3 nthrds(1024, 1, 1);
343 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
345 hipLaunchKernelGGL(mma_sub2_kernel<real>, nblcks, nthrds, 0,
346 (hipStream_t)glb_cmd_queue,
347 (real*)low, (real*)upp, (real*)x, (real*)xold1,
348 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
350 HIP_CHECK(hipGetLastError());
353void mma_gensub1_hip(
void* low,
void* upp,
void* x,
void* xmin,
void* xmax,
354 real* asyinit,
int* n) {
355 const dim3 nthrds(1024, 1, 1);
356 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
358 hipLaunchKernelGGL(mma_sub1_kernel<real>, nblcks, nthrds, 0,
359 (hipStream_t)glb_cmd_queue,
360 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
363 HIP_CHECK(hipGetLastError());
366void hip_mma_max(
void* xsi,
void* x,
void* alpha,
int* n) {
367 const dim3 nthrds(1024, 1, 1);
368 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
370 hipLaunchKernelGGL(mma_max2_kernel<real>, nblcks, nthrds, 0,
371 (hipStream_t)glb_cmd_queue,
372 (real*)xsi, (real*)x, (real*)alpha, *n);
374 HIP_CHECK(hipGetLastError());
377void hip_relambda(
void* relambda,
void* x,
void* xupp,
void* xlow,
378 void* pij,
void* qij,
int* n,
int* m) {
379 const dim3 nthrds(1024, 1, 1);
380 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
381 const int nb = nblcks.x;
382 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
384 if (nb > mma_red_s) {
386 if (mma_bufred != NULL) {
387 HIP_CHECK(hipHostFree(mma_bufred));
388 HIP_CHECK(hipFree(mma_bufred_d));
390 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
391 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
395 hipMalloc(&temp, (*n) * (*m) *
sizeof(real));
397 hipLaunchKernelGGL(relambda_kernel<real>, nblcks, nthrds, 0, stream,
398 temp, (real*)x, (real*)xupp, (real*)xlow,
399 (real*)pij, (real*)qij, *n, *m);
401 for (
int i = 0; i < (*m); i++) {
402 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
403 temp, mma_bufred_d, (*n), (*m), i);
404 HIP_CHECK(hipGetLastError());
406 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0,
407 stream, mma_bufred_d, nb);
408 HIP_CHECK(hipGetLastError());
410 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
411 (real*)relambda, mma_bufred_d, 1, i);
412 HIP_CHECK(hipGetLastError());
414 hipStreamSynchronize(stream);
420void hip_sub2cons2(
void* a,
void* b,
void* c,
void* d, real* e,
int* n) {
421 const dim3 nthrds(1024, 1, 1);
422 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
424 hipLaunchKernelGGL(sub2cons2_kernel<real>, nblcks, nthrds, 0,
425 (hipStream_t)glb_cmd_queue,
426 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
428 HIP_CHECK(hipGetLastError());
431real hip_maxval(
void* a,
int* n) {
432 const dim3 nthrds(1024, 1, 1);
433 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
434 const int nb = nblcks.x;
435 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
437 if (nb > mma_red_s) {
439 if (mma_bufred != NULL) {
440 HIP_CHECK(hipHostFree(mma_bufred));
441 HIP_CHECK(hipFree(mma_bufred_d));
443 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
444 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
447 hipLaunchKernelGGL(maxval_kernel<real>, nblcks, nthrds, 0, stream,
448 (real*)a, mma_bufred_d, (*n));
449 HIP_CHECK(hipGetLastError());
451 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
453 HIP_CHECK(hipGetLastError());
455 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
456 hipMemcpyDeviceToHost, stream));
457 hipStreamSynchronize(stream);
459 return mma_bufred[0];
463void hip_delx(
void* delx,
void* x,
void* xlow,
void* xupp,
void* pij,
464 void* qij,
void* p0j,
void* q0j,
void* alpha,
void* beta,
void* lambda,
465 real* epsi,
int* n,
int* m) {
466 const dim3 nthrds(1024, 1, 1);
467 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
469 hipLaunchKernelGGL(delx_kernel<real>, nblcks, nthrds, 0,
470 (hipStream_t)glb_cmd_queue,
471 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
472 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
473 (real*)lambda, *epsi, *n, *m);
474 HIP_CHECK(hipGetLastError());
477void hip_GG(
void* GG,
void* x,
void* xlow,
void* xupp,
478 void* pij,
void* qij,
int* n,
int* m) {
479 const dim3 nthrds(1024, 1, 1);
480 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
482 hipLaunchKernelGGL(GG_kernel<real>, nblcks, nthrds, 0,
483 (hipStream_t)glb_cmd_queue,
484 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
486 HIP_CHECK(hipGetLastError());
489void hip_diagx(
void* diagx,
void* x,
void* xsi,
void* xlow,
void* xupp,
490 void* p0j,
void* q0j,
void* pij,
void* qij,
void* alpha,
void* beta,
491 void* eta,
void* lambda,
int *n,
int *m) {
492 const dim3 nthrds(1024, 1, 1);
493 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
495 hipLaunchKernelGGL(diagx_kernel<real>, nblcks, nthrds, 0,
496 (hipStream_t)glb_cmd_queue,
497 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
498 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
499 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
500 HIP_CHECK(hipGetLastError());
503void hip_bb(
void* bb,
void* GG,
void* delx,
void* diagx,
int *n,
int *m) {
504 const dim3 nthrds(1024, 1, 1);
505 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
506 const int nb = ((*n) + 1024 - 1)/ 1024;
507 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
509 hipStreamSynchronize(stream);
511 if (nb > mma_red_s) {
513 if (mma_bufred != NULL) {
514 HIP_CHECK(hipHostFree(mma_bufred));
515 HIP_CHECK(hipFree(mma_bufred_d));
517 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
518 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
521 for (
int i = 0; i < (*m); i++) {
522 hipLaunchKernelGGL(mmasumbb_kernel<real>, nblcks, nthrds, 0, stream,
523 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
524 HIP_CHECK(hipGetLastError());
526 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
528 HIP_CHECK(hipGetLastError());
530 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
532 HIP_CHECK(hipGetLastError());
534 hipStreamSynchronize(stream);
538void hip_AA(
void* AA,
void* GG,
void* diagx,
int *n,
int *m) {
539 const dim3 nthrds(1024, 1, 1);
540 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
541 const int nb = ((*n) + 1024 - 1)/ 1024;
542 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
544 hipStreamSynchronize(stream);
546 if (nb > mma_red_s) {
548 if (mma_bufred != NULL) {
549 HIP_CHECK(hipHostFree(mma_bufred));
550 HIP_CHECK(hipFree(mma_bufred_d));
552 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
553 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
556 for (
int i = 0; i < (*m); i++) {
557 for (
int j = 0; j < (*m); j++) {
558 hipLaunchKernelGGL(mmasumAA_kernel<real>, nblcks, nthrds, 0, stream,
559 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
560 HIP_CHECK(hipGetLastError());
562 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
564 HIP_CHECK(hipGetLastError());
566 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
567 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
568 HIP_CHECK(hipGetLastError());
570 hipStreamSynchronize(stream);
575void hip_dx(
void* dx,
void* delx,
void* diagx,
void* GG,
void* dlambda,
577 const dim3 nthrds(1024, 1, 1);
578 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
580 hipLaunchKernelGGL(dx_kernel<real>, nblcks, nthrds, 0,
581 (hipStream_t)glb_cmd_queue,
582 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
583 HIP_CHECK(hipGetLastError());
586void hip_dxsi(
void* dxsi,
void* xsi,
void* dx,
void* x,
587 void* alpha, real* epsi,
int* n) {
588 const dim3 nthrds(1024, 1, 1);
589 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
591 hipLaunchKernelGGL(dxsi_kernel<real>, nblcks, nthrds, 0,
592 (hipStream_t)glb_cmd_queue,
593 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
594 HIP_CHECK(hipGetLastError());
597void hip_deta(
void* deta,
void* eta,
void* dx,
void* x,
598 void* beta, real* epsi,
int* n) {
599 const dim3 nthrds(1024, 1, 1);
600 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
602 hipLaunchKernelGGL(deta_kernel<real>, nblcks, nthrds, 0,
603 (hipStream_t)glb_cmd_queue,
604 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
605 HIP_CHECK(hipGetLastError());
608void hip_rex(
void* rex,
void* x,
void* xlow,
void* xupp,
void* pij,
609 void* p0j,
void* qij,
void* q0j,
void* lambda,
void* xsi,
void* eta,
611 const dim3 nthrds(1024, 1, 1);
612 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
614 hipLaunchKernelGGL(RexCalculation_kernel<real>, nblcks, nthrds, 0,
615 (hipStream_t)glb_cmd_queue,
616 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
617 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
618 HIP_CHECK(hipGetLastError());
621void hip_rey(
void* rey,
void* c,
void* d,
void* y,
void* lambda,
void* mu,
int* n) {
622 const dim3 nthrds(1024, 1, 1);
623 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
625 hipLaunchKernelGGL(rey_calculation_kernel<real>, nblcks, nthrds, 0,
626 (hipStream_t)glb_cmd_queue,
627 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
628 HIP_CHECK(hipGetLastError());
633void hip_sub2cons(
void *a,
void *b,
void *c, real *d,
int *n) {
634 const dim3 nthrds(1024, 1, 1);
635 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
636 hipLaunchKernelGGL(sub2cons_kernel<real>, nblcks, nthrds, 0,
637 (hipStream_t)glb_cmd_queue,
638 (real *)a, (real *)b, (real *)c, *d, *n);
639 HIP_CHECK(hipGetLastError());
644real hip_norm(
void* a,
int* n) {
645 const dim3 nthrds(1024, 1, 1);
646 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
647 const int nb = ((*n) + 1024 - 1) / 1024;
648 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
650 if (nb > mma_red_s) {
652 if (mma_bufred != NULL) {
653 HIP_CHECK(hipFreeHost(mma_bufred));
654 HIP_CHECK(hipFree(mma_bufred_d));
656 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
657 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
660 hipLaunchKernelGGL(norm_kernel<real>, nblcks, nthrds, 0, stream,
661 (real*)a, mma_bufred_d, (*n));
662 HIP_CHECK(hipGetLastError());
664 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
666 HIP_CHECK(hipGetLastError());
668 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
669 hipMemcpyDeviceToHost, stream));
671 hipStreamSynchronize(stream);
673 return mma_bufred[0];
677void hip_dely(
void* dely,
void* c,
void* d,
void* y,
void* lambda,
678 real* epsi,
int* n) {
679 const dim3 nthrds(1024, 1, 1);
680 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
681 hipLaunchKernelGGL(dely_kernel<real>, nblcks, nthrds, 0,
682 (hipStream_t)glb_cmd_queue,
683 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
684 HIP_CHECK(hipGetLastError());
688real hip_maxval2(
void* a,
void* b, real* cons,
int* n) {
689 const dim3 nthrds(1024, 1, 1);
690 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
691 const int nb = ((*n) + 1024 - 1) / 1024;
692 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
694 if (nb > mma_red_s) {
696 if (mma_bufred != NULL) {
697 HIP_CHECK(hipFreeHost(mma_bufred));
698 HIP_CHECK(hipFree(mma_bufred_d));
700 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
701 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
704 hipLaunchKernelGGL(maxval2_kernel<real>, nblcks, nthrds, 0, stream,
705 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
706 HIP_CHECK(hipGetLastError());
708 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
710 HIP_CHECK(hipGetLastError());
712 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
713 hipMemcpyDeviceToHost, stream));
715 hipStreamSynchronize(stream);
717 return mma_bufred[0];
721real hip_maxval3(
void* a,
void* b,
void* c, real* cons,
int* n) {
722 const dim3 nthrds(1024, 1, 1);
723 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
724 const int nb = ((*n) + 1024 - 1) / 1024;
725 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
727 if (nb > mma_red_s) {
729 if (mma_bufred != NULL) {
730 HIP_CHECK(hipFreeHost(mma_bufred));
731 HIP_CHECK(hipFree(mma_bufred_d));
733 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
734 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
737 hipLaunchKernelGGL(maxval3_kernel<real>, nblcks, nthrds, 0, stream,
738 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
739 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
741 HIP_CHECK(hipGetLastError());
743 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
744 hipMemcpyDeviceToHost, stream));
746 hipStreamSynchronize(stream);
748 return mma_bufred[0];
752void hip_kkt_rex(
void* rex,
void* df0dx,
void* dfdx,
void* xsi,
753 void* eta,
void* lambda,
int* n,
int* m) {
754 const dim3 nthrds(1024, 1, 1);
755 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
756 hipLaunchKernelGGL(kkt_rex_kernel<real>, nblcks, nthrds, 0,
757 (hipStream_t)glb_cmd_queue,
758 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
759 (real*)eta, (real*)lambda, *n, *m);
760 HIP_CHECK(hipGetLastError());
765void hip_maxcons(
void* a, real* b, real* c,
void* d,
int* n) {
766 const dim3 nthrds(1024, 1, 1);
767 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
768 hipLaunchKernelGGL(maxcons_kernel<real>, nblcks, nthrds, 0,
769 (hipStream_t)glb_cmd_queue,
770 (real*)a, *b, *c, (real*)d, *n);
771 HIP_CHECK(hipGetLastError());
775real hip_lcsc2(
void *a,
void*b,
int *n) {
776 const dim3 nthrds(1024, 1, 1);
777 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
778 const int nb = ((*n) + 1024 - 1) / 1024;
779 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
781 if (nb > mma_red_s) {
783 if (mma_bufred != NULL) {
784 HIP_CHECK(hipFreeHost(mma_bufred));
785 HIP_CHECK(hipFree(mma_bufred_d));
787 HIP_CHECK(hipHostMalloc(&mma_bufred, nb *
sizeof(real)));
788 HIP_CHECK(hipMalloc(&mma_bufred_d, nb *
sizeof(real)));
791 hipLaunchKernelGGL(glsc2_kernel<real>, nblcks, nthrds, 0, stream,
792 (real*)a, (real*)b, mma_bufred_d, (*n));
793 HIP_CHECK(hipGetLastError());
795 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
797 HIP_CHECK(hipGetLastError());
799 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d,
sizeof(real),
800 hipMemcpyDeviceToHost, stream));
802 hipStreamSynchronize(stream);
804 return mma_bufred[0];
808void hip_mpisum(
void *a,
int *n) {
809#ifdef HAVE_DEVICE_MPI
810 real* temp = (real*)a;
811 hipStreamSynchronize(stream);
812 device_mpi_allreduce_inplace(temp, *n,
sizeof(real), DEVICE_MPI_SUM);
817void hip_add2inv2(
void* a,
void* b, real* c,
int* n) {
818 const dim3 nthrds(1024, 1, 1);
819 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
820 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
822 hipLaunchKernelGGL(add2inv2_kernel<real>, nblcks, nthrds, 0, stream,
823 (real*)a, (real*)b, *c, *n);
824 HIP_CHECK(hipGetLastError());
827void hip_max2(
void* a, real* b,
void* c, real* d,
int* n) {
828 const dim3 nthrds(1024, 1, 1);
829 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
830 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
832 hipLaunchKernelGGL(max2_kernel<real>, nblcks, nthrds, 0, stream,
833 (real*)a, *b, (real*)c, *d, *n);
834 HIP_CHECK(hipGetLastError());
837void hip_updatebb(
void* bb,
void* dellambda,
void* dely,
void* d,
838 void* mu,
void* y, real* delz,
int* m) {
839 const dim3 nthrds(1024, 1, 1);
840 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
841 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
843 hipLaunchKernelGGL(updatebb_kernel<real>, nblcks, nthrds, 0, stream,
844 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
845 (real*)mu, (real*)y, *delz, *m);
846 HIP_CHECK(hipGetLastError());
849void hip_updateAA(
void* AA,
void* globaltmp_mm,
void* s,
void* lambda,
850 void* d,
void* mu,
void* y,
void* a,
851 real* zeta, real* z,
int* m) {
852 const dim3 nthrds(1024, 1, 1);
853 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
854 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
856 hipLaunchKernelGGL(updateAA_kernel<real>, nblcks, nthrds, 0, stream,
857 (real*)AA, (real*)globaltmp_mm, (real*)s,
858 (real*)lambda, (real*)d, (real*)mu,
859 (real*)y, (real*)a, *zeta, *z, *m);
860 HIP_CHECK(hipGetLastError());
863void hip_dy(
void* dy,
void* dely,
void* dlambda,
void* d,
864 void* mu,
void* y,
int* n) {
865 const dim3 nthrds(1024, 1, 1);
866 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
867 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
869 hipLaunchKernelGGL(dy_kernel<real>, nblcks, nthrds, 0, stream,
870 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
871 (real*)mu, (real*)y, *n);
872 HIP_CHECK(hipGetLastError());