42#include <hip/hip_runtime.h>
43#include <hipsolver/hipsolver.h>
46#include <neko/device/device_config.h>
47#include <neko/device/hip/check.h>
48#include <neko/math/bcknd/device/device_mpi_op.h>
59void mma_update_hessian_z_hip(
void*
Hess,
void* a,
int m) {
73void hipSOLVER_wrapper(
void*
A,
void*
b,
int n,
int*
jj) {
113void mma_prepare_aa_matrix_hip(
void*
AA,
void* s,
void* lambda,
114 void* d,
void* mu,
void* y,
115 void* a,
real zeta,
real z,
int m) {
131void mma_prepare_hessian_hip(
void*
Hess,
void* y,
132 void* mu,
void* lambda,
int m) {
161 for (
int i = 0; i <
M; i++) {
171 for (
int i = 0; i <
M; i++) {
187extern "C" void hip_custom_solver(
void*
A,
void*
b,
int n,
int*
info) {
211 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
220 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
221 const int nb = ((*n) + 1024 - 1) / 1024;
225 if (
nb > mma_red_s) {
227 if (mma_bufred !=
NULL) {
235 for (
int i = 0; i < (*m); i++) {
236 for (
int j = 0;
j < (*m);
j++) {
246 (
real*)
Hess, mma_bufred_d, 1, i +
j * (*m));
255 void* low,
void* upp,
void* alpha,
void* beta,
int* n) {
257 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
264void mma_dipsolvesub1_hip(
void* x,
void*
pjlambda,
void*
qjlambda,
void* low,
265 void* upp,
void* alpha,
void* beta,
int* n) {
267 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
274void mattrans_v_mul_hip(
void* output,
void* pij,
void* lambda,
int* m,
int* n) {
276 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
282void mma_gensub4_hip(
const void* x,
const void* low,
const void* upp,
283 const void* pij,
const void* qij,
284 const int* n,
const int* m,
void* bi) {
291 const int nb = (
N + 1023) / 1024;
294 if (
nb > mma_red_s) {
297 if (mma_bufred !=
nullptr) {
311 static_cast<const real*
>(x),
312 static_cast<const real*
>(low),
313 static_cast<const real*
>(upp),
314 static_cast<const real*
>(pij),
315 static_cast<const real*
>(qij),
318 for (
int i = 0; i <
M; ++i) {
320 temp, mma_bufred_d,
N,
M, i);
328 bi_d + i, mma_bufred_d,
sizeof(
real),
338void mma_gensub3_hip(
void* x,
void*
df0dx,
void*
dfdx,
void* low,
339 void* upp,
void* xmin,
void* xmax,
void* alpha,
340 void* beta,
void* p0j,
void* q0j,
void* pij,
341 void* qij,
int* n,
int* m) {
343 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
355void mma_gensub2_hip(
void* low,
void* upp,
void* x,
void* xold1,
357 real* asyincr,
int* n) {
359 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
369void mma_gensub1_hip(
void* low,
void* upp,
void* x,
void* xmin,
void* xmax,
370 real* asyinit,
int* n) {
372 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
382void hip_mma_max(
void* xsi,
void* x,
void* alpha,
int* n) {
384 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
394 void* pij,
void* qij,
int* n,
int* m) {
396 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
400 if (
nb > mma_red_s) {
402 if (mma_bufred !=
NULL) {
417 for (
int i = 0; i < (*m); i++) {
419 temp, mma_bufred_d, (*n), (*m), i);
436void hip_sub2cons2(
void* a,
void*
b,
void* c,
void* d,
real*
e,
int* n) {
438 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
447real hip_maxval(
void* a,
int* n) {
449 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
453 if (
nb > mma_red_s) {
455 if (mma_bufred !=
NULL) {
464 (
real*)a, mma_bufred_d, (*n));
475 return mma_bufred[0];
479void hip_delx(
void*
delx,
void* x,
void*
xlow,
void*
xupp,
void* pij,
480 void* qij,
void* p0j,
void* q0j,
void* alpha,
void* beta,
void* lambda,
483 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
493void hip_GG(
void*
GG,
void* x,
void*
xlow,
void*
xupp,
494 void* pij,
void* qij,
int* n,
int* m) {
496 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
505void hip_diagx(
void*
diagx,
void* x,
void* xsi,
void*
xlow,
void*
xupp,
506 void* p0j,
void* q0j,
void* pij,
void* qij,
void* alpha,
void* beta,
507 void* eta,
void* lambda,
int *n,
int *m) {
509 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
519void hip_bb(
void*
bb,
void*
GG,
void*
delx,
void*
diagx,
int *n,
int *m) {
521 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
522 const int nb = ((*n) + 1024 - 1)/ 1024;
527 if (
nb > mma_red_s) {
529 if (mma_bufred !=
NULL) {
537 for (
int i = 0; i < (*m); i++) {
554void hip_AA(
void*
AA,
void*
GG,
void*
diagx,
int *n,
int *m) {
556 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
557 const int nb = ((*n) + 1024 - 1)/ 1024;
562 if (
nb > mma_red_s) {
564 if (mma_bufred !=
NULL) {
572 for (
int i = 0; i < (*m); i++) {
573 for (
int j = 0;
j < (*m);
j++) {
583 (
real*)
AA, mma_bufred_d, 1, i +
j * (*m + 1));
594 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
602void hip_dxsi(
void*
dxsi,
void* xsi,
void*
dx,
void* x,
605 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
613void hip_deta(
void*
deta,
void* eta,
void*
dx,
void* x,
616 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
624void hip_rex(
void*
rex,
void* x,
void*
xlow,
void*
xupp,
void* pij,
625 void* p0j,
void* qij,
void* q0j,
void* lambda,
void* xsi,
void* eta,
628 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
637void hip_rey(
void*
rey,
void* c,
void* d,
void* y,
void* lambda,
void* mu,
int* n) {
639 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
649void hip_sub2cons(
void *a,
void *
b,
void *c,
real *d,
int *n) {
651 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
660real hip_norm(
void* a,
int* n) {
662 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
663 const int nb = ((*n) + 1024 - 1) / 1024;
666 if (
nb > mma_red_s) {
668 if (mma_bufred !=
NULL) {
677 (
real*)a, mma_bufred_d, (*n));
689 return mma_bufred[0];
693void hip_dely(
void*
dely,
void* c,
void* d,
void* y,
void* lambda,
696 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
706 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
707 const int nb = ((*n) + 1024 - 1) / 1024;
710 if (
nb > mma_red_s) {
712 if (mma_bufred !=
NULL) {
733 return mma_bufred[0];
737real hip_maxval3(
void* a,
void*
b,
void* c,
real*
cons,
int* n) {
739 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
740 const int nb = ((*n) + 1024 - 1) / 1024;
743 if (
nb > mma_red_s) {
745 if (mma_bufred !=
NULL) {
764 return mma_bufred[0];
768void hip_kkt_rex(
void*
rex,
void*
df0dx,
void*
dfdx,
void* xsi,
769 void* eta,
void* lambda,
int* n,
int* m) {
771 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
781void hip_maxcons(
void* a,
real*
b,
real* c,
void* d,
int* n) {
783 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
791real hip_lcsc2(
void *a,
void*
b,
int *n) {
793 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
794 const int nb = ((*n) + 1024 - 1) / 1024;
797 if (
nb > mma_red_s) {
799 if (mma_bufred !=
NULL) {
820 return mma_bufred[0];
824void hip_mpisum(
void *a,
int *n) {
825#ifdef HAVE_DEVICE_MPI
833void hip_add2inv2(
void* a,
void*
b,
real* c,
int* n) {
835 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
843void hip_max2(
void* a,
real*
b,
void* c,
real* d,
int* n) {
845 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
854 void* mu,
void* y,
real*
delz,
int* m) {
856 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
865void hip_updateAA(
void*
AA,
void*
globaltmp_mm,
void* s,
void* lambda,
866 void* d,
void* mu,
void* y,
void* a,
869 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
880 void* mu,
void* y,
int* n) {
882 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
__global__ void heaviside_mapping_apply_kernel(const T beta, const T eta, T *__restrict__ X_out_d, T *__restrict__ X_in_d, const int n)