33submodule(mma) mma_device
35 use device_math,
only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax, &
38 device_glsum, device_cmult2, device_pwmax
39 use device_mma_math,
only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess
48 use neko_config,
only: neko_bcknd_device
49 use device,
only: device_to_host
50 use comm,
only: neko_comm, pe_rank, mpi_real_precision
51 use mpi_f08,
only: mpi_in_place, mpi_max, mpi_min
57 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
65 class(mma_t),
intent(inout) :: this
66 integer,
intent(in) :: iter
67 type(c_ptr),
intent(inout) :: x
68 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
70 if (.not. this%is_initialized)
then
71 call neko_error(
"The MMA object is not initialized.")
75 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
78 if (this%subsolver .eq.
"dip")
then
79 call mma_subsolve_dip_device(this, x)
80 else if (this%subsolver .eq.
"dpip")
then
81 call mma_subsolve_dpip_device(this, x)
83 call neko_error(
"Unrecognized subsolver for MMA in mma_device.")
86 this%is_updated = .true.
87 end subroutine mma_update_device
89 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
90 class(mma_t),
intent(inout) :: this
91 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
93 if (this%subsolver .eq.
"dip")
then
94 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
96 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
98 end subroutine mma_kkt_device
102 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
103 class(mma_t),
intent(inout) :: this
104 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
106 type(vector_t) :: relambda, remu
108 call relambda%init(this%m)
109 call remu%init(this%m)
112 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
114 call device_sub2(relambda%x_d, this%y%x_d, this%m)
115 call device_add2(relambda%x_d, this%mu%x_d, this%m)
118 call device_col3 (remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
121 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
122 device_maxval(remu%x_d, this%m)])
123 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
124 device_norm(remu%x_d, this%m))
128 end subroutine mma_dip_kkt_device
132 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
133 class(mma_t),
intent(inout) :: this
134 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
136 real(kind=rp) :: rez, rezeta
137 type(vector_t) :: rey, relambda, remu, res
138 type(vector_t) :: rex, rexsi, reeta
140 real(kind=rp) :: re_sq_norm
142 call rey%init(this%m)
143 call relambda%init(this%m)
144 call remu%init(this%m)
145 call res%init(this%m)
147 call rex%init(this%n)
148 call rexsi%init(this%n)
149 call reeta%init(this%n)
151 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
152 this%eta%x_d, this%lambda%x_d, this%n, this%m)
154 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
155 call device_add2(rey%x_d, this%c%x_d, this%m)
156 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
157 call device_sub2(rey%x_d, this%mu%x_d, this%m)
159 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
162 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
164 call device_sub2(relambda%x_d, this%y%x_d, this%m)
165 call device_add2(relambda%x_d, this%s%x_d, this%m)
167 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
168 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
170 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
171 call device_col2(reeta%x_d, this%eta%x_d, this%n)
173 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
175 rezeta = this%zeta * this%z
177 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
179 this%residumax = maxval([ &
180 device_maxval(rex%x_d, this%n), &
181 device_maxval(rey%x_d, this%m), &
183 device_maxval(relambda%x_d, this%m), &
184 device_maxval(rexsi%x_d, this%n), &
185 device_maxval(reeta%x_d, this%n), &
186 device_maxval(remu%x_d, this%m), &
188 device_maxval(res%x_d, this%m)])
190 re_sq_norm = device_norm(rex%x_d, this%n) + &
191 device_norm(rexsi%x_d, this%n) + &
192 device_norm(reeta%x_d, this%n)
194 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
195 mpi_real_precision, mpi_max, neko_comm, ierr)
197 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
198 mpi_real_precision, mpi_sum, neko_comm, ierr)
200 this%residunorm = sqrt(( &
201 device_norm(rey%x_d, this%m) + &
203 device_norm(relambda%x_d, this%m) + &
204 device_norm(remu%x_d, this%m) + &
206 device_norm(res%x_d, this%m) &
216 end subroutine mma_dpip_kkt_device
222 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
228 class(mma_t),
intent(inout) :: this
229 type(c_ptr),
intent(in) :: x
230 type(c_ptr),
intent(in) :: df0dx
231 type(c_ptr),
intent(in) :: fval
232 type(c_ptr),
intent(in) :: dfdx
234 integer,
intent(in) :: iter
237 type(vector_t):: x_diff
239 call x_diff%init(this%n)
240 call device_sub3 (x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
241 call device_memcpy(x_diff%x, x_diff%x_d, this%n, &
242 device_to_host, sync = .true.)
247 if (iter .lt. 3)
then
248 call device_copy(this%low%x_d, x, this%n)
249 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
250 call device_copy(this%upp%x_d, x, this%n)
251 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
253 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
254 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
255 this%asydecr, this%asyincr, this%n)
261 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
262 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
263 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
264 this%qij%x_d, this%n, this%m)
269 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
270 this%qij%x_d, this%n, this%m, this%bi%x_d)
272 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
274 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
275 mpi_real_precision, mpi_sum, neko_comm, ierr)
276 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
278 call device_sub2(this%bi%x_d, fval, this%m)
280 end subroutine mma_gensub_device
284 subroutine mma_subsolve_dpip_device(this, designx_d)
285 class(mma_t),
intent(inout) :: this
286 type(c_ptr),
intent(in) :: designx_d
287 integer :: iter, itto, ierr
288 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
289 delz, dz, dzeta, steg, zold, zetaold, new_residual
291 type(vector_t) :: y, lambda, s, mu, rey, relambda, remu, res, &
292 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
295 type(vector_t) :: x, xsi, eta, rex, rexsi, reeta, &
296 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
303 integer,
dimension(this%m+1) :: ipiv
304 real(kind=rp) :: re_sq_norm
308 real(kind=rp) :: minimal_epsilon
311 call lambda%init(this%m)
314 call rey%init(this%m)
315 call relambda%init(this%m)
316 call remu%init(this%m)
317 call res%init(this%m)
318 call dely%init(this%m)
319 call dellambda%init(this%m)
321 call dlambda%init(this%m)
323 call dmu%init(this%m)
324 call yold%init(this%m)
325 call lambdaold%init(this%m)
326 call sold%init(this%m)
327 call muold%init(this%m)
329 call xsi%init(this%n)
330 call eta%init(this%n)
331 call rex%init(this%n)
332 call rexsi%init(this%n)
333 call reeta%init(this%n)
334 call delx%init(this%n)
335 call diagx%init(this%n)
337 call dxsi%init(this%n)
338 call deta%init(this%n)
339 call xold%init(this%n)
340 call xsiold%init(this%n)
341 call etaold%init(this%n)
342 call bb%init(this%m+1)
344 call gg%init(this%m, this%n)
345 call aa%init(this%m+1, this%m+1)
352 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
354 call device_cfill(y%x_d, 1.0_rp, this%m)
357 call device_cfill(lambda%x_d, 1.0_rp, this%m)
358 call device_cfill(s%x_d, 1.0_rp, this%m)
359 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
360 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
361 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
366 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
367 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
368 mpi_real_precision, mpi_min, neko_comm, ierr)
373 do while (epsi .gt. minimal_epsilon)
380 associate(p0j => this%p0j, q0j => this%q0j, &
381 pij => this%pij, qij => this%qij, &
382 low => this%low, upp => this%upp, &
383 alpha => this%alpha, beta => this%beta, &
384 c => this%c, d => this%d, &
385 a0 => this%a0, a => this%a)
387 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
388 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
389 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
391 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
392 call device_add2(rey%x_d, c%x_d, this%m)
393 call device_sub2(rey%x_d, lambda%x_d, this%m)
394 call device_sub2(rey%x_d, mu%x_d, this%m)
395 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
397 call device_cfill(relambda%x_d, 0.0_rp, this%m)
398 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
399 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
407 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
409 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
410 mpi_real_precision, mpi_sum, neko_comm, ierr)
411 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
414 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
415 call device_sub2(relambda%x_d, y%x_d, this%m)
416 call device_add2(relambda%x_d, s%x_d, this%m)
417 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
419 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
420 call device_col2(rexsi%x_d, xsi%x_d, this%n)
421 call device_cadd(rexsi%x_d, - epsi, this%n)
423 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
424 call device_col2(reeta%x_d, eta%x_d, this%n)
425 call device_cadd(reeta%x_d, - epsi, this%n)
427 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
428 call device_cadd(remu%x_d, - epsi, this%m)
430 rezeta = zeta * z - epsi
432 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
433 call device_cadd(res%x_d, - epsi, this%m)
436 residual_max = maxval([device_maxval(rex%x_d, this%n), &
437 device_maxval(rey%x_d, this%m), abs(rez), &
438 device_maxval(relambda%x_d, this%m), &
439 device_maxval(rexsi%x_d, this%n), &
440 device_maxval(reeta%x_d, this%n), &
441 device_maxval(remu%x_d, this%m), abs(rezeta), &
442 device_maxval(res%x_d, this%m)])
444 re_sq_norm = device_norm(rex%x_d, this%n) + &
445 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
447 call mpi_allreduce(mpi_in_place, residual_max, 1, &
448 mpi_real_precision, mpi_max, neko_comm, ierr)
450 call mpi_allreduce(mpi_in_place, re_sq_norm, &
451 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
453 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
455 device_norm(relambda%x_d, this%m) + &
456 device_norm(remu%x_d, this%m)+ &
458 device_norm(res%x_d, this%m) &
464 do iter = 1, this%max_iter
466 if (residual_max .lt. epsi)
exit
468 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
469 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
470 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
473 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
474 call device_add2(dely%x_d, this%c%x_d, this%m)
475 call device_sub2(dely%x_d, lambda%x_d, this%m)
476 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
477 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
480 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
481 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
482 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
484 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
485 device_to_host, sync = .true.)
486 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
487 mpi_real_precision, mpi_sum, neko_comm, ierr)
488 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
489 host_to_device, sync = .true.)
491 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
493 call device_sub2(dellambda%x_d, y%x_d, this%m)
494 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
495 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
497 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
498 this%pij%x_d, this%qij%x_d, this%n, this%m)
500 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
501 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
502 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
503 lambda%x_d, this%n, this%m)
513 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
516 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
518 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
519 mpi_real_precision, mpi_sum, neko_comm, ierr)
520 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
521 host_to_device, sync = .true.)
523 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
524 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
534 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
535 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
536 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
537 device_to_host, sync = .true.)
539 call mpi_allreduce(mpi_in_place, aa%x, &
540 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
542 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
544 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
546 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
548 call device_memcpy(s%x, s%x_d, this%m, device_to_host, &
552 aa%x(i, i) = aa%x(i, i) &
553 + s%x(i) / lambda%x(i) &
554 + 1.0_rp / (this%d%x(i) + mu%x(i) / y%x(i))
556 aa%x(1:this%m, this%m+1) = this%a%x
557 aa%x(this%m+1, 1:this%m) = this%a%x
558 aa%x(this%m+1, this%m+1) = - zeta/z
560 call device_memcpy(aa%x, aa%x_d, &
561 (this%m + 1) * (this%m + 1), host_to_device, sync = .true.)
563 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
565 call dgesv(this%m + 1, 1, aa%x, this%m + 1, ipiv, bb%x, this%m + 1, &
568 if (info .ne. 0)
then
569 call neko_error(
"DGESV failed to solve the linear system in " // &
570 "mma_subsolve_dpip (device).")
573 call device_memcpy(bb%x, bb%x_d, this%m+1, host_to_device, &
576 dlambda%x = bb%x(1:this%m)
577 call device_memcpy(dlambda%x, dlambda%x_d, this%m, host_to_device, &
580 dz = bb%x(this%m + 1)
583 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
584 dlambda%x_d, this%n, this%m)
585 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
586 mu%x_d, y%x_d, this%m)
587 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
588 this%alpha%x_d, epsi, this%n)
589 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
590 this%beta%x_d, epsi, this%n)
592 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
593 call device_cmult(dmu%x_d, -1.0_rp, this%m)
594 call device_cadd(dmu%x_d, epsi, this%m)
595 call device_invcol2(dmu%x_d, y%x_d, this%m)
596 call device_sub2(dmu%x_d, mu%x_d, this%m)
597 dzeta = -zeta + (epsi - zeta * dz) / z
598 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
599 call device_cmult(ds%x_d, -1.0_rp, this%m)
600 call device_cadd(ds%x_d, epsi, this%m)
601 call device_invcol2(ds%x_d, lambda%x_d, this%m)
602 call device_sub2(ds%x_d, s%x_d, this%m)
604 steg = maxval([1.0_rp, &
605 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
607 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
608 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
609 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
610 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
611 -1.01_rp * dzeta / zeta, &
612 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
613 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
614 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
618 call device_copy(xold%x_d, x%x_d, this%n)
619 call device_copy(yold%x_d, y%x_d, this%m)
621 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
622 call device_copy(xsiold%x_d, xsi%x_d, this%n)
623 call device_copy(etaold%x_d, eta%x_d, this%n)
624 call device_copy(muold%x_d, mu%x_d, this%m)
626 call device_copy(sold%x_d, s%x_d, this%m)
628 new_residual = 2.0_rp * residual_norm
631 call mpi_allreduce(mpi_in_place, steg, 1, &
632 mpi_real_precision, mpi_min, neko_comm, ierr)
633 call mpi_allreduce(mpi_in_place, new_residual, 1, &
634 mpi_real_precision, mpi_min, neko_comm, ierr)
639 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
643 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
644 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
646 call device_add3s2(lambda%x_d, lambdaold%x_d, &
647 dlambda%x_d, 1.0_rp, steg, this%m)
648 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
649 1.0_rp, steg, this%n)
650 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
651 1.0_rp, steg, this%n)
652 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
653 1.0_rp, steg, this%m)
654 zeta = zetaold + steg*dzeta
655 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
660 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
661 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
662 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
663 eta%x_d, this%n, this%m)
665 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
666 call device_add2(rey%x_d, this%c%x_d, this%m)
667 call device_sub2(rey%x_d, lambda%x_d, this%m)
668 call device_sub2(rey%x_d, mu%x_d, this%m)
670 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
673 call device_cfill(relambda%x_d, 0.0_rp, this%m)
674 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
675 this%low%x_d, this%pij%x_d, this%qij%x_d, &
678 call device_memcpy(relambda%x, relambda%x_d, this%m, &
679 device_to_host, sync = .true.)
680 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
681 mpi_real_precision, mpi_sum, neko_comm, ierr)
682 call device_memcpy(relambda%x, relambda%x_d, &
683 this%m, host_to_device, sync = .true.)
685 call device_add3s2(relambda%x_d, relambda%x_d, &
686 this%a%x_d, 1.0_rp, -z, this%m)
687 call device_sub2(relambda%x_d, y%x_d, this%m)
688 call device_add2(relambda%x_d, s%x_d, this%m)
689 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
691 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
692 call device_col2(rexsi%x_d, xsi%x_d, this%n)
693 call device_cadd(rexsi%x_d, - epsi, this%n)
695 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
696 call device_col2(reeta%x_d, eta%x_d, this%n)
697 call device_cadd(reeta%x_d, - epsi, this%n)
699 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
700 call device_cadd(remu%x_d, - epsi, this%m)
702 rezeta = zeta*z - epsi
704 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
705 call device_cadd(res%x_d, - epsi, this%m)
708 re_sq_norm = device_norm(rex%x_d, this%n) + &
709 device_norm(rexsi%x_d, this%n) + &
710 device_norm(reeta%x_d, this%n)
711 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
712 mpi_real_precision, mpi_sum, neko_comm, ierr)
714 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
716 device_norm(relambda%x_d, this%m) + &
717 device_norm(remu%x_d, this%m) + &
719 device_norm(res%x_d, this%m) + &
728 residual_norm = new_residual
729 residual_max = maxval([ &
730 device_maxval(rex%x_d, this%n), &
731 device_maxval(rey%x_d, this%m), &
733 device_maxval(relambda%x_d, this%m), &
734 device_maxval(rexsi%x_d, this%n), &
735 device_maxval(reeta%x_d, this%n), &
736 device_maxval(remu%x_d, this%m), &
738 device_maxval(res%x_d, this%m)])
740 call mpi_allreduce(mpi_in_place, residual_max, 1, &
741 mpi_real_precision, mpi_max, neko_comm, ierr)
749 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
750 call device_copy(this%xold1%x_d, designx_d, this%n)
751 call device_copy(designx_d, x%x_d, this%n)
754 call device_copy(this%y%x_d, y%x_d, this%m)
756 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
758 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
759 call device_copy(this%eta%x_d, eta%x_d, this%n)
760 call device_copy(this%mu%x_d, mu%x_d, this%m)
761 call device_copy(this%s%x_d, s%x_d, this%m)
773 call dellambda%free()
779 call lambdaold%free()
798 end subroutine mma_subsolve_dpip_device
802 subroutine mma_subsolve_dip_device(this, designx_d)
803 class(mma_t),
intent(inout) :: this
804 type(c_ptr),
intent(in) :: designx_d
805 integer :: iter, ierr
806 real(kind=rp) :: epsi, residumax, z, steg
808 type(vector_t) :: y, lambda, mu, relambda, remu, dlambda, dmu, &
809 gradlambda, zerom, dd, dummy_m
811 type(vector_t) :: x, pjlambda, qjlambda
814 type(vector_t) :: Ljjxinv
815 type(matrix_t) :: hijx
816 type(matrix_t) :: Hess
817 real(kind=rp) :: hesstrace
820 integer,
dimension(this%m+1) :: ipiv
823 real(kind=rp) :: minimal_epsilon
826 call lambda%init(this%m)
828 call relambda%init(this%m)
829 call remu%init(this%m)
830 call dlambda%init(this%m)
831 call dmu%init(this%m)
832 call gradlambda%init(this%m)
833 call zerom%init(this%m)
835 call dummy_m%init(this%m)
838 call pjlambda%init(this%n)
839 call qjlambda%init(this%n)
841 call ljjxinv%init(this%n)
842 call hijx%init(this%m,this%n)
843 call hess%init(this%m,this%m)
845 call device_cfill(zerom%x_d, 0.0_rp, this%m)
852 call device_cfill(y%x_d, 1.0_rp, this%m)
854 call device_cfill(lambda%x_d, 1.0_rp, this%m)
855 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
856 call device_pwmax(lambda%x_d, dummy_m%x_d, this%m)
858 call device_cfill(mu%x_d, 1.0_rp, this%m)
862 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
867 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
868 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
869 mpi_real_precision, mpi_min, neko_comm, ierr)
874 outer:
do while (epsi .gt. minimal_epsilon)
878 associate(p0j => this%p0j, q0j => this%q0j, &
879 pij => this%pij, qij => this%qij, &
880 low => this%low, upp => this%upp, &
881 alpha => this%alpha, beta => this%beta, &
882 c => this%c, a0 => this%a0, a => this%a)
890 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
892 call device_invcol2(y%x_d, dd%x_d, this%m)
893 call device_pwmax(y%x_d, zerom%x_d, this%m)
899 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
900 z = device_glsum(dummy_m%x_d, this%m)
901 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
907 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
908 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
909 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
910 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
912 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
913 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
915 call device_cfill(relambda%x_d, 0.0_rp, this%m)
916 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
917 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
921 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
923 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
924 mpi_real_precision, mpi_sum, neko_comm, ierr)
925 call device_memcpy(relambda%x, relambda%x_d, this%m, &
926 host_to_device, sync = .true.)
928 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
929 call device_sub2(relambda%x_d, y%x_d, this%m)
930 call device_add2(relambda%x_d, mu%x_d, this%m)
931 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
933 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
934 call device_cadd(remu%x_d, -epsi, this%m)
938 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
940 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
942 residumax = maxval(abs([relambda%x, remu%x]))
946 do iter = 1, this%max_iter
948 if (residumax .lt. epsi)
exit
956 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
957 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
960 call device_cfill(dummy_m%x_d, epsi, this%m)
961 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
962 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
963 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
969 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
970 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
972 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
973 this%pij%x_d, this%qij%x_d, this%n, this%m)
975 call device_memcpy(hijx%x, hijx%x_d, this%n*this%m, device_to_host, &
978 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
979 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
982 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
984 call mpi_allreduce(mpi_in_place, hess%x, &
985 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
1002 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1004 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1006 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
1009 if (y%x(i) .gt. 0.0_rp)
then
1010 if (abs(this%d%x(i)) < 1.0e-15_rp)
then
1013 hess%x(i, i) = hess%x(i, i) - 1.0_rp/this%d%x(i)
1017 hess%x(i, i) = hess%x(i, i) - mu%x(i) / lambda%x(i)
1024 hesstrace = hesstrace + hess%x(i, i)
1027 hess%x(i,i) = hess%x(i, i) - &
1028 max(-1.0e-4_rp*hesstrace/this%m, 1.0e-7_rp)
1031 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, device_to_host, &
1033 call dgesv(this%m , 1, hess%x, this%m , ipiv, &
1034 gradlambda%x, this%m, info)
1036 if (info .ne. 0)
then
1037 call neko_error(
"DGESV failed to solve the linear system in " // &
1038 "mma_subsolve_dip (device).")
1040 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, host_to_device, &
1043 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
1046 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
1047 call device_col2(dummy_m%x_d, mu%x_d, this%m)
1048 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
1050 call device_cfill(dmu%x_d, epsi, this%m)
1051 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
1052 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
1053 call device_sub2(dmu%x_d, mu%x_d, this%m)
1055 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
1056 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
1058 steg = 1.0_rp / steg
1060 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
1061 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
1063 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1065 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1075 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
1077 call device_invcol2(y%x_d, dd%x_d, this%m)
1078 call device_pwmax(y%x_d, zerom%x_d, this%m)
1084 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
1085 z = device_glsum(dummy_m%x_d, this%m)
1086 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1092 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1093 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1094 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1095 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1097 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1098 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1102 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1103 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1104 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1108 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1110 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1111 mpi_real_precision, mpi_sum, neko_comm, ierr)
1112 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1113 host_to_device, sync = .true.)
1115 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1116 call device_sub2(relambda%x_d, y%x_d, this%m)
1117 call device_add2(relambda%x_d, mu%x_d, this%m)
1118 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1120 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1121 call device_cadd(remu%x_d, -epsi, this%m)
1126 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1128 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
1130 residumax = maxval(abs([relambda%x, remu%x]))
1133 epsi = 0.1_rp * epsi
1137 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1138 call device_copy(this%xold1%x_d, designx_d, this%n)
1139 call device_copy(designx_d, x%x_d, this%n)
1142 call device_copy(this%y%x_d, y%x_d, this%m)
1144 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1145 call device_copy(this%mu%x_d, mu%x_d, this%m)
1150 call relambda%free()
1154 call gradlambda%free()
1160 call pjlambda%free()
1161 call qjlambda%free()
1166 end subroutine mma_subsolve_dip_device
1168end submodule mma_device