42#include <hip/hip_runtime.h>
43#include <hipsolver/hipsolver.h>
46#include <neko/device/device_config.h>
47#include <neko/device/hip/check.h>
48#include <neko/math/bcknd/device/device_mpi_op.h>
59void hipSOLVER_wrapper(
void*
A,
void*
b,
int n,
int*
jj) {
99void mma_prepare_aa_matrix_hip(
void*
AA,
void* s,
void* lambda,
100 void* d,
void* mu,
void* y,
101 void* a,
real zeta,
real z,
int m) {
117void mma_prepare_hessian_hip(
void*
Hess,
void* y,
void* d,
118 void* mu,
void* lambda,
int m) {
147 for (
int i = 0; i <
M; i++) {
157 for (
int i = 0; i <
M; i++) {
173extern "C" void hip_custom_solver(
void*
A,
void*
b,
int n,
int*
info) {
197 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
206 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
207 const int nb = ((*n) + 1024 - 1) / 1024;
211 if (
nb > mma_red_s) {
213 if (mma_bufred !=
NULL) {
221 for (
int i = 0; i < (*m); i++) {
222 for (
int j = 0;
j < (*m);
j++) {
232 (
real*)
Hess, mma_bufred_d, 1, i +
j * (*m));
241 void* low,
void* upp,
void* alpha,
void* beta,
int* n) {
243 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
250void mma_dipsolvesub1_hip(
void* x,
void*
pjlambda,
void*
qjlambda,
void* low,
251 void* upp,
void* alpha,
void* beta,
int* n) {
253 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
260void mattrans_v_mul_hip(
void* output,
void* pij,
void* lambda,
int* m,
int* n) {
262 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
268void mma_gensub4_hip(
const void* x,
const void* low,
const void* upp,
269 const void* pij,
const void* qij,
270 const int* n,
const int* m,
void* bi) {
277 const int nb = (
N + 1023) / 1024;
280 if (
nb > mma_red_s) {
283 if (mma_bufred !=
nullptr) {
297 static_cast<const real*
>(x),
298 static_cast<const real*
>(low),
299 static_cast<const real*
>(upp),
300 static_cast<const real*
>(pij),
301 static_cast<const real*
>(qij),
304 for (
int i = 0; i <
M; ++i) {
306 temp, mma_bufred_d,
N,
M, i);
314 bi_d + i, mma_bufred_d,
sizeof(
real),
324void mma_gensub3_hip(
void* x,
void*
df0dx,
void*
dfdx,
void* low,
325 void* upp,
void* xmin,
void* xmax,
void* alpha,
326 void* beta,
void* p0j,
void* q0j,
void* pij,
327 void* qij,
int* n,
int* m) {
329 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
341void mma_gensub2_hip(
void* low,
void* upp,
void* x,
void* xold1,
343 real* asyincr,
int* n) {
345 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
355void mma_gensub1_hip(
void* low,
void* upp,
void* x,
void* xmin,
void* xmax,
356 real* asyinit,
int* n) {
358 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
368void hip_mma_max(
void* xsi,
void* x,
void* alpha,
int* n) {
370 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
380 void* pij,
void* qij,
int* n,
int* m) {
382 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
386 if (
nb > mma_red_s) {
388 if (mma_bufred !=
NULL) {
403 for (
int i = 0; i < (*m); i++) {
405 temp, mma_bufred_d, (*n), (*m), i);
422void hip_sub2cons2(
void* a,
void*
b,
void* c,
void* d,
real*
e,
int* n) {
424 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
433real hip_maxval(
void* a,
int* n) {
435 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
439 if (
nb > mma_red_s) {
441 if (mma_bufred !=
NULL) {
450 (
real*)a, mma_bufred_d, (*n));
461 return mma_bufred[0];
465void hip_delx(
void*
delx,
void* x,
void*
xlow,
void*
xupp,
void* pij,
466 void* qij,
void* p0j,
void* q0j,
void* alpha,
void* beta,
void* lambda,
469 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
479void hip_GG(
void*
GG,
void* x,
void*
xlow,
void*
xupp,
480 void* pij,
void* qij,
int* n,
int* m) {
482 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
491void hip_diagx(
void*
diagx,
void* x,
void* xsi,
void*
xlow,
void*
xupp,
492 void* p0j,
void* q0j,
void* pij,
void* qij,
void* alpha,
void* beta,
493 void* eta,
void* lambda,
int *n,
int *m) {
495 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
505void hip_bb(
void*
bb,
void*
GG,
void*
delx,
void*
diagx,
int *n,
int *m) {
507 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
508 const int nb = ((*n) + 1024 - 1)/ 1024;
513 if (
nb > mma_red_s) {
515 if (mma_bufred !=
NULL) {
523 for (
int i = 0; i < (*m); i++) {
540void hip_AA(
void*
AA,
void*
GG,
void*
diagx,
int *n,
int *m) {
542 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
543 const int nb = ((*n) + 1024 - 1)/ 1024;
548 if (
nb > mma_red_s) {
550 if (mma_bufred !=
NULL) {
558 for (
int i = 0; i < (*m); i++) {
559 for (
int j = 0;
j < (*m);
j++) {
569 (
real*)
AA, mma_bufred_d, 1, i +
j * (*m + 1));
580 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
588void hip_dxsi(
void*
dxsi,
void* xsi,
void*
dx,
void* x,
591 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
599void hip_deta(
void*
deta,
void* eta,
void*
dx,
void* x,
602 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
610void hip_rex(
void*
rex,
void* x,
void*
xlow,
void*
xupp,
void* pij,
611 void* p0j,
void* qij,
void* q0j,
void* lambda,
void* xsi,
void* eta,
614 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
623void hip_rey(
void*
rey,
void* c,
void* d,
void* y,
void* lambda,
void* mu,
int* n) {
625 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
635void hip_sub2cons(
void *a,
void *
b,
void *c,
real *d,
int *n) {
637 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
646real hip_norm(
void* a,
int* n) {
648 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
649 const int nb = ((*n) + 1024 - 1) / 1024;
652 if (
nb > mma_red_s) {
654 if (mma_bufred !=
NULL) {
663 (
real*)a, mma_bufred_d, (*n));
675 return mma_bufred[0];
679void hip_dely(
void*
dely,
void* c,
void* d,
void* y,
void* lambda,
682 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
692 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
693 const int nb = ((*n) + 1024 - 1) / 1024;
696 if (
nb > mma_red_s) {
698 if (mma_bufred !=
NULL) {
719 return mma_bufred[0];
723real hip_maxval3(
void* a,
void*
b,
void* c,
real*
cons,
int* n) {
725 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
726 const int nb = ((*n) + 1024 - 1) / 1024;
729 if (
nb > mma_red_s) {
731 if (mma_bufred !=
NULL) {
750 return mma_bufred[0];
754void hip_kkt_rex(
void*
rex,
void*
df0dx,
void*
dfdx,
void* xsi,
755 void* eta,
void* lambda,
int* n,
int* m) {
757 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
767void hip_maxcons(
void* a,
real*
b,
real* c,
void* d,
int* n) {
769 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
777real hip_lcsc2(
void *a,
void*
b,
int *n) {
779 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
780 const int nb = ((*n) + 1024 - 1) / 1024;
783 if (
nb > mma_red_s) {
785 if (mma_bufred !=
NULL) {
806 return mma_bufred[0];
810void hip_mpisum(
void *a,
int *n) {
811#ifdef HAVE_DEVICE_MPI
819void hip_add2inv2(
void* a,
void*
b,
real* c,
int* n) {
821 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
829void hip_max2(
void* a,
real*
b,
void* c,
real* d,
int* n) {
831 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
840 void* mu,
void* y,
real*
delz,
int* m) {
842 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
851void hip_updateAA(
void*
AA,
void*
globaltmp_mm,
void* s,
void* lambda,
852 void* d,
void* mu,
void* y,
void* a,
855 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
866 void* mu,
void* y,
int* n) {
868 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
__global__ void convex_down_RAMP_mapping_apply_kernel(const T f_min, const T f_max, const T q, T *__restrict__ X_out_d, T *__restrict__ X_in_d, const int n)