35submodule(
mma) mma_device
37 use device_math,
only: device_copy, device_cmult, device_cadd, device_cfill, &
38 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
39 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
40 device_glsum, device_cmult2
41 use device_mma_math,
only: device_maxval, device_norm, device_lcsc2, &
42 device_maxval2, device_maxval3, device_mma_gensub3, &
43 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
44 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
45 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
46 device_dy, device_dxsi, device_deta, device_kkt_rex, &
47 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
48 device_mma_ljjxinv, device_hess, device_solve_linear_system, &
49 device_prepare_hessian, device_prepare_aa_matrix, device_update_hessian_z
51 use neko_config,
only: neko_bcknd_device, neko_device_mpi
52 use device,
only: device_to_host
53 use comm,
only: neko_comm, pe_rank, mpi_real_precision
54 use mpi_f08,
only: mpi_in_place, mpi_max, mpi_min
55 use profiler,
only: profiler_start_region, profiler_end_region
61 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
69 class(mma_t),
intent(inout) :: this
70 integer,
intent(in) :: iter
71 type(c_ptr),
intent(inout) :: x
72 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
74 if (.not. this%is_initialized)
then
75 call neko_error(
"The MMA object is not initialized.")
78 call profiler_start_region(
"MMA gensub")
80 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
81 call profiler_end_region(
"MMA gensub")
84 call profiler_start_region(
"MMA subsolve")
85 if (this%subsolver .eq.
"dip")
then
86 call mma_subsolve_dip_device(this, x)
87 else if (this%subsolver .eq.
"dpip")
then
88 call mma_subsolve_dpip_device(this, x)
90 call neko_error(
"Unrecognized subsolver for MMA in mma_device.")
92 call profiler_end_region(
"MMA subsolve")
94 this%is_updated = .true.
95 end subroutine mma_update_device
97 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
98 class(mma_t),
intent(inout) :: this
99 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
101 if (this%subsolver .eq.
"dip")
then
102 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
104 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
106 end subroutine mma_kkt_device
110 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
111 class(mma_t),
intent(inout) :: this
112 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
114 type(vector_t),
pointer :: relambda, remu
117 call this%scratch%request(relambda, ind(1), this%m, .false.)
118 call this%scratch%request(remu, ind(2), this%m, .false.)
121 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
123 call device_sub2(relambda%x_d, this%y%x_d, this%m)
124 call device_add2(relambda%x_d, this%mu%x_d, this%m)
127 call device_col3(remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
129 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
130 device_maxval(remu%x_d, this%m)])
131 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
132 device_norm(remu%x_d, this%m))
134 call this%scratch%relinquish(ind)
135 end subroutine mma_dip_kkt_device
139 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
140 class(mma_t),
intent(inout) :: this
141 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
143 real(kind=rp) :: rez, rezeta
144 type(vector_t),
pointer :: rey, relambda, remu, res
145 type(vector_t),
pointer :: rex, rexsi, reeta
146 integer :: ierr, ind(7)
147 real(kind=rp) :: re_sq_norm
149 call this%scratch%request(rey, ind(1), this%m, .false.)
150 call this%scratch%request(relambda, ind(2), this%m, .false.)
151 call this%scratch%request(remu, ind(3), this%m, .false.)
152 call this%scratch%request(res, ind(4), this%m, .false.)
154 call this%scratch%request(rex, ind(5), this%n, .false.)
155 call this%scratch%request(rexsi, ind(6), this%n, .false.)
156 call this%scratch%request(reeta, ind(7), this%n, .false.)
158 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
159 this%eta%x_d, this%lambda%x_d, this%n, this%m)
161 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
162 call device_add2(rey%x_d, this%c%x_d, this%m)
163 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
164 call device_sub2(rey%x_d, this%mu%x_d, this%m)
166 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
169 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
171 call device_sub2(relambda%x_d, this%y%x_d, this%m)
172 call device_add2(relambda%x_d, this%s%x_d, this%m)
174 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
175 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
177 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
178 call device_col2(reeta%x_d, this%eta%x_d, this%n)
180 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
182 rezeta = this%zeta * this%z
184 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
186 this%residumax = maxval([ &
187 device_maxval(rex%x_d, this%n), &
188 device_maxval(rey%x_d, this%m), &
190 device_maxval(relambda%x_d, this%m), &
191 device_maxval(rexsi%x_d, this%n), &
192 device_maxval(reeta%x_d, this%n), &
193 device_maxval(remu%x_d, this%m), &
195 device_maxval(res%x_d, this%m)])
197 re_sq_norm = device_norm(rex%x_d, this%n) + &
198 device_norm(rexsi%x_d, this%n) + &
199 device_norm(reeta%x_d, this%n)
201 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
202 mpi_real_precision, mpi_max, neko_comm, ierr)
204 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
205 mpi_real_precision, mpi_sum, neko_comm, ierr)
207 this%residunorm = sqrt(( &
208 device_norm(rey%x_d, this%m) + &
210 device_norm(relambda%x_d, this%m) + &
211 device_norm(remu%x_d, this%m) + &
213 device_norm(res%x_d, this%m) &
216 call this%scratch%relinquish(ind)
217 end subroutine mma_dpip_kkt_device
223 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
229 class(mma_t),
intent(inout) :: this
230 type(c_ptr),
intent(in) :: x
231 type(c_ptr),
intent(in) :: df0dx
232 type(c_ptr),
intent(in) :: fval
233 type(c_ptr),
intent(in) :: dfdx
235 integer,
intent(in) :: iter
238 type(vector_t),
pointer :: x_diff
241 call this%scratch%request(x_diff, ind, this%n, .false.)
243 call device_sub3(x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
248 if (iter .lt. 3)
then
249 call device_copy(this%low%x_d, x, this%n)
250 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
251 call device_copy(this%upp%x_d, x, this%n)
252 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
254 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
255 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
256 this%asydecr, this%asyincr, this%n)
262 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
263 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
264 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
265 this%qij%x_d, this%n, this%m)
270 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
271 this%qij%x_d, this%n, this%m, this%bi%x_d)
273 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
275 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
276 mpi_real_precision, mpi_sum, neko_comm, ierr)
277 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
280 call device_sub2(this%bi%x_d, fval, this%m)
282 call this%scratch%relinquish(ind)
283 end subroutine mma_gensub_device
287 subroutine mma_subsolve_dpip_device(this, designx_d)
288 class(mma_t),
intent(inout) :: this
289 type(c_ptr),
intent(in) :: designx_d
290 integer :: iter, itto, ierr
291 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
292 delz, dz, dzeta, steg, zold, zetaold, new_residual
294 type(vector_t) ,
pointer :: y, lambda, s, mu, rey, relambda, remu, res, &
295 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
298 type(vector_t),
pointer :: x, xsi, eta, rex, rexsi, reeta, &
299 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
301 type(vector_t),
pointer :: bb
302 type(matrix_t),
pointer :: GG
303 type(matrix_t),
pointer :: AA
306 real(kind=rp) :: re_sq_norm
310 real(kind=rp) :: minimal_epsilon
312 call this%scratch%request(y, ind(1), this%m, .false.)
313 call this%scratch%request(lambda, ind(2), this%m, .false.)
314 call this%scratch%request(s, ind(3), this%m, .false.)
315 call this%scratch%request(mu, ind(4), this%m, .false.)
316 call this%scratch%request(rey, ind(5), this%m, .false.)
317 call this%scratch%request(relambda, ind(6), this%m, .false.)
318 call this%scratch%request(remu, ind(7), this%m, .false.)
319 call this%scratch%request(res, ind(8), this%m, .false.)
320 call this%scratch%request(dely, ind(9), this%m, .false.)
321 call this%scratch%request(dellambda, ind(10), this%m, .false.)
322 call this%scratch%request(dy, ind(11), this%m, .false.)
323 call this%scratch%request(dlambda, ind(12), this%m, .false.)
324 call this%scratch%request(ds, ind(13), this%m, .false.)
325 call this%scratch%request(dmu, ind(14), this%m, .false.)
326 call this%scratch%request(yold, ind(15), this%m, .false.)
327 call this%scratch%request(lambdaold, ind(16), this%m, .false.)
328 call this%scratch%request(sold, ind(17), this%m, .false.)
329 call this%scratch%request(muold, ind(18), this%m, .false.)
330 call this%scratch%request(x, ind(19), this%n, .false.)
331 call this%scratch%request(xsi, ind(20), this%n, .false.)
332 call this%scratch%request(eta, ind(21), this%n, .false.)
333 call this%scratch%request(rex, ind(22), this%n, .false.)
334 call this%scratch%request(rexsi, ind(23), this%n, .false.)
335 call this%scratch%request(reeta, ind(24), this%n, .false.)
336 call this%scratch%request(delx, ind(25), this%n, .false.)
337 call this%scratch%request(diagx, ind(26), this%n, .false.)
338 call this%scratch%request(dx, ind(27), this%n, .false.)
339 call this%scratch%request(dxsi, ind(28), this%n, .false.)
340 call this%scratch%request(deta, ind(29), this%n, .false.)
341 call this%scratch%request(xold, ind(30), this%n, .false.)
342 call this%scratch%request(xsiold, ind(31), this%n, .false.)
343 call this%scratch%request(etaold, ind(32), this%n, .false.)
344 call this%scratch%request(bb, ind(33), this%m+1, .false.)
346 call this%scratch%request(gg, ind(34), this%m, this%n, .false.)
347 call this%scratch%request(aa, ind(35), this%m+1, this%m+1, .false.)
354 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
356 call device_cfill(y%x_d, 1.0_rp, this%m)
359 call device_cfill(lambda%x_d, 1.0_rp, this%m)
360 call device_cfill(s%x_d, 1.0_rp, this%m)
361 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
362 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
363 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
368 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
369 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
370 mpi_real_precision, mpi_min, neko_comm, ierr)
375 do while (epsi .gt. minimal_epsilon)
382 associate(p0j => this%p0j, q0j => this%q0j, &
383 pij => this%pij, qij => this%qij, &
384 low => this%low, upp => this%upp, &
385 alpha => this%alpha, beta => this%beta, &
386 c => this%c, d => this%d, &
387 a0 => this%a0, a => this%a)
389 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
390 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
391 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
393 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
394 call device_add2(rey%x_d, c%x_d, this%m)
395 call device_sub2(rey%x_d, lambda%x_d, this%m)
396 call device_sub2(rey%x_d, mu%x_d, this%m)
397 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
399 call device_cfill(relambda%x_d, 0.0_rp, this%m)
400 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
401 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
409 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
411 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
412 mpi_real_precision, mpi_sum, neko_comm, ierr)
413 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
416 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
417 call device_sub2(relambda%x_d, y%x_d, this%m)
418 call device_add2(relambda%x_d, s%x_d, this%m)
419 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
421 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
422 call device_col2(rexsi%x_d, xsi%x_d, this%n)
423 call device_cadd(rexsi%x_d, - epsi, this%n)
425 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
426 call device_col2(reeta%x_d, eta%x_d, this%n)
427 call device_cadd(reeta%x_d, - epsi, this%n)
429 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
430 call device_cadd(remu%x_d, - epsi, this%m)
432 rezeta = zeta * z - epsi
434 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
435 call device_cadd(res%x_d, - epsi, this%m)
438 residual_max = maxval([device_maxval(rex%x_d, this%n), &
439 device_maxval(rey%x_d, this%m), abs(rez), &
440 device_maxval(relambda%x_d, this%m), &
441 device_maxval(rexsi%x_d, this%n), &
442 device_maxval(reeta%x_d, this%n), &
443 device_maxval(remu%x_d, this%m), abs(rezeta), &
444 device_maxval(res%x_d, this%m)])
446 re_sq_norm = device_norm(rex%x_d, this%n) + &
447 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
449 call mpi_allreduce(mpi_in_place, residual_max, 1, &
450 mpi_real_precision, mpi_max, neko_comm, ierr)
452 call mpi_allreduce(mpi_in_place, re_sq_norm, &
453 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
455 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
457 device_norm(relambda%x_d, this%m) + &
458 device_norm(remu%x_d, this%m)+ &
460 device_norm(res%x_d, this%m) &
466 do iter = 1, this%max_iter
468 if (residual_max .lt. epsi)
exit
470 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
471 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
472 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
475 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
476 call device_add2(dely%x_d, this%c%x_d, this%m)
477 call device_sub2(dely%x_d, lambda%x_d, this%m)
478 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
479 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
482 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
483 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
484 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
486 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
487 device_to_host, sync = .true.)
488 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
489 mpi_real_precision, mpi_sum, neko_comm, ierr)
490 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
491 host_to_device, sync = .true.)
493 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
495 call device_sub2(dellambda%x_d, y%x_d, this%m)
496 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
497 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
499 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
500 this%pij%x_d, this%qij%x_d, this%n, this%m)
502 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
503 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
504 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
505 lambda%x_d, this%n, this%m)
515 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
518 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
520 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
521 mpi_real_precision, mpi_sum, neko_comm, ierr)
522 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
523 host_to_device, sync = .true.)
525 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
526 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
535 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
536 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
538 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
539 device_to_host, sync = .true.)
540 call mpi_allreduce(mpi_in_place, aa%x, &
541 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
542 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
543 host_to_device, sync = .true.)
545 call device_prepare_aa_matrix(aa%x_d, s%x_d, lambda%x_d, &
546 this%d%x_d, mu%x_d, y%x_d, this%a%x_d, zeta, z, this%m)
549 call device_solve_linear_system(aa%x_d, bb%x_d, this%m + 1, info)
550 if (info .ne. 0)
then
551 call neko_error(
"Linear solver failed on the device in " // &
555 call device_copy(dlambda%x_d, bb%x_d, this%m)
559 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
561 dz = bb%x(this%m + 1)
565 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
566 dlambda%x_d, this%n, this%m)
567 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
568 mu%x_d, y%x_d, this%m)
569 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
570 this%alpha%x_d, epsi, this%n)
571 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
572 this%beta%x_d, epsi, this%n)
574 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
575 call device_cmult(dmu%x_d, -1.0_rp, this%m)
576 call device_cadd(dmu%x_d, epsi, this%m)
577 call device_invcol2(dmu%x_d, y%x_d, this%m)
578 call device_sub2(dmu%x_d, mu%x_d, this%m)
579 dzeta = -zeta + (epsi - zeta * dz) / z
580 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
581 call device_cmult(ds%x_d, -1.0_rp, this%m)
582 call device_cadd(ds%x_d, epsi, this%m)
583 call device_invcol2(ds%x_d, lambda%x_d, this%m)
584 call device_sub2(ds%x_d, s%x_d, this%m)
586 steg = maxval([1.0_rp, &
587 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
589 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
590 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
591 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
592 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
593 -1.01_rp * dzeta / zeta, &
594 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
595 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
596 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
600 call device_copy(xold%x_d, x%x_d, this%n)
601 call device_copy(yold%x_d, y%x_d, this%m)
603 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
604 call device_copy(xsiold%x_d, xsi%x_d, this%n)
605 call device_copy(etaold%x_d, eta%x_d, this%n)
606 call device_copy(muold%x_d, mu%x_d, this%m)
608 call device_copy(sold%x_d, s%x_d, this%m)
610 new_residual = 2.0_rp * residual_norm
613 call mpi_allreduce(mpi_in_place, steg, 1, &
614 mpi_real_precision, mpi_min, neko_comm, ierr)
615 call mpi_allreduce(mpi_in_place, new_residual, 1, &
616 mpi_real_precision, mpi_min, neko_comm, ierr)
621 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
625 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
626 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
628 call device_add3s2(lambda%x_d, lambdaold%x_d, &
629 dlambda%x_d, 1.0_rp, steg, this%m)
630 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
631 1.0_rp, steg, this%n)
632 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
633 1.0_rp, steg, this%n)
634 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
635 1.0_rp, steg, this%m)
636 zeta = zetaold + steg*dzeta
637 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
642 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
643 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
644 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
645 eta%x_d, this%n, this%m)
647 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
648 call device_add2(rey%x_d, this%c%x_d, this%m)
649 call device_sub2(rey%x_d, lambda%x_d, this%m)
650 call device_sub2(rey%x_d, mu%x_d, this%m)
652 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
655 call device_cfill(relambda%x_d, 0.0_rp, this%m)
656 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
657 this%low%x_d, this%pij%x_d, this%qij%x_d, &
660 call device_memcpy(relambda%x, relambda%x_d, this%m, &
661 device_to_host, sync = .true.)
662 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
663 mpi_real_precision, mpi_sum, neko_comm, ierr)
664 call device_memcpy(relambda%x, relambda%x_d, &
665 this%m, host_to_device, sync = .true.)
667 call device_add3s2(relambda%x_d, relambda%x_d, &
668 this%a%x_d, 1.0_rp, -z, this%m)
669 call device_sub2(relambda%x_d, y%x_d, this%m)
670 call device_add2(relambda%x_d, s%x_d, this%m)
671 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
673 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
674 call device_col2(rexsi%x_d, xsi%x_d, this%n)
675 call device_cadd(rexsi%x_d, - epsi, this%n)
677 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
678 call device_col2(reeta%x_d, eta%x_d, this%n)
679 call device_cadd(reeta%x_d, - epsi, this%n)
681 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
682 call device_cadd(remu%x_d, - epsi, this%m)
684 rezeta = zeta*z - epsi
686 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
687 call device_cadd(res%x_d, - epsi, this%m)
690 re_sq_norm = device_norm(rex%x_d, this%n) + &
691 device_norm(rexsi%x_d, this%n) + &
692 device_norm(reeta%x_d, this%n)
693 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
694 mpi_real_precision, mpi_sum, neko_comm, ierr)
696 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
698 device_norm(relambda%x_d, this%m) + &
699 device_norm(remu%x_d, this%m) + &
701 device_norm(res%x_d, this%m) + &
704 call mpi_allreduce(mpi_in_place, new_residual, 1, &
705 mpi_real_precision, mpi_sum, neko_comm, ierr)
713 residual_norm = new_residual
714 residual_max = maxval([ &
715 device_maxval(rex%x_d, this%n), &
716 device_maxval(rey%x_d, this%m), &
718 device_maxval(relambda%x_d, this%m), &
719 device_maxval(rexsi%x_d, this%n), &
720 device_maxval(reeta%x_d, this%n), &
721 device_maxval(remu%x_d, this%m), &
723 device_maxval(res%x_d, this%m)])
725 call mpi_allreduce(mpi_in_place, residual_max, 1, &
726 mpi_real_precision, mpi_max, neko_comm, ierr)
734 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
735 call device_copy(this%xold1%x_d, designx_d, this%n)
736 call device_copy(designx_d, x%x_d, this%n)
739 call device_copy(this%y%x_d, y%x_d, this%m)
741 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
743 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
744 call device_copy(this%eta%x_d, eta%x_d, this%n)
745 call device_copy(this%mu%x_d, mu%x_d, this%m)
746 call device_copy(this%s%x_d, s%x_d, this%m)
749 call this%scratch%relinquish(ind)
750 end subroutine mma_subsolve_dpip_device
754 subroutine mma_subsolve_dip_device(this, designx_d)
755 class(mma_t),
intent(inout) :: this
756 type(c_ptr),
intent(in) :: designx_d
757 integer :: iter, ierr
758 real(kind=rp) :: epsi, residumax, z, steg
760 type(vector_t),
pointer :: y, lambda, mu, relambda, remu, dlambda, dmu, &
761 gradlambda, zerom, dd, dummy_m
763 type(vector_t),
pointer :: x, pjlambda, qjlambda
766 type(vector_t),
pointer :: Ljjxinv
767 type(matrix_t),
pointer :: hijx
768 type(matrix_t),
pointer :: Hess
770 integer :: info, ind(17)
772 real(kind=rp) :: minimal_epsilon
774 call this%scratch%request(y, ind(1), this%m, .false.)
775 call this%scratch%request(lambda, ind(2), this%m, .false.)
776 call this%scratch%request(mu, ind(3), this%m, .false.)
777 call this%scratch%request(relambda, ind(4), this%m, .false.)
778 call this%scratch%request(remu, ind(5), this%m, .false.)
779 call this%scratch%request(dlambda, ind(6), this%m, .false.)
780 call this%scratch%request(dmu, ind(7), this%m, .false.)
781 call this%scratch%request(gradlambda, ind(8), this%m, .false.)
782 call this%scratch%request(zerom, ind(9), this%m, .false.)
783 call this%scratch%request(dd, ind(10), this%m, .false.)
784 call this%scratch%request(dummy_m, ind(11), this%m, .false.)
786 call this%scratch%request(x, ind(12), this%n, .false.)
787 call this%scratch%request(pjlambda,ind(13), this%n, .false.)
788 call this%scratch%request(qjlambda, ind(14), this%n, .false.)
790 call this%scratch%request(ljjxinv, ind(15), this%n, .false.)
792 call this%scratch%request(hijx, ind(16), this%m, this%n, .false.)
793 call this%scratch%request(hess, ind(17), this%m, this%m, .false.)
800 call device_cfill(y%x_d, 1.0_rp, this%m)
802 call device_cfill(lambda%x_d, 1.0_rp, this%m)
803 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
804 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
806 call device_cfill(mu%x_d, 1.0_rp, this%m)
812 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
813 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
814 mpi_real_precision, mpi_min, neko_comm, ierr)
819 outer:
do while (epsi .gt. minimal_epsilon)
823 associate(p0j => this%p0j, q0j => this%q0j, &
824 pij => this%pij, qij => this%qij, &
825 low => this%low, upp => this%upp, &
826 alpha => this%alpha, beta => this%beta, &
827 c => this%c, a0 => this%a0, a => this%a)
835 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
836 call device_pwmax2(y%x_d, zerom%x_d, this%m)
841 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
842 z = device_glsum(dummy_m%x_d, this%m)
843 z = max(0.0_rp, z - a0)
849 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
850 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
851 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
852 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
854 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
855 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
857 call device_cfill(relambda%x_d, 0.0_rp, this%m)
858 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
859 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
863 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
865 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
866 mpi_real_precision, mpi_sum, neko_comm, ierr)
867 call device_memcpy(relambda%x, relambda%x_d, this%m, &
868 host_to_device, sync = .true.)
870 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
871 call device_sub2(relambda%x_d, y%x_d, this%m)
872 call device_add2(relambda%x_d, mu%x_d, this%m)
873 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
875 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
876 call device_cadd(remu%x_d, -epsi, this%m)
878 residumax = maxval([device_maxval(relambda%x_d, this%m), &
879 device_maxval(remu%x_d, this%m)])
883 do iter = 1, this%max_iter
885 if (residumax .lt. epsi)
exit
893 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
894 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
897 call device_cfill(dummy_m%x_d, epsi, this%m)
898 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
899 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
900 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
906 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
907 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
909 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
910 this%pij%x_d, this%qij%x_d, this%n, this%m)
912 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
913 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
916 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
918 call mpi_allreduce(mpi_in_place, hess%x, &
919 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
920 call device_memcpy(hess%x, hess%x_d, this%m*this%m, &
921 host_to_device, sync = .true.)
926 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
927 if (device_glsum(dummy_m%x_d, this%m) .gt. 0.0_rp)
then
928 call device_update_hessian_z(hess%x_d, a%x_d, this%m)
939 call device_prepare_hessian(hess%x_d, y%x_d, mu%x_d, lambda%x_d, &
943 call device_solve_linear_system(hess%x_d, gradlambda%x_d, &
945 if (info .ne. 0)
then
946 call neko_error(
"Linear solver failed on the device in " // &
950 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
953 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
954 call device_col2(dummy_m%x_d, mu%x_d, this%m)
955 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
957 call device_cfill(dmu%x_d, epsi, this%m)
958 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
959 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
960 call device_sub2(dmu%x_d, mu%x_d, this%m)
962 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
963 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
967 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
968 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
976 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
977 call device_pwmax2(y%x_d, zerom%x_d, this%m)
982 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
983 z = device_glsum(dummy_m%x_d, this%m)
984 z = max(0.0_rp, z - a0)
990 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
991 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
992 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
993 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
995 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
996 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1000 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1001 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1002 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1006 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1008 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1009 mpi_real_precision, mpi_sum, neko_comm, ierr)
1010 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1011 host_to_device, sync = .true.)
1013 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1014 call device_sub2(relambda%x_d, y%x_d, this%m)
1015 call device_add2(relambda%x_d, mu%x_d, this%m)
1016 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1018 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1019 call device_cadd(remu%x_d, -epsi, this%m)
1021 residumax = maxval([device_maxval(relambda%x_d, this%m), &
1022 device_maxval(remu%x_d, this%m)])
1025 epsi = 0.1_rp * epsi
1029 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1030 call device_copy(this%xold1%x_d, designx_d, this%n)
1031 call device_copy(designx_d, x%x_d, this%n)
1034 call device_copy(this%y%x_d, y%x_d, this%m)
1036 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1037 call device_copy(this%mu%x_d, mu%x_d, this%m)
1039 call this%scratch%relinquish(ind)
1040 end subroutine mma_subsolve_dip_device
1042end submodule mma_device