Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma.hip
1/*
2 Copyright (c) 2021-2025, The Neko Authors
3 All rights reserved.
4
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions
7 are met:
8
9 * Redistributions of source code must retain the above copyright
10 notice, this list of conditions and the following disclaimer.
11
12 * Redistributions in binary form must reproduce the above
13 copyright notice, this list of conditions and the following
14 disclaimer in the documentation and/or other materials provided
15 with the distribution.
16
17 * Neither the name of the authors nor the names of its
18 contributors may be used to endorse or promote products derived
19 from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
24 FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
25 COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
27 BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
28 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
29 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30 LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
31 ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 POSSIBILITY OF SUCH DAMAGE.
33*/
34
35// System includes
36#include <stdio.h>
37#include <stdlib.h>
38
39// Device includes
40#include <hip/hip_runtime.h>
41
42// Neko includes
43#include <neko/device/device_config.h>
44#include <neko/device/hip/check.h>
45#include <neko/math/bcknd/device/device_mpi_op.h>
46
47// Local includes
48#include "mma_kernel.h"
49
50extern "C" {
51
52int mma_red_s = 0;
53real * mma_bufred = NULL;
54real * mma_bufred_d = NULL;
55
56void hip_Hess(void* Hess, void* hijx, void* Ljjxinv, int *n, int *m) {
57 const dim3 nthrds(1024, 1, 1);
58 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
59 const int nb = ((*n) + 1024 - 1) / 1024;
60 const hipStream_t stream = (hipStream_t) glb_cmd_queue;
61 hipStreamSynchronize(stream);
62
63 if (nb > mma_red_s) {
64 mma_red_s = nb;
65 if (mma_bufred != NULL) {
66 HIP_CHECK(hipHostFree(mma_bufred));
67 HIP_CHECK(hipFree(mma_bufred_d));
68 }
69 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
70 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
71 }
72
73 for (int i = 0; i < (*m); i++) {
74 for (int j = 0; j < (*m); j++) {
75 hipLaunchKernelGGL(mmasumHess_kernel<real>, nblcks, nthrds, 0, stream,
76 (real*)hijx, (real*)Ljjxinv, mma_bufred_d, (*n), (*m), i, j);
77 HIP_CHECK(hipGetLastError());
78
79 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
80 mma_bufred_d, nb);
81 HIP_CHECK(hipGetLastError());
82
83 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
84 (real*)Hess, mma_bufred_d, 1, i + j * (*m));
85 HIP_CHECK(hipGetLastError());
86
87 hipStreamSynchronize(stream);
88 }
89 }
90}
91
92void mma_Ljjxinv_hip(void* Ljjxinv, void* pjlambda, void* qjlambda, void* x,
93 void* low, void* upp, void* alpha, void* beta, int* n) {
94 const dim3 nthrds(1024, 1, 1);
95 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
96 hipLaunchKernelGGL(mma_Ljjxinv_kernel<real>, nblcks, nthrds, 0,
97 (hipStream_t)glb_cmd_queue, (real*)Ljjxinv, (real*)pjlambda, (real*)qjlambda,
98 (real*)x, (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
99 HIP_CHECK(hipGetLastError());
100}
101
102void mma_dipsolvesub1_hip(void* x, void* pjlambda, void* qjlambda, void* low,
103 void* upp, void* alpha, void* beta, int* n) {
104 const dim3 nthrds(1024, 1, 1);
105 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
106 hipLaunchKernelGGL(mma_dipsolvesub1_kernel<real>, nblcks, nthrds, 0,
107 (hipStream_t)glb_cmd_queue, (real*)x, (real*)pjlambda, (real*)qjlambda,
108 (real*)low, (real*)upp, (real*)alpha, (real*)beta, *n);
109 HIP_CHECK(hipGetLastError());
110}
111
112void mattrans_v_mul_hip(void* output, void* pij, void* lambda, int* m, int* n) {
113 const dim3 nthrds(1024, 1, 1);
114 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
115 hipLaunchKernelGGL(mattrans_v_mul_kernel<real>, nblcks, nthrds, 0,
116 (hipStream_t)glb_cmd_queue, (real*)output, (real*)pij, (real*)lambda, *m, *n);
117 HIP_CHECK(hipGetLastError());
118}
119
120void mma_gensub4_hip(void* x, void* low, void* upp, void* pij, void* qij,
121 int* n, int* m, void* bi) {
122 const dim3 nthrds(1024, 1, 1);
123 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
124 const int nb = ((*n) + 1024 - 1) / 1024;
125 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
126
127 if (nb > mma_red_s) {
128 mma_red_s = nb;
129 if (mma_bufred != NULL) {
130 HIP_CHECK(hipFreeHost(mma_bufred));
131 HIP_CHECK(hipFree(mma_bufred_d));
132 }
133 HIP_CHECK(hipHostMalloc(&mma_bufred,
134 nb * sizeof(real)));
135 HIP_CHECK(hipMalloc(&mma_bufred_d,
136 nb * sizeof(real)));
137 }
138
139 real* temp;
140 real* bi_d = (real*)bi;
141 hipMalloc(&temp, (*m) * (*n) * sizeof(real));
142
143 hipLaunchKernelGGL(mma_sub4_kernel<real>, nblcks, nthrds, 0, stream,
144 (real*)x, (real*)low, (real*)upp, (real*)pij, (real*)qij,
145 temp, *n, *m);
146
147 for (int i = 0; i < (*m); i++) {
148 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
149 temp, mma_bufred_d, (*n), (*m), i);
150 HIP_CHECK(hipGetLastError());
151
152 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
153 mma_bufred_d, nb);
154 HIP_CHECK(hipGetLastError());
155
156 HIP_CHECK(hipMemcpyAsync(
157 bi_d + i, mma_bufred_d, sizeof(real),
158 hipMemcpyDeviceToDevice, stream));
159
160 hipStreamSynchronize(stream);
161 }
162
163 hipFree(temp);
164}
165
166void mma_gensub3_hip(void* x, void* df0dx, void* dfdx, void* low,
167 void* upp, void* xmin, void* xmax, void* alpha,
168 void* beta, void* p0j, void* q0j, void* pij,
169 void* qij, int* n, int* m) {
170 const dim3 nthrds(1024, 1, 1);
171 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
172
173 hipLaunchKernelGGL(mma_sub3_kernel<real>, nblcks, nthrds, 0,
174 (hipStream_t)glb_cmd_queue,
175 (real*)x, (real*)df0dx, (real*)dfdx, (real*)low,
176 (real*)upp, (real*)xmin, (real*)xmax, (real*)alpha,
177 (real*)beta, (real*)p0j, (real*)q0j, (real*)pij,
178 (real*)qij, *n, *m);
179
180 HIP_CHECK(hipGetLastError());
181}
182
183void mma_gensub2_hip(void* low, void* upp, void* x, void* xold1,
184 void* xold2, void* xdiff, real* asydecr,
185 real* asyincr, int* n) {
186 const dim3 nthrds(1024, 1, 1);
187 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
188
189 hipLaunchKernelGGL(mma_sub2_kernel<real>, nblcks, nthrds, 0,
190 (hipStream_t)glb_cmd_queue,
191 (real*)low, (real*)upp, (real*)x, (real*)xold1,
192 (real*)xold2, (real*)xdiff, *asydecr, *asyincr, *n);
193
194 HIP_CHECK(hipGetLastError());
195}
196
197void mma_gensub1_hip(void* low, void* upp, void* x, void* xmin, void* xmax,
198 real* asyinit, int* n) {
199 const dim3 nthrds(1024, 1, 1);
200 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
201
202 hipLaunchKernelGGL(mma_sub1_kernel<real>, nblcks, nthrds, 0,
203 (hipStream_t)glb_cmd_queue,
204 (real*)low, (real*)upp, (real*)x, (real*)xmin, (real*)xmax,
205 *asyinit, *n);
206
207 HIP_CHECK(hipGetLastError());
208}
209
210void hip_mma_max(void* xsi, void* x, void* alpha, int* n) {
211 const dim3 nthrds(1024, 1, 1);
212 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
213
214 hipLaunchKernelGGL(mma_max2_kernel<real>, nblcks, nthrds, 0,
215 (hipStream_t)glb_cmd_queue,
216 (real*)xsi, (real*)x, (real*)alpha, *n);
217
218 HIP_CHECK(hipGetLastError());
219}
220
221void hip_relambda(void* relambda, void* x, void* xupp, void* xlow,
222 void* pij, void* qij, int* n, int* m) {
223 const dim3 nthrds(1024, 1, 1);
224 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
225 const int nb = nblcks.x;
226 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
227
228 if (nb > mma_red_s) {
229 mma_red_s = nb;
230 if (mma_bufred != NULL) {
231 HIP_CHECK(hipHostFree(mma_bufred));
232 HIP_CHECK(hipFree(mma_bufred_d));
233 }
234 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
235 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
236 }
237
238 real* temp;
239 hipMalloc(&temp, (*n) * (*m) * sizeof(real));
240
241 hipLaunchKernelGGL(relambda_kernel<real>, nblcks, nthrds, 0, stream,
242 temp, (real*)x, (real*)xupp, (real*)xlow,
243 (real*)pij, (real*)qij, *n, *m);
244
245 for (int i = 0; i < (*m); i++) {
246 hipLaunchKernelGGL(mmasum_kernel<real>, nblcks, nthrds, 0, stream,
247 temp, mma_bufred_d, (*n), (*m), i);
248 HIP_CHECK(hipGetLastError());
249
250 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0,
251 stream, mma_bufred_d, nb);
252 HIP_CHECK(hipGetLastError());
253
254 hipLaunchKernelGGL(mma_copy_kernel, dim3(1), dim3(1), 0, stream,
255 (real*)relambda, mma_bufred_d, 1, i);
256 HIP_CHECK(hipGetLastError());
257
258 hipStreamSynchronize(stream);
259 }
260
261 hipFree(temp);
262}
263
264void hip_sub2cons2(void* a, void* b, void* c, void* d, real* e, int* n) {
265 const dim3 nthrds(1024, 1, 1);
266 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
267
268 hipLaunchKernelGGL(sub2cons2_kernel<real>, nblcks, nthrds, 0,
269 (hipStream_t)glb_cmd_queue,
270 (real*)a, (real*)b, (real*)c, (real*)d, *e, *n);
271
272 HIP_CHECK(hipGetLastError());
273}
274
275real hip_maxval(void* a, int* n) {
276 const dim3 nthrds(1024, 1, 1);
277 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
278 const int nb = nblcks.x;
279 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
280
281 if (nb > mma_red_s) {
282 mma_red_s = nb;
283 if (mma_bufred != NULL) {
284 HIP_CHECK(hipHostFree(mma_bufred));
285 HIP_CHECK(hipFree(mma_bufred_d));
286 }
287 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
288 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
289 }
290
291 hipLaunchKernelGGL(maxval_kernel<real>, nblcks, nthrds, 0, stream,
292 (real*)a, mma_bufred_d, (*n));
293 HIP_CHECK(hipGetLastError());
294
295 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
296 mma_bufred_d, nb);
297 HIP_CHECK(hipGetLastError());
298
299 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
300 hipMemcpyDeviceToHost, stream));
301 hipStreamSynchronize(stream);
302
303 return mma_bufred[0];
304}
305
306
307void hip_delx(void* delx, void* x, void* xlow, void* xupp, void* pij,
308 void* qij, void* p0j, void* q0j, void* alpha, void* beta, void* lambda,
309 real* epsi, int* n, int* m) {
310 const dim3 nthrds(1024, 1, 1);
311 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
312
313 hipLaunchKernelGGL(delx_kernel<real>, nblcks, nthrds, 0,
314 (hipStream_t)glb_cmd_queue,
315 (real*)delx, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
316 (real*)qij, (real*)p0j, (real*)q0j, (real*)alpha, (real*)beta,
317 (real*)lambda, *epsi, *n, *m);
318 HIP_CHECK(hipGetLastError());
319}
320
321void hip_GG(void* GG, void* x, void* xlow, void* xupp,
322 void* pij, void* qij, int* n, int* m) {
323 const dim3 nthrds(1024, 1, 1);
324 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
325
326 hipLaunchKernelGGL(GG_kernel<real>, nblcks, nthrds, 0,
327 (hipStream_t)glb_cmd_queue,
328 (real*)GG, (real*)x, (real*)xlow, (real*)xupp, (real*)pij,
329 (real*)qij, *n, *m);
330 HIP_CHECK(hipGetLastError());
331}
332
333void hip_diagx(void* diagx, void* x, void* xsi, void* xlow, void* xupp,
334 void* p0j, void* q0j, void* pij, void* qij, void* alpha, void* beta,
335 void* eta, void* lambda, int *n, int *m) {
336 const dim3 nthrds(1024, 1, 1);
337 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
338
339 hipLaunchKernelGGL(diagx_kernel<real>, nblcks, nthrds, 0,
340 (hipStream_t)glb_cmd_queue,
341 (real*)diagx, (real*)x, (real*)xsi, (real*)xlow, (real*)xupp,
342 (real*)p0j, (real*)q0j, (real*)pij, (real*)qij, (real*)alpha,
343 (real*)beta, (real*)eta, (real*)lambda, *n, *m);
344 HIP_CHECK(hipGetLastError());
345}
346
347void hip_bb(void* bb, void* GG, void* delx, void* diagx, int *n, int *m) {
348 const dim3 nthrds(1024, 1, 1);
349 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
350 const int nb = ((*n) + 1024 - 1)/ 1024;
351 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
352
353 hipStreamSynchronize(stream);
354
355 if (nb > mma_red_s) {
356 mma_red_s = nb;
357 if (mma_bufred != NULL) {
358 HIP_CHECK(hipHostFree(mma_bufred));
359 HIP_CHECK(hipFree(mma_bufred_d));
360 }
361 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
362 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
363 }
364
365 for (int i = 0; i < (*m); i++) {
366 hipLaunchKernelGGL(mmasumbb_kernel<real>, nblcks, nthrds, 0, stream,
367 (real*)GG, (real*)delx, (real*)diagx, mma_bufred_d, *n, *m, i);
368 HIP_CHECK(hipGetLastError());
369
370 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
371 mma_bufred_d, nb);
372 HIP_CHECK(hipGetLastError());
373
374 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream, (real*)bb,
375 mma_bufred_d, 1, i);
376 HIP_CHECK(hipGetLastError());
377
378 hipStreamSynchronize(stream);
379 }
380}
381
382void hip_AA(void* AA, void* GG, void* diagx, int *n, int *m) {
383 const dim3 nthrds(1024, 1, 1);
384 const dim3 nblcks(((*n)+1024 - 1)/ 1024, 1, 1);
385 const int nb = ((*n) + 1024 - 1)/ 1024;
386 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
387
388 hipStreamSynchronize(stream);
389
390 if (nb > mma_red_s) {
391 mma_red_s = nb;
392 if (mma_bufred != NULL) {
393 HIP_CHECK(hipHostFree(mma_bufred));
394 HIP_CHECK(hipFree(mma_bufred_d));
395 }
396 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
397 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
398 }
399
400 for (int i = 0; i < (*m); i++) {
401 for (int j = 0; j < (*m); j++) {
402 hipLaunchKernelGGL(mmasumAA_kernel<real>, nblcks, nthrds, 0, stream,
403 (real*)GG, (real*)diagx, mma_bufred_d, *n, *m, i, j);
404 HIP_CHECK(hipGetLastError());
405
406 hipLaunchKernelGGL(mmareduce_kernel<real>, 1, 1024, 0, stream,
407 mma_bufred_d, nb);
408 HIP_CHECK(hipGetLastError());
409
410 hipLaunchKernelGGL(mma_copy_kernel, 1, 1, 0, stream,
411 (real*)AA, mma_bufred_d, 1, i + j * (*m + 1));
412 HIP_CHECK(hipGetLastError());
413
414 hipStreamSynchronize(stream);
415 }
416 }
417}
418
419void hip_dx(void* dx, void* delx, void* diagx, void* GG, void* dlambda,
420 int* n, int* m) {
421 const dim3 nthrds(1024, 1, 1);
422 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
423
424 hipLaunchKernelGGL(dx_kernel<real>, nblcks, nthrds, 0,
425 (hipStream_t)glb_cmd_queue,
426 (real*)dx, (real*)delx, (real*)diagx, (real*)GG, (real*)dlambda, *n, *m);
427 HIP_CHECK(hipGetLastError());
428}
429
430void hip_dxsi(void* dxsi, void* xsi, void* dx, void* x,
431 void* alpha, real* epsi, int* n) {
432 const dim3 nthrds(1024, 1, 1);
433 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
434
435 hipLaunchKernelGGL(dxsi_kernel<real>, nblcks, nthrds, 0,
436 (hipStream_t)glb_cmd_queue,
437 (real*)dxsi, (real*)xsi, (real*)dx, (real*)x, (real*)alpha, *epsi, *n);
438 HIP_CHECK(hipGetLastError());
439}
440
441void hip_deta(void* deta, void* eta, void* dx, void* x,
442 void* beta, real* epsi, int* n) {
443 const dim3 nthrds(1024, 1, 1);
444 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
445
446 hipLaunchKernelGGL(deta_kernel<real>, nblcks, nthrds, 0,
447 (hipStream_t)glb_cmd_queue,
448 (real*)deta, (real*)eta, (real*)dx, (real*)x, (real*)beta, *epsi, *n);
449 HIP_CHECK(hipGetLastError());
450}
451
452void hip_rex(void* rex, void* x, void* xlow, void* xupp, void* pij,
453 void* p0j, void* qij, void* q0j, void* lambda, void* xsi, void* eta,
454 int* n, int* m) {
455 const dim3 nthrds(1024, 1, 1);
456 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
457
458 hipLaunchKernelGGL(RexCalculation_kernel<real>, nblcks, nthrds, 0,
459 (hipStream_t)glb_cmd_queue,
460 (real*)rex, (real*)x, (real*)xlow, (real*)xupp, (real*)pij, (real*)p0j,
461 (real*)qij, (real*)q0j, (real*)lambda, (real*)xsi, (real*)eta, *n, *m);
462 HIP_CHECK(hipGetLastError());
463}
464
465void hip_rey(void* rey, void* c, void* d, void* y, void* lambda, void* mu, int* n) {
466 const dim3 nthrds(1024, 1, 1);
467 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
468
469 hipLaunchKernelGGL(rey_calculation_kernel<real>, nblcks, nthrds, 0,
470 (hipStream_t)glb_cmd_queue,
471 (real*)rey, (real*)c, (real*)d, (real*)y, (real*)lambda, (real*)mu, *n);
472 HIP_CHECK(hipGetLastError());
473}
474
475
476 // a_d = b_d * c_d - d
477void hip_sub2cons(void *a, void *b, void *c, real *d, int *n) {
478 const dim3 nthrds(1024, 1, 1);
479 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
480 hipLaunchKernelGGL(sub2cons_kernel<real>, nblcks, nthrds, 0,
481 (hipStream_t)glb_cmd_queue,
482 (real *)a, (real *)b, (real *)c, *d, *n);
483 HIP_CHECK(hipGetLastError());
484}
485
486
487// sum(a^2)
488real hip_norm(void* a, int* n) {
489 const dim3 nthrds(1024, 1, 1);
490 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
491 const int nb = ((*n) + 1024 - 1) / 1024;
492 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
493
494 if (nb > mma_red_s) {
495 mma_red_s = nb;
496 if (mma_bufred != NULL) {
497 HIP_CHECK(hipFreeHost(mma_bufred));
498 HIP_CHECK(hipFree(mma_bufred_d));
499 }
500 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
501 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
502 }
503
504 hipLaunchKernelGGL(norm_kernel<real>, nblcks, nthrds, 0, stream,
505 (real*)a, mma_bufred_d, (*n));
506 HIP_CHECK(hipGetLastError());
507
508 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
509 mma_bufred_d, nb);
510 HIP_CHECK(hipGetLastError());
511
512 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
513 hipMemcpyDeviceToHost, stream));
514
515 hipStreamSynchronize(stream);
516
517 return mma_bufred[0];
518}
519
520
521void hip_dely(void* dely, void* c, void* d, void* y, void* lambda,
522 real* epsi, int* n) {
523 const dim3 nthrds(1024, 1, 1);
524 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
525 hipLaunchKernelGGL(dely_kernel<real>, nblcks, nthrds, 0,
526 (hipStream_t)glb_cmd_queue,
527 (real*)dely, (real*)c, (real*)d, (real*)y, (real*)lambda, *epsi, *n);
528 HIP_CHECK(hipGetLastError());
529}
530
531
532real hip_maxval2(void* a, void* b, real* cons, int* n) {
533 const dim3 nthrds(1024, 1, 1);
534 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
535 const int nb = ((*n) + 1024 - 1) / 1024;
536 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
537
538 if (nb > mma_red_s) {
539 mma_red_s = nb;
540 if (mma_bufred != NULL) {
541 HIP_CHECK(hipFreeHost(mma_bufred));
542 HIP_CHECK(hipFree(mma_bufred_d));
543 }
544 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
545 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
546 }
547
548 hipLaunchKernelGGL(maxval2_kernel<real>, nblcks, nthrds, 0, stream,
549 (real*)a, (real*)b, mma_bufred_d, *cons, *n);
550 HIP_CHECK(hipGetLastError());
551
552 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
553 mma_bufred_d, nb);
554 HIP_CHECK(hipGetLastError());
555
556 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
557 hipMemcpyDeviceToHost, stream));
558
559 hipStreamSynchronize(stream);
560
561 return mma_bufred[0];
562}
563
564
565real hip_maxval3(void* a, void* b, void* c, real* cons, int* n) {
566 const dim3 nthrds(1024, 1, 1);
567 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
568 const int nb = ((*n) + 1024 - 1) / 1024;
569 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
570
571 if (nb > mma_red_s) {
572 mma_red_s = nb;
573 if (mma_bufred != NULL) {
574 HIP_CHECK(hipFreeHost(mma_bufred));
575 HIP_CHECK(hipFree(mma_bufred_d));
576 }
577 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
578 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
579 }
580
581 hipLaunchKernelGGL(maxval3_kernel<real>, nblcks, nthrds, 0, stream,
582 (real*)a, (real*)b, (real*)c, mma_bufred_d, *cons, *n);
583 hipLaunchKernelGGL(max_reduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
584 mma_bufred_d, nb);
585 HIP_CHECK(hipGetLastError());
586
587 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
588 hipMemcpyDeviceToHost, stream));
589
590 hipStreamSynchronize(stream);
591
592 return mma_bufred[0];
593}
594
595
596void hip_kkt_rex(void* rex, void* df0dx, void* dfdx, void* xsi,
597 void* eta, void* lambda, int* n, int* m) {
598 const dim3 nthrds(1024, 1, 1);
599 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
600 hipLaunchKernelGGL(kkt_rex_kernel<real>, nblcks, nthrds, 0,
601 (hipStream_t)glb_cmd_queue,
602 (real*)rex, (real*)df0dx, (real*)dfdx, (real*)xsi,
603 (real*)eta, (real*)lambda, *n, *m);
604 HIP_CHECK(hipGetLastError());
605}
606
607
608// a_d = max(b, c * d_d)
609void hip_maxcons(void* a, real* b, real* c, void* d, int* n) {
610 const dim3 nthrds(1024, 1, 1);
611 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
612 hipLaunchKernelGGL(maxcons_kernel<real>, nblcks, nthrds, 0,
613 (hipStream_t)glb_cmd_queue,
614 (real*)a, *b, *c, (real*)d, *n);
615 HIP_CHECK(hipGetLastError());
616}
617
618
619real hip_lcsc2(void *a, void*b, int *n) {
620 const dim3 nthrds(1024, 1, 1);
621 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
622 const int nb = ((*n) + 1024 - 1) / 1024;
623 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
624
625 if (nb > mma_red_s) {
626 mma_red_s = nb;
627 if (mma_bufred != NULL) {
628 HIP_CHECK(hipFreeHost(mma_bufred));
629 HIP_CHECK(hipFree(mma_bufred_d));
630 }
631 HIP_CHECK(hipHostMalloc(&mma_bufred, nb * sizeof(real)));
632 HIP_CHECK(hipMalloc(&mma_bufred_d, nb * sizeof(real)));
633 }
634
635 hipLaunchKernelGGL(glsc2_kernel<real>, nblcks, nthrds, 0, stream,
636 (real*)a, (real*)b, mma_bufred_d, (*n));
637 HIP_CHECK(hipGetLastError());
638
639 hipLaunchKernelGGL(mmareduce_kernel<real>, dim3(1), dim3(1024), 0, stream,
640 mma_bufred_d, nb);
641 HIP_CHECK(hipGetLastError());
642
643 HIP_CHECK(hipMemcpyAsync(mma_bufred, mma_bufred_d, sizeof(real),
644 hipMemcpyDeviceToHost, stream));
645
646 hipStreamSynchronize(stream);
647
648 return mma_bufred[0];
649}
650
651
652void hip_mpisum(void *a, int *n) {
653#ifdef HAVE_DEVICE_MPI
654 real* temp = (real*)a;
655 hipStreamSynchronize(stream);
656 device_mpi_allreduce_inplace(temp, *n, sizeof(real), DEVICE_MPI_SUM);
657#endif
658}
659
660
661void hip_add2inv2(void* a, void* b, real* c, int* n) {
662 const dim3 nthrds(1024, 1, 1);
663 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
664 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
665
666 hipLaunchKernelGGL(add2inv2_kernel<real>, nblcks, nthrds, 0, stream,
667 (real*)a, (real*)b, *c, *n);
668 HIP_CHECK(hipGetLastError());
669}
670
671void hip_max2(void* a, real* b, void* c, real* d, int* n) {
672 const dim3 nthrds(1024, 1, 1);
673 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
674 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
675
676 hipLaunchKernelGGL(max2_kernel<real>, nblcks, nthrds, 0, stream,
677 (real*)a, *b, (real*)c, *d, *n);
678 HIP_CHECK(hipGetLastError());
679}
680
681void hip_updatebb(void* bb, void* dellambda, void* dely, void* d,
682 void* mu, void* y, real* delz, int* m) {
683 const dim3 nthrds(1024, 1, 1);
684 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
685 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
686
687 hipLaunchKernelGGL(updatebb_kernel<real>, nblcks, nthrds, 0, stream,
688 (real*)bb, (real*)dellambda, (real*)dely, (real*)d,
689 (real*)mu, (real*)y, *delz, *m);
690 HIP_CHECK(hipGetLastError());
691}
692
693void hip_updateAA(void* AA, void* globaltmp_mm, void* s, void* lambda,
694 void* d, void* mu, void* y, void* a,
695 real* zeta, real* z, int* m) {
696 const dim3 nthrds(1024, 1, 1);
697 const dim3 nblcks(((*m + 1) + 1024 - 1) / 1024, 1, 1);
698 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
699
700 hipLaunchKernelGGL(updateAA_kernel<real>, nblcks, nthrds, 0, stream,
701 (real*)AA, (real*)globaltmp_mm, (real*)s,
702 (real*)lambda, (real*)d, (real*)mu,
703 (real*)y, (real*)a, *zeta, *z, *m);
704 HIP_CHECK(hipGetLastError());
705}
706
707void hip_dy(void* dy, void* dely, void* dlambda, void* d,
708 void* mu, void* y, int* n) {
709 const dim3 nthrds(1024, 1, 1);
710 const dim3 nblcks(((*n) + 1024 - 1) / 1024, 1, 1);
711 const hipStream_t stream = (hipStream_t)glb_cmd_queue;
712
713 hipLaunchKernelGGL(dy_kernel<real>, nblcks, nthrds, 0, stream,
714 (real*)dy, (real*)dely, (real*)dlambda, (real*)d,
715 (real*)mu, (real*)y, *n);
716 HIP_CHECK(hipGetLastError());
717}
718}