Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma.hip
1/*
2 Copyright (c) 2021-2025, The Neko Authors
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions
7 are met:
8
9 * Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above
13 copyright notice, this list of conditions and the following
14 disclaimer in the documentation and/or other materials provided
15 with the distribution.
16
17 * Neither the name of the authors nor the names of its
18 contributors may be used to endorse or promote products derived
19 from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 POSSIBILITY OF SUCH DAMAGE.
33*/
34
35#include <hip/hip_runtime.h>
36#include "device/hip/check.h"
37#include "mma_kernel.h"
38#include <stdio.h>
39#include <stdlib.h>
40
41extern "C" {
42#include "math/bcknd/device/device_mpi_reduce.h"
43#include "math/bcknd/device/device_mpi_op.h"
44#include "device/device_config.h"
45
46int mma_red_s = 0;
47real * mma_bufred = NULL;
48real * mma_bufred_d = NULL;
50#include <hip/hip_runtime.h>
51
52void hip_Hess(void* Hess, void* hijx, void* Ljjxinv, int *n, int *m) {
53 const dim3 nthrds(1024, 1, 1);
54 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
55 const int nb = ((*n) + 1024 - 1) / 1024;
56 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
57 hipStreamSynchronize(stream);
58
59 if (nb > mma_red_s) {
60 mma_red_s = nb;
61 if (mma_bufred != NULL) {
62 HIP_CHECK(hipHostFree(mma_bufred));
63 HIP_CHECK(hipFree(mma_bufred_d));
64 }
65 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
66 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
67 }
68
69 for (int i = 0; i < (*m); i++) {
70 for (int j = 0; j < (*m); j++) {
71 hipLaunchKernelGGL(mmasumHess_kernel<real>, nblcks, nthrds, 0, stream,
72 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
73 HIP_CHECK(hipGetLastError());
74
75 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
76 mma_bufred_d, nb);
77 HIP_CHECK(hipGetLastError());
78
79 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
80 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
81 HIP_CHECK(hipGetLastError());
82
83 hipStreamSynchronize(stream);
84 }
85 }
86}
87
88void mma_Ljjxinv_hip(void* Ljjxinv, void* pjlambda, void* qjlambda, void* x,
89 void* low, void* upp, void* alpha, void* beta, int* n) {
90 const dim3 nthrds(1024, 1, 1);
91 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
92 hipLaunchKernelGGL(mma_Ljjxinv_kernel<real>, nblcks, nthrds, 0,
93 (hipStream_t)glb_cmd_queue, (real*)Ljjxinv, (real*)pjlambda, (real*)qjlambda,
94 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
95 HIP_CHECK(hipGetLastError());
96}
97
98void mma_dipsolvesub1_hip(void* x, void* pjlambda, void* qjlambda, void* low,
99 void* upp, void* alpha, void* beta, int* n) {
100 const dim3 nthrds(1024, 1, 1);
101 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
102 hipLaunchKernelGGL(mma_dipsolvesub1_kernel<real>, nblcks, nthrds, 0,
103 (hipStream_t)glb_cmd_queue, (real*)x, (real*)pjlambda, (real*)qjlambda,
104 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
105 HIP_CHECK(hipGetLastError());
106}
107
108void mattrans_v_mul_hip(void* output, void* pij, void* lambda, int* m, int* n) {
109 const dim3 nthrds(1024, 1, 1);
110 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
111 hipLaunchKernelGGL(mattrans_v_mul_kernel<real>, nblcks, nthrds, 0,
112 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
113 HIP_CHECK(hipGetLastError());
114}
115
116void mma_gensub4_hip(void* x, void* low, void* upp, void* pij, void* qij,
117 int* n, int* m, void* bi) {
118 const dim3 nthrds(1024, 1, 1);
119 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
120 const int nb = ((*n) + 1024 - 1) / 1024;
121 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
122
123 if (nb > mma_red_s) {
124 mma_red_s = nb;
125 if (mma_bufred != NULL) {
126 HIP_CHECK(hipFreeHost(mma_bufred));
127 HIP_CHECK(hipFree(mma_bufred_d));
128 }
129 HIP_CHECK(hipHostMalloc(&mma_bufred,
130 nb * sizeof(real)));
131 HIP_CHECK(hipMalloc(&mma_bufred_d,
132 nb * sizeof(real)));
133 }
134
135 real* temp;
136 real* bi_d = (real*)bi;
137 hipMalloc(&temp, (*m) * (*n) * sizeof(real));
138
139 hipLaunchKernelGGL(mma_sub4_kernel<real>, nblcks, nthrds, 0, stream,
140 (real*)x, (real*)low, (real*)upp, (real*)pij, (real*)qij,
141 temp, *n, *m);
142
143 for (int i = 0; i < (*m); i++) {
144 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
145 temp, mma_bufred_d, (*n), (*m), i);
146 HIP_CHECK(hipGetLastError());
147
148 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
149 mma_bufred_d, nb);
150 HIP_CHECK(hipGetLastError());
151
152 HIP_CHECK(hipMemcpyAsync(
153 bi_d + i, mma_bufred_d, sizeof(real),
154 hipMemcpyDeviceToDevice, stream));
155
156 hipStreamSynchronize(stream);
157 }
158
159 hipFree(temp);
160}
161
162void mma_gensub3_hip(void* x, void* df0dx, void* dfdx, void* low,
163 void* upp, void* xmin, void* xmax, void* alpha,
164 void* beta, void* p0j, void* q0j, void* pij,
165 void* qij, int* n, int* m) {
166 const dim3 nthrds(1024, 1, 1);
167 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
168
169 hipLaunchKernelGGL(mma_sub3_kernel<real>, nblcks, nthrds, 0,
170 (hipStream_t)glb_cmd_queue,
171 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
172 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
173 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
174 (real*)qij, *n, *m);
175
176 HIP_CHECK(hipGetLastError());
177}
178
179void mma_gensub2_hip(void* low, void* upp, void* x, void* xold1,
180 void* xold2, void* xdiff, real* asydecr,
181 real* asyincr, int* n) {
182 const dim3 nthrds(1024, 1, 1);
183 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
184
185 hipLaunchKernelGGL(mma_sub2_kernel<real>, nblcks, nthrds, 0,
186 (hipStream_t)glb_cmd_queue,
187 (real*)low, (real*)upp, (real*)x, (real*)xold1,
188 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
189
190 HIP_CHECK(hipGetLastError());
191}
192
193void mma_gensub1_hip(void* low, void* upp, void* x, void* xmin, void* xmax,
194 real* asyinit, int* n) {
195 const dim3 nthrds(1024, 1, 1);
196 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
197
198 hipLaunchKernelGGL(mma_sub1_kernel<real>, nblcks, nthrds, 0,
199 (hipStream_t)glb_cmd_queue,
200 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
201 *asyinit, *n);
202
203 HIP_CHECK(hipGetLastError());
204}
205
206void hip_mma_max(void* xsi, void* x, void* alpha, int* n) {
207 const dim3 nthrds(1024, 1, 1);
208 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
209
210 hipLaunchKernelGGL(mma_max2_kernel<real>, nblcks, nthrds, 0,
211 (hipStream_t)glb_cmd_queue,
212 (real*)xsi, (real*)x, (real*)alpha, *n);
213
214 HIP_CHECK(hipGetLastError());
215}
216
217void hip_relambda(void* relambda, void* x, void* xupp, void* xlow,
218 void* pij, void* qij, int* n, int* m) {
219 const dim3 nthrds(1024, 1, 1);
220 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
221 const int nb = nblcks.x;
222 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
223
224 if (nb > mma_red_s) {
225 mma_red_s = nb;
226 if (mma_bufred != NULL) {
227 HIP_CHECK(hipHostFree(mma_bufred));
228 HIP_CHECK(hipFree(mma_bufred_d));
229 }
230 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
231 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
232 }
233
234 real* temp;
235 hipMalloc(&temp, (*n) * (*m) * sizeof(real));
236
237 hipLaunchKernelGGL(relambda_kernel<real>, nblcks, nthrds, 0, stream,
238 temp, (real*)x, (real*)xupp, (real*)xlow,
239 (real*)pij, (real*)qij, *n, *m);
240
241 for (int i = 0; i < (*m); i++) {
242 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
243 temp, mma_bufred_d, (*n), (*m), i);
244 HIP_CHECK(hipGetLastError());
245
246 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0,
247 stream, mma_bufred_d, nb);
248 HIP_CHECK(hipGetLastError());
249
250 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
251 (real*)relambda, mma_bufred_d, 1, i);
252 HIP_CHECK(hipGetLastError());
253
254 hipStreamSynchronize(stream);
255 }
256
257 hipFree(temp);
258}
259
260void hip_sub2cons2(void* a, void* b, void* c, void* d, real* e, int* n) {
261 const dim3 nthrds(1024, 1, 1);
262 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
263
264 hipLaunchKernelGGL(sub2cons2_kernel<real>, nblcks, nthrds, 0,
265 (hipStream_t)glb_cmd_queue,
266 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
267
268 HIP_CHECK(hipGetLastError());
269}
270
271real hip_maxval(void* a, int* n) {
272 const dim3 nthrds(1024, 1, 1);
273 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
274 const int nb = nblcks.x;
275 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
276
277 if (nb > mma_red_s) {
278 mma_red_s = nb;
279 if (mma_bufred != NULL) {
280 HIP_CHECK(hipHostFree(mma_bufred));
281 HIP_CHECK(hipFree(mma_bufred_d));
282 }
283 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
284 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
285 }
286
287 hipLaunchKernelGGL(maxval_kernel<real>, nblcks, nthrds, 0, stream,
288 (real*)a, mma_bufred_d, (*n));
289 HIP_CHECK(hipGetLastError());
290
291 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
292 mma_bufred_d, nb);
293 HIP_CHECK(hipGetLastError());
294
295 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
296 hipMemcpyDeviceToHost, stream));
297 hipStreamSynchronize(stream);
298
299 return mma_bufred[0];
300}
301
302
303void hip_delx(void* delx, void* x, void* xlow, void* xupp, void* pij,
304 void* qij, void* p0j, void* q0j, void* alpha, void* beta, void* lambda,
305 real* epsi, int* n, int* m) {
306 const dim3 nthrds(1024, 1, 1);
307 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
308
309 hipLaunchKernelGGL(delx_kernel<real>, nblcks, nthrds, 0,
310 (hipStream_t)glb_cmd_queue,
311 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
312 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
313 (real*)lambda, *epsi, *n, *m);
314 HIP_CHECK(hipGetLastError());
315}
316
317void hip_GG(void* GG, void* x, void* xlow, void* xupp,
318 void* pij, void* qij, int* n, int* m) {
319 const dim3 nthrds(1024, 1, 1);
320 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
321
322 hipLaunchKernelGGL(GG_kernel<real>, nblcks, nthrds, 0,
323 (hipStream_t)glb_cmd_queue,
324 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
325 (real*)qij, *n, *m);
326 HIP_CHECK(hipGetLastError());
327}
328
329void hip_diagx(void* diagx, void* x, void* xsi, void* xlow, void* xupp,
330 void* p0j, void* q0j, void* pij, void* qij, void* alpha, void* beta,
331 void* eta, void* lambda, int *n, int *m) {
332 const dim3 nthrds(1024, 1, 1);
333 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
334
335 hipLaunchKernelGGL(diagx_kernel<real>, nblcks, nthrds, 0,
336 (hipStream_t)glb_cmd_queue,
337 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
338 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
339 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
340 HIP_CHECK(hipGetLastError());
341}
342
343void hip_bb(void* bb, void* GG, void* delx, void* diagx, int *n, int *m) {
344 const dim3 nthrds(1024, 1, 1);
345 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
346 const int nb = ((*n) + 1024 - 1)/ 1024;
347 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
348
349 hipStreamSynchronize(stream);
350
351 if (nb > mma_red_s) {
352 mma_red_s = nb;
353 if (mma_bufred != NULL) {
354 HIP_CHECK(hipHostFree(mma_bufred));
355 HIP_CHECK(hipFree(mma_bufred_d));
356 }
357 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
358 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
359 }
360
361 for (int i = 0; i < (*m); i++) {
362 hipLaunchKernelGGL(mmasumbb_kernel<real>, nblcks, nthrds, 0, stream,
363 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
364 HIP_CHECK(hipGetLastError());
365
366 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
367 mma_bufred_d, nb);
368 HIP_CHECK(hipGetLastError());
369
370 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
371 mma_bufred_d, 1, i);
372 HIP_CHECK(hipGetLastError());
373
374 hipStreamSynchronize(stream);
375 }
376}
377
378void hip_AA(void* AA, void* GG, void* diagx, int *n, int *m) {
379 const dim3 nthrds(1024, 1, 1);
380 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
381 const int nb = ((*n) + 1024 - 1)/ 1024;
382 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
383
384 hipStreamSynchronize(stream);
385
386 if (nb > mma_red_s) {
387 mma_red_s = nb;
388 if (mma_bufred != NULL) {
389 HIP_CHECK(hipHostFree(mma_bufred));
390 HIP_CHECK(hipFree(mma_bufred_d));
391 }
392 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
393 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
394 }
395
396 for (int i = 0; i < (*m); i++) {
397 for (int j = 0; j < (*m); j++) {
398 hipLaunchKernelGGL(mmasumAA_kernel<real>, nblcks, nthrds, 0, stream,
399 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
400 HIP_CHECK(hipGetLastError());
401
402 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
403 mma_bufred_d, nb);
404 HIP_CHECK(hipGetLastError());
405
406 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
407 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
408 HIP_CHECK(hipGetLastError());
409
410 hipStreamSynchronize(stream);
411 }
412 }
413}
414
415void hip_dx(void* dx, void* delx, void* diagx, void* GG, void* dlambda,
416 int* n, int* m) {
417 const dim3 nthrds(1024, 1, 1);
418 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
419
420 hipLaunchKernelGGL(dx_kernel<real>, nblcks, nthrds, 0,
421 (hipStream_t)glb_cmd_queue,
422 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
423 HIP_CHECK(hipGetLastError());
424}
425
426void hip_dxsi(void* dxsi, void* xsi, void* dx, void* x,
427 void* alpha, real* epsi, int* n) {
428 const dim3 nthrds(1024, 1, 1);
429 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
430
431 hipLaunchKernelGGL(dxsi_kernel<real>, nblcks, nthrds, 0,
432 (hipStream_t)glb_cmd_queue,
433 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
434 HIP_CHECK(hipGetLastError());
435}
436
437void hip_deta(void* deta, void* eta, void* dx, void* x,
438 void* beta, real* epsi, int* n) {
439 const dim3 nthrds(1024, 1, 1);
440 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
441
442 hipLaunchKernelGGL(deta_kernel<real>, nblcks, nthrds, 0,
443 (hipStream_t)glb_cmd_queue,
444 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
445 HIP_CHECK(hipGetLastError());
446}
447
448void hip_rex(void* rex, void* x, void* xlow, void* xupp, void* pij,
449 void* p0j, void* qij, void* q0j, void* lambda, void* xsi, void* eta,
450 int* n, int* m) {
451 const dim3 nthrds(1024, 1, 1);
452 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
453
454 hipLaunchKernelGGL(RexCalculation_kernel<real>, nblcks, nthrds, 0,
455 (hipStream_t)glb_cmd_queue,
456 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
457 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
458 HIP_CHECK(hipGetLastError());
459}
460
461void hip_rey(void* rey, void* c, void* d, void* y, void* lambda, void* mu, int* n) {
462 const dim3 nthrds(1024, 1, 1);
463 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
464
465 hipLaunchKernelGGL(rey_calculation_kernel<real>, nblcks, nthrds, 0,
466 (hipStream_t)glb_cmd_queue,
467 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
468 HIP_CHECK(hipGetLastError());
469}
470
471
473void hip_sub2cons(void *a, void *b, void *c, real *d, int *n) {
474 const dim3 nthrds(1024, 1, 1);
475 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
476 hipLaunchKernelGGL(sub2cons_kernel<real>, nblcks, nthrds, 0,
477 (hipStream_t)glb_cmd_queue,
478 (real *)a, (real *)b, (real *)c, *d, *n);
479 HIP_CHECK(hipGetLastError());
480}
481
482
484real hip_norm(void* a, int* n) {
485 const dim3 nthrds(1024, 1, 1);
486 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
487 const int nb = ((*n) + 1024 - 1) / 1024;
488 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
489
490 if (nb > mma_red_s) {
491 mma_red_s = nb;
492 if (mma_bufred != NULL) {
493 HIP_CHECK(hipFreeHost(mma_bufred));
494 HIP_CHECK(hipFree(mma_bufred_d));
495 }
496 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
497 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
498 }
499
500 hipLaunchKernelGGL(norm_kernel<real>, nblcks, nthrds, 0, stream,
501 (real*)a, mma_bufred_d, (*n));
502 HIP_CHECK(hipGetLastError());
503
504 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
505 mma_bufred_d, nb);
506 HIP_CHECK(hipGetLastError());
507
508 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
509 hipMemcpyDeviceToHost, stream));
510
511 hipStreamSynchronize(stream);
512
513 return mma_bufred[0];
514}
515
516
517void hip_dely(void* dely, void* c, void* d, void* y, void* lambda,
518 real* epsi, int* n) {
519 const dim3 nthrds(1024, 1, 1);
520 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
521 hipLaunchKernelGGL(dely_kernel<real>, nblcks, nthrds, 0,
522 (hipStream_t)glb_cmd_queue,
523 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
524 HIP_CHECK(hipGetLastError());
525}
526
527
528real hip_maxval2(void* a, void* b, real* cons, int* n) {
529 const dim3 nthrds(1024, 1, 1);
530 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
531 const int nb = ((*n) + 1024 - 1) / 1024;
532 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
533
534 if (nb > mma_red_s) {
535 mma_red_s = nb;
536 if (mma_bufred != NULL) {
537 HIP_CHECK(hipFreeHost(mma_bufred));
538 HIP_CHECK(hipFree(mma_bufred_d));
539 }
540 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
541 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
542 }
543
544 hipLaunchKernelGGL(maxval2_kernel<real>, nblcks, nthrds, 0, stream,
545 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
546 HIP_CHECK(hipGetLastError());
547
548 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
549 mma_bufred_d, nb);
550 HIP_CHECK(hipGetLastError());
551
552 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
553 hipMemcpyDeviceToHost, stream));
554
555 hipStreamSynchronize(stream);
556
557 return mma_bufred[0];
558}
559
560
561real hip_maxval3(void* a, void* b, void* c, real* cons, int* n) {
562 const dim3 nthrds(1024, 1, 1);
563 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
564 const int nb = ((*n) + 1024 - 1) / 1024;
565 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
566
567 if (nb > mma_red_s) {
568 mma_red_s = nb;
569 if (mma_bufred != NULL) {
570 HIP_CHECK(hipFreeHost(mma_bufred));
571 HIP_CHECK(hipFree(mma_bufred_d));
572 }
573 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
574 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
575 }
576
577 hipLaunchKernelGGL(maxval3_kernel<real>, nblcks, nthrds, 0, stream,
578 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
579 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
580 mma_bufred_d, nb);
581 HIP_CHECK(hipGetLastError());
582
583 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
584 hipMemcpyDeviceToHost, stream));
585
586 hipStreamSynchronize(stream);
587
588 return mma_bufred[0];
589}
590
591
592void hip_kkt_rex(void* rex, void* df0dx, void* dfdx, void* xsi,
593 void* eta, void* lambda, int* n, int* m) {
594 const dim3 nthrds(1024, 1, 1);
595 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
596 hipLaunchKernelGGL(kkt_rex_kernel<real>, nblcks, nthrds, 0,
597 (hipStream_t)glb_cmd_queue,
598 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
599 (real*)eta, (real*)lambda, *n, *m);
600 HIP_CHECK(hipGetLastError());
601}
602
603
605void hip_maxcons(void* a, real* b, real* c, void* d, int* n) {
606 const dim3 nthrds(1024, 1, 1);
607 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
608 hipLaunchKernelGGL(maxcons_kernel<real>, nblcks, nthrds, 0,
609 (hipStream_t)glb_cmd_queue,
610 (real*)a, *b, *c, (real*)d, *n);
611 HIP_CHECK(hipGetLastError());
612}
613
614
615real hip_lcsc2(void *a, void*b, int *n) {
616 const dim3 nthrds(1024, 1, 1);
617 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
618 const int nb = ((*n) + 1024 - 1) / 1024;
619 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
620
621 if (nb > mma_red_s) {
622 mma_red_s = nb;
623 if (mma_bufred != NULL) {
624 HIP_CHECK(hipFreeHost(mma_bufred));
625 HIP_CHECK(hipFree(mma_bufred_d));
626 }
627 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
628 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
629 }
630
631 hipLaunchKernelGGL(glsc2_kernel<real>, nblcks, nthrds, 0, stream,
632 (real*)a, (real*)b, mma_bufred_d, (*n));
633 HIP_CHECK(hipGetLastError());
634
635 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
636 mma_bufred_d, nb);
637 HIP_CHECK(hipGetLastError());
638
639 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
640 hipMemcpyDeviceToHost, stream));
641
642 hipStreamSynchronize(stream);
643
644 return mma_bufred[0];
645}
646
647
648void hip_mpisum(void *a, int *n) {
649#ifdef HAVE_DEVICE_MPI
650 real* temp = (real*)a;
651 hipStreamSynchronize(stream);
652 device_mpi_allreduce_inplace(temp, *n, sizeof(real), DEVICE_MPI_SUM);
653#endif
654}
655
656
657void hip_add2inv2(void* a, void* b, real* c, int* n) {
658 const dim3 nthrds(1024, 1, 1);
659 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
660 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
661
662 hipLaunchKernelGGL(add2inv2_kernel<real>, nblcks, nthrds, 0, stream,
663 (real*)a, (real*)b, *c, *n);
664 HIP_CHECK(hipGetLastError());
665}
666
667void hip_max2(void* a, real* b, void* c, real* d, int* n) {
668 const dim3 nthrds(1024, 1, 1);
669 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
670 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
671
672 hipLaunchKernelGGL(max2_kernel<real>, nblcks, nthrds, 0, stream,
673 (real*)a, *b, (real*)c, *d, *n);
674 HIP_CHECK(hipGetLastError());
675}
676
677void hip_updatebb(void* bb, void* dellambda, void* dely, void* d,
678 void* mu, void* y, real* delz, int* m) {
679 const dim3 nthrds(1024, 1, 1);
680 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
681 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
682
683 hipLaunchKernelGGL(updatebb_kernel<real>, nblcks, nthrds, 0, stream,
684 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
685 (real*)mu, (real*)y, *delz, *m);
686 HIP_CHECK(hipGetLastError());
687}
688
689void hip_updateAA(void* AA, void* globaltmp_mm, void* s, void* lambda,
690 void* d, void* mu, void* y, void* a,
691 real* zeta, real* z, int* m) {
692 const dim3 nthrds(1024, 1, 1);
693 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
694 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
695
696 hipLaunchKernelGGL(updateAA_kernel<real>, nblcks, nthrds, 0, stream,
697 (real*)AA, (real*)globaltmp_mm, (real*)s,
698 (real*)lambda, (real*)d, (real*)mu,
699 (real*)y, (real*)a, *zeta, *z, *m);
700 HIP_CHECK(hipGetLastError());
701}
702
703void hip_dy(void* dy, void* dely, void* dlambda, void* d,
704 void* mu, void* y, int* n) {
705 const dim3 nthrds(1024, 1, 1);
706 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
707 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
708
709 hipLaunchKernelGGL(dy_kernel<real>, nblcks, nthrds, 0, stream,
710 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
711 (real*)mu, (real*)y, *n);
712 HIP_CHECK(hipGetLastError());
713}
714
715}