33submodule(mma) mma_device
35 use device_math,
only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
38 device_glsum, device_cmult2
39 use device_mma_math,
only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess, device_solve_linear_system, &
47 device_prepare_hessian, device_prepare_aa_matrix
49 use neko_config,
only: neko_bcknd_device, neko_device_mpi
50 use device,
only: device_to_host
51 use comm,
only: neko_comm, pe_rank, mpi_real_precision
52 use mpi_f08,
only: mpi_in_place, mpi_max, mpi_min
53 use profiler,
only: profiler_start_region, profiler_end_region
54 use scratch_registry,
only: neko_scratch_registry
60 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
68 class(mma_t),
intent(inout) :: this
69 integer,
intent(in) :: iter
70 type(c_ptr),
intent(inout) :: x
71 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
73 if (.not. this%is_initialized)
then
74 call neko_error(
"The MMA object is not initialized.")
77 call profiler_start_region(
"MMA gensub")
79 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
80 call profiler_end_region(
"MMA gensub")
83 call profiler_start_region(
"MMA subsolve")
84 if (this%subsolver .eq.
"dip")
then
85 call mma_subsolve_dip_device(this, x)
86 else if (this%subsolver .eq.
"dpip")
then
87 call mma_subsolve_dpip_device(this, x)
89 call neko_error(
"Unrecognized subsolver for MMA in mma_device.")
91 call profiler_end_region(
"MMA subsolve")
93 this%is_updated = .true.
94 end subroutine mma_update_device
96 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
97 class(mma_t),
intent(inout) :: this
98 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
100 if (this%subsolver .eq.
"dip")
then
101 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
103 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
105 end subroutine mma_kkt_device
109 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
110 class(mma_t),
intent(inout) :: this
111 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
113 type(vector_t),
pointer :: relambda, remu
116 call neko_scratch_registry%request(relambda, ind(1), this%m, .false.)
117 call neko_scratch_registry%request(remu, ind(2), this%m, .false.)
120 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
122 call device_sub2(relambda%x_d, this%y%x_d, this%m)
123 call device_add2(relambda%x_d, this%mu%x_d, this%m)
126 call device_col3(remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
128 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
129 device_maxval(remu%x_d, this%m)])
130 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
131 device_norm(remu%x_d, this%m))
133 call neko_scratch_registry%relinquish(ind)
134 end subroutine mma_dip_kkt_device
138 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
139 class(mma_t),
intent(inout) :: this
140 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
142 real(kind=rp) :: rez, rezeta
143 type(vector_t),
pointer :: rey, relambda, remu, res
144 type(vector_t),
pointer :: rex, rexsi, reeta
145 integer :: ierr, ind(7)
146 real(kind=rp) :: re_sq_norm
148 call neko_scratch_registry%request(rey, ind(1), this%m, .false.)
149 call neko_scratch_registry%request(relambda, ind(2), this%m, .false.)
150 call neko_scratch_registry%request(remu, ind(3), this%m, .false.)
151 call neko_scratch_registry%request(res, ind(4), this%m, .false.)
153 call neko_scratch_registry%request(rex, ind(5), this%n, .false.)
154 call neko_scratch_registry%request(rexsi, ind(6), this%n, .false.)
155 call neko_scratch_registry%request(reeta, ind(7), this%n, .false.)
157 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
158 this%eta%x_d, this%lambda%x_d, this%n, this%m)
160 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
161 call device_add2(rey%x_d, this%c%x_d, this%m)
162 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
163 call device_sub2(rey%x_d, this%mu%x_d, this%m)
165 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
168 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
170 call device_sub2(relambda%x_d, this%y%x_d, this%m)
171 call device_add2(relambda%x_d, this%s%x_d, this%m)
173 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
174 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
176 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
177 call device_col2(reeta%x_d, this%eta%x_d, this%n)
179 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
181 rezeta = this%zeta * this%z
183 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
185 this%residumax = maxval([ &
186 device_maxval(rex%x_d, this%n), &
187 device_maxval(rey%x_d, this%m), &
189 device_maxval(relambda%x_d, this%m), &
190 device_maxval(rexsi%x_d, this%n), &
191 device_maxval(reeta%x_d, this%n), &
192 device_maxval(remu%x_d, this%m), &
194 device_maxval(res%x_d, this%m)])
196 re_sq_norm = device_norm(rex%x_d, this%n) + &
197 device_norm(rexsi%x_d, this%n) + &
198 device_norm(reeta%x_d, this%n)
200 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
201 mpi_real_precision, mpi_max, neko_comm, ierr)
203 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
204 mpi_real_precision, mpi_sum, neko_comm, ierr)
206 this%residunorm = sqrt(( &
207 device_norm(rey%x_d, this%m) + &
209 device_norm(relambda%x_d, this%m) + &
210 device_norm(remu%x_d, this%m) + &
212 device_norm(res%x_d, this%m) &
215 call neko_scratch_registry%relinquish(ind)
216 end subroutine mma_dpip_kkt_device
222 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
228 class(mma_t),
intent(inout) :: this
229 type(c_ptr),
intent(in) :: x
230 type(c_ptr),
intent(in) :: df0dx
231 type(c_ptr),
intent(in) :: fval
232 type(c_ptr),
intent(in) :: dfdx
234 integer,
intent(in) :: iter
237 type(vector_t),
pointer :: x_diff
240 call neko_scratch_registry%request(x_diff, ind, this%n, .false.)
242 call device_sub3(x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
247 if (iter .lt. 3)
then
248 call device_copy(this%low%x_d, x, this%n)
249 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
250 call device_copy(this%upp%x_d, x, this%n)
251 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
253 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
254 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
255 this%asydecr, this%asyincr, this%n)
261 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
262 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
263 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
264 this%qij%x_d, this%n, this%m)
269 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
270 this%qij%x_d, this%n, this%m, this%bi%x_d)
272 if (neko_device_mpi)
then
273 call mpi_allreduce(mpi_in_place, this%bi%x_d, this%m, &
274 mpi_real_precision, mpi_sum, neko_comm, ierr)
276 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
278 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
279 mpi_real_precision, mpi_sum, neko_comm, ierr)
280 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
283 call device_sub2(this%bi%x_d, fval, this%m)
285 call neko_scratch_registry%relinquish(ind)
286 end subroutine mma_gensub_device
290 subroutine mma_subsolve_dpip_device(this, designx_d)
291 class(mma_t),
intent(inout) :: this
292 type(c_ptr),
intent(in) :: designx_d
293 integer :: iter, itto, ierr
294 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
295 delz, dz, dzeta, steg, zold, zetaold, new_residual
297 type(vector_t) ,
pointer :: y, lambda, s, mu, rey, relambda, remu, res, &
298 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
301 type(vector_t),
pointer :: x, xsi, eta, rex, rexsi, reeta, &
302 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
304 type(vector_t),
pointer :: bb
305 type(matrix_t),
pointer :: GG
306 type(matrix_t),
pointer :: AA
309 real(kind=rp) :: re_sq_norm
313 real(kind=rp) :: minimal_epsilon
315 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
316 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
317 call neko_scratch_registry%request(s, ind(3), this%m, .false.)
318 call neko_scratch_registry%request(mu, ind(4), this%m, .false.)
319 call neko_scratch_registry%request(rey, ind(5), this%m, .false.)
320 call neko_scratch_registry%request(relambda, ind(6), this%m, .false.)
321 call neko_scratch_registry%request(remu, ind(7), this%m, .false.)
322 call neko_scratch_registry%request(res, ind(8), this%m, .false.)
323 call neko_scratch_registry%request(dely, ind(9), this%m, .false.)
324 call neko_scratch_registry%request(dellambda, ind(10), this%m, .false.)
325 call neko_scratch_registry%request(dy, ind(11), this%m, .false.)
326 call neko_scratch_registry%request(dlambda, ind(12), this%m, .false.)
327 call neko_scratch_registry%request(ds, ind(13), this%m, .false.)
328 call neko_scratch_registry%request(dmu, ind(14), this%m, .false.)
329 call neko_scratch_registry%request(yold, ind(15), this%m, .false.)
330 call neko_scratch_registry%request(lambdaold, ind(16), this%m, .false.)
331 call neko_scratch_registry%request(sold, ind(17), this%m, .false.)
332 call neko_scratch_registry%request(muold, ind(18), this%m, .false.)
333 call neko_scratch_registry%request(x, ind(19), this%n, .false.)
334 call neko_scratch_registry%request(xsi, ind(20), this%n, .false.)
335 call neko_scratch_registry%request(eta, ind(21), this%n, .false.)
336 call neko_scratch_registry%request(rex, ind(22), this%n, .false.)
337 call neko_scratch_registry%request(rexsi, ind(23), this%n, .false.)
338 call neko_scratch_registry%request(reeta, ind(24), this%n, .false.)
339 call neko_scratch_registry%request(delx, ind(25), this%n, .false.)
340 call neko_scratch_registry%request(diagx, ind(26), this%n, .false.)
341 call neko_scratch_registry%request(dx, ind(27), this%n, .false.)
342 call neko_scratch_registry%request(dxsi, ind(28), this%n, .false.)
343 call neko_scratch_registry%request(deta, ind(29), this%n, .false.)
344 call neko_scratch_registry%request(xold, ind(30), this%n, .false.)
345 call neko_scratch_registry%request(xsiold, ind(31), this%n, .false.)
346 call neko_scratch_registry%request(etaold, ind(32), this%n, .false.)
347 call neko_scratch_registry%request(bb, ind(33), this%m+1, .false.)
349 call neko_scratch_registry%request(gg, ind(34), this%m, this%n, .false.)
350 call neko_scratch_registry%request(aa, ind(35), this%m+1, this%m+1, .false.)
357 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
359 call device_cfill(y%x_d, 1.0_rp, this%m)
362 call device_cfill(lambda%x_d, 1.0_rp, this%m)
363 call device_cfill(s%x_d, 1.0_rp, this%m)
364 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
365 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
366 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
371 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
372 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
373 mpi_real_precision, mpi_min, neko_comm, ierr)
378 do while (epsi .gt. minimal_epsilon)
385 associate(p0j => this%p0j, q0j => this%q0j, &
386 pij => this%pij, qij => this%qij, &
387 low => this%low, upp => this%upp, &
388 alpha => this%alpha, beta => this%beta, &
389 c => this%c, d => this%d, &
390 a0 => this%a0, a => this%a)
392 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
393 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
394 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
396 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
397 call device_add2(rey%x_d, c%x_d, this%m)
398 call device_sub2(rey%x_d, lambda%x_d, this%m)
399 call device_sub2(rey%x_d, mu%x_d, this%m)
400 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
402 call device_cfill(relambda%x_d, 0.0_rp, this%m)
403 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
404 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
412 if (neko_device_mpi)
then
413 call mpi_allreduce(mpi_in_place, relambda%x_d, this%m, &
414 mpi_real_precision, mpi_sum, neko_comm, ierr)
416 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
418 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
419 mpi_real_precision, mpi_sum, neko_comm, ierr)
420 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
424 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
425 call device_sub2(relambda%x_d, y%x_d, this%m)
426 call device_add2(relambda%x_d, s%x_d, this%m)
427 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
429 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
430 call device_col2(rexsi%x_d, xsi%x_d, this%n)
431 call device_cadd(rexsi%x_d, - epsi, this%n)
433 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
434 call device_col2(reeta%x_d, eta%x_d, this%n)
435 call device_cadd(reeta%x_d, - epsi, this%n)
437 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
438 call device_cadd(remu%x_d, - epsi, this%m)
440 rezeta = zeta * z - epsi
442 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
443 call device_cadd(res%x_d, - epsi, this%m)
446 residual_max = maxval([device_maxval(rex%x_d, this%n), &
447 device_maxval(rey%x_d, this%m), abs(rez), &
448 device_maxval(relambda%x_d, this%m), &
449 device_maxval(rexsi%x_d, this%n), &
450 device_maxval(reeta%x_d, this%n), &
451 device_maxval(remu%x_d, this%m), abs(rezeta), &
452 device_maxval(res%x_d, this%m)])
454 re_sq_norm = device_norm(rex%x_d, this%n) + &
455 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
457 call mpi_allreduce(mpi_in_place, residual_max, 1, &
458 mpi_real_precision, mpi_max, neko_comm, ierr)
460 call mpi_allreduce(mpi_in_place, re_sq_norm, &
461 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
463 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
465 device_norm(relambda%x_d, this%m) + &
466 device_norm(remu%x_d, this%m)+ &
468 device_norm(res%x_d, this%m) &
474 do iter = 1, this%max_iter
476 if (residual_max .lt. epsi)
exit
478 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
479 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
480 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
483 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
484 call device_add2(dely%x_d, this%c%x_d, this%m)
485 call device_sub2(dely%x_d, lambda%x_d, this%m)
486 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
487 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
490 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
491 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
492 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
494 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
495 device_to_host, sync = .true.)
496 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
497 mpi_real_precision, mpi_sum, neko_comm, ierr)
498 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
499 host_to_device, sync = .true.)
501 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
503 call device_sub2(dellambda%x_d, y%x_d, this%m)
504 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
505 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
507 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
508 this%pij%x_d, this%qij%x_d, this%n, this%m)
510 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
511 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
512 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
513 lambda%x_d, this%n, this%m)
523 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
526 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
528 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
529 mpi_real_precision, mpi_sum, neko_comm, ierr)
530 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
531 host_to_device, sync = .true.)
533 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
534 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
543 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
544 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
546 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
547 device_to_host, sync = .true.)
548 call mpi_allreduce(mpi_in_place, aa%x, &
549 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
550 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
551 host_to_device, sync = .true.)
553 call device_prepare_aa_matrix(aa%x_d, s%x_d, lambda%x_d, &
554 this%d%x_d, mu%x_d, y%x_d, this%a%x_d, zeta, z, this%m)
557 call device_solve_linear_system(aa%x_d, bb%x_d, this%m + 1, info)
558 if (info .ne. 0)
then
559 call neko_error(
"Linear solver failed on the device in " // &
563 call device_copy(dlambda%x_d, bb%x_d, this%m)
567 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
569 dz = bb%x(this%m + 1)
573 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
574 dlambda%x_d, this%n, this%m)
575 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
576 mu%x_d, y%x_d, this%m)
577 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
578 this%alpha%x_d, epsi, this%n)
579 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
580 this%beta%x_d, epsi, this%n)
582 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
583 call device_cmult(dmu%x_d, -1.0_rp, this%m)
584 call device_cadd(dmu%x_d, epsi, this%m)
585 call device_invcol2(dmu%x_d, y%x_d, this%m)
586 call device_sub2(dmu%x_d, mu%x_d, this%m)
587 dzeta = -zeta + (epsi - zeta * dz) / z
588 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
589 call device_cmult(ds%x_d, -1.0_rp, this%m)
590 call device_cadd(ds%x_d, epsi, this%m)
591 call device_invcol2(ds%x_d, lambda%x_d, this%m)
592 call device_sub2(ds%x_d, s%x_d, this%m)
594 steg = maxval([1.0_rp, &
595 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
597 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
598 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
599 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
600 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
601 -1.01_rp * dzeta / zeta, &
602 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
603 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
604 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
608 call device_copy(xold%x_d, x%x_d, this%n)
609 call device_copy(yold%x_d, y%x_d, this%m)
611 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
612 call device_copy(xsiold%x_d, xsi%x_d, this%n)
613 call device_copy(etaold%x_d, eta%x_d, this%n)
614 call device_copy(muold%x_d, mu%x_d, this%m)
616 call device_copy(sold%x_d, s%x_d, this%m)
618 new_residual = 2.0_rp * residual_norm
621 call mpi_allreduce(mpi_in_place, steg, 1, &
622 mpi_real_precision, mpi_min, neko_comm, ierr)
623 call mpi_allreduce(mpi_in_place, new_residual, 1, &
624 mpi_real_precision, mpi_min, neko_comm, ierr)
629 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
633 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
634 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
636 call device_add3s2(lambda%x_d, lambdaold%x_d, &
637 dlambda%x_d, 1.0_rp, steg, this%m)
638 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
639 1.0_rp, steg, this%n)
640 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
641 1.0_rp, steg, this%n)
642 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
643 1.0_rp, steg, this%m)
644 zeta = zetaold + steg*dzeta
645 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
650 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
651 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
652 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
653 eta%x_d, this%n, this%m)
655 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
656 call device_add2(rey%x_d, this%c%x_d, this%m)
657 call device_sub2(rey%x_d, lambda%x_d, this%m)
658 call device_sub2(rey%x_d, mu%x_d, this%m)
660 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
663 call device_cfill(relambda%x_d, 0.0_rp, this%m)
664 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
665 this%low%x_d, this%pij%x_d, this%qij%x_d, &
668 call device_memcpy(relambda%x, relambda%x_d, this%m, &
669 device_to_host, sync = .true.)
670 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
671 mpi_real_precision, mpi_sum, neko_comm, ierr)
672 call device_memcpy(relambda%x, relambda%x_d, &
673 this%m, host_to_device, sync = .true.)
675 call device_add3s2(relambda%x_d, relambda%x_d, &
676 this%a%x_d, 1.0_rp, -z, this%m)
677 call device_sub2(relambda%x_d, y%x_d, this%m)
678 call device_add2(relambda%x_d, s%x_d, this%m)
679 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
681 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
682 call device_col2(rexsi%x_d, xsi%x_d, this%n)
683 call device_cadd(rexsi%x_d, - epsi, this%n)
685 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
686 call device_col2(reeta%x_d, eta%x_d, this%n)
687 call device_cadd(reeta%x_d, - epsi, this%n)
689 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
690 call device_cadd(remu%x_d, - epsi, this%m)
692 rezeta = zeta*z - epsi
694 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
695 call device_cadd(res%x_d, - epsi, this%m)
698 re_sq_norm = device_norm(rex%x_d, this%n) + &
699 device_norm(rexsi%x_d, this%n) + &
700 device_norm(reeta%x_d, this%n)
701 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
702 mpi_real_precision, mpi_sum, neko_comm, ierr)
704 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
706 device_norm(relambda%x_d, this%m) + &
707 device_norm(remu%x_d, this%m) + &
709 device_norm(res%x_d, this%m) + &
718 residual_norm = new_residual
719 residual_max = maxval([ &
720 device_maxval(rex%x_d, this%n), &
721 device_maxval(rey%x_d, this%m), &
723 device_maxval(relambda%x_d, this%m), &
724 device_maxval(rexsi%x_d, this%n), &
725 device_maxval(reeta%x_d, this%n), &
726 device_maxval(remu%x_d, this%m), &
728 device_maxval(res%x_d, this%m)])
730 call mpi_allreduce(mpi_in_place, residual_max, 1, &
731 mpi_real_precision, mpi_max, neko_comm, ierr)
739 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
740 call device_copy(this%xold1%x_d, designx_d, this%n)
741 call device_copy(designx_d, x%x_d, this%n)
744 call device_copy(this%y%x_d, y%x_d, this%m)
746 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
748 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
749 call device_copy(this%eta%x_d, eta%x_d, this%n)
750 call device_copy(this%mu%x_d, mu%x_d, this%m)
751 call device_copy(this%s%x_d, s%x_d, this%m)
754 call neko_scratch_registry%relinquish(ind)
755 end subroutine mma_subsolve_dpip_device
759 subroutine mma_subsolve_dip_device(this, designx_d)
760 class(mma_t),
intent(inout) :: this
761 type(c_ptr),
intent(in) :: designx_d
762 integer :: iter, ierr
763 real(kind=rp) :: epsi, residumax, z, steg
765 type(vector_t),
pointer :: y, lambda, mu, relambda, remu, dlambda, dmu, &
766 gradlambda, zerom, dd, dummy_m
768 type(vector_t),
pointer :: x, pjlambda, qjlambda
771 type(vector_t),
pointer :: Ljjxinv
772 type(matrix_t),
pointer :: hijx
773 type(matrix_t),
pointer :: Hess
775 integer :: info, ind(17)
777 real(kind=rp) :: minimal_epsilon
779 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
780 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
781 call neko_scratch_registry%request(mu, ind(3), this%m, .false.)
782 call neko_scratch_registry%request(relambda, ind(4), this%m, .false.)
783 call neko_scratch_registry%request(remu, ind(5), this%m, .false.)
784 call neko_scratch_registry%request(dlambda, ind(6), this%m, .false.)
785 call neko_scratch_registry%request(dmu, ind(7), this%m, .false.)
786 call neko_scratch_registry%request(gradlambda, ind(8), this%m, .false.)
787 call neko_scratch_registry%request(zerom, ind(9), this%m, .false.)
788 call neko_scratch_registry%request(dd, ind(10), this%m, .false.)
789 call neko_scratch_registry%request(dummy_m, ind(11), this%m, .false.)
791 call neko_scratch_registry%request(x, ind(12), this%n, .false.)
792 call neko_scratch_registry%request(pjlambda,ind(13), this%n, .false.)
793 call neko_scratch_registry%request(qjlambda, ind(14), this%n, .false.)
795 call neko_scratch_registry%request(ljjxinv, ind(15), this%n, .false.)
797 call neko_scratch_registry%request(hijx, ind(16), this%m, this%n, .false.)
798 call neko_scratch_registry%request(hess, ind(17), this%m, this%m, .false.)
805 call device_cfill(y%x_d, 1.0_rp, this%m)
807 call device_cfill(lambda%x_d, 1.0_rp, this%m)
808 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
809 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
811 call device_cfill(mu%x_d, 1.0_rp, this%m)
815 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
820 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
821 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
822 mpi_real_precision, mpi_min, neko_comm, ierr)
827 outer:
do while (epsi .gt. minimal_epsilon)
831 associate(p0j => this%p0j, q0j => this%q0j, &
832 pij => this%pij, qij => this%qij, &
833 low => this%low, upp => this%upp, &
834 alpha => this%alpha, beta => this%beta, &
835 c => this%c, a0 => this%a0, a => this%a)
843 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
845 call device_invcol2(y%x_d, dd%x_d, this%m)
846 call device_pwmax2(y%x_d, zerom%x_d, this%m)
852 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
853 z = device_glsum(dummy_m%x_d, this%m)
854 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
860 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
861 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
862 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
863 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
865 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
866 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
868 call device_cfill(relambda%x_d, 0.0_rp, this%m)
869 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
870 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
874 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
876 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
877 mpi_real_precision, mpi_sum, neko_comm, ierr)
878 call device_memcpy(relambda%x, relambda%x_d, this%m, &
879 host_to_device, sync = .true.)
881 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
882 call device_sub2(relambda%x_d, y%x_d, this%m)
883 call device_add2(relambda%x_d, mu%x_d, this%m)
884 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
886 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
887 call device_cadd(remu%x_d, -epsi, this%m)
889 residumax = maxval([device_maxval(relambda%x_d, this%m), &
890 device_maxval(remu%x_d, this%m)])
894 do iter = 1, this%max_iter
896 if (residumax .lt. epsi)
exit
904 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
905 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
908 call device_cfill(dummy_m%x_d, epsi, this%m)
909 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
910 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
911 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
917 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
918 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
920 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
921 this%pij%x_d, this%qij%x_d, this%n, this%m)
923 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
924 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
927 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
929 call mpi_allreduce(mpi_in_place, hess%x, &
930 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
933 call device_memcpy(hess%x, hess%x_d, this%m*this%m, &
934 host_to_device, sync = .true.)
949 call device_prepare_hessian(hess%x_d, y%x_d, this%d%x_d, &
950 mu%x_d, lambda%x_d, this%m)
953 call device_solve_linear_system(hess%x_d, gradlambda%x_d, &
955 if (info .ne. 0)
then
956 call neko_error(
"Linear solver failed on the device in " // &
960 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
963 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
964 call device_col2(dummy_m%x_d, mu%x_d, this%m)
965 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
967 call device_cfill(dmu%x_d, epsi, this%m)
968 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
969 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
970 call device_sub2(dmu%x_d, mu%x_d, this%m)
972 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
973 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
977 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
978 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
987 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
989 call device_invcol2(y%x_d, dd%x_d, this%m)
990 call device_pwmax2(y%x_d, zerom%x_d, this%m)
996 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
997 z = device_glsum(dummy_m%x_d, this%m)
998 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1004 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1005 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1006 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1007 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1009 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1010 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1014 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1015 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1016 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1020 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1022 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1023 mpi_real_precision, mpi_sum, neko_comm, ierr)
1024 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1025 host_to_device, sync = .true.)
1027 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1028 call device_sub2(relambda%x_d, y%x_d, this%m)
1029 call device_add2(relambda%x_d, mu%x_d, this%m)
1030 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1032 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1033 call device_cadd(remu%x_d, -epsi, this%m)
1035 residumax = maxval([device_maxval(relambda%x_d, this%m), &
1036 device_maxval(remu%x_d, this%m)])
1039 epsi = 0.1_rp * epsi
1043 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1044 call device_copy(this%xold1%x_d, designx_d, this%n)
1045 call device_copy(designx_d, x%x_d, this%n)
1048 call device_copy(this%y%x_d, y%x_d, this%m)
1050 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1051 call device_copy(this%mu%x_d, mu%x_d, this%m)
1053 call neko_scratch_registry%relinquish(ind)
1054 end subroutine mma_subsolve_dip_device
1056end submodule mma_device