33submodule(mma) mma_device
35 use device_math,
only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
38 device_glsum, device_cmult2
39 use device_mma_math,
only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess
48 use neko_config,
only: neko_bcknd_device
49 use device,
only: device_to_host
50 use comm,
only: neko_comm, pe_rank, mpi_real_precision
51 use mpi_f08,
only: mpi_in_place, mpi_max, mpi_min
52 use profiler,
only: profiler_start_region, profiler_end_region
58 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
66 class(mma_t),
intent(inout) :: this
67 integer,
intent(in) :: iter
68 type(c_ptr),
intent(inout) :: x
69 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
71 if (.not. this%is_initialized)
then
72 call neko_error(
"The MMA object is not initialized.")
75 call profiler_start_region(
"MMA gensub")
77 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
78 call profiler_end_region(
"MMA gensub")
81 call profiler_start_region(
"MMA subsolve")
82 if (this%subsolver .eq.
"dip")
then
83 call mma_subsolve_dip_device(this, x)
84 else if (this%subsolver .eq.
"dpip")
then
85 call mma_subsolve_dpip_device(this, x)
87 call neko_error(
"Unrecognized subsolver for MMA in mma_device.")
89 call profiler_end_region(
"MMA subsolve")
91 this%is_updated = .true.
92 end subroutine mma_update_device
94 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
95 class(mma_t),
intent(inout) :: this
96 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
98 if (this%subsolver .eq.
"dip")
then
99 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
101 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
103 end subroutine mma_kkt_device
107 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
108 class(mma_t),
intent(inout) :: this
109 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
111 type(vector_t) :: relambda, remu
113 call relambda%init(this%m)
114 call remu%init(this%m)
117 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
119 call device_sub2(relambda%x_d, this%y%x_d, this%m)
120 call device_add2(relambda%x_d, this%mu%x_d, this%m)
123 call device_col3 (remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
126 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
127 device_maxval(remu%x_d, this%m)])
128 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
129 device_norm(remu%x_d, this%m))
133 end subroutine mma_dip_kkt_device
137 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
138 class(mma_t),
intent(inout) :: this
139 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
141 real(kind=rp) :: rez, rezeta
142 type(vector_t) :: rey, relambda, remu, res
143 type(vector_t) :: rex, rexsi, reeta
145 real(kind=rp) :: re_sq_norm
147 call rey%init(this%m)
148 call relambda%init(this%m)
149 call remu%init(this%m)
150 call res%init(this%m)
152 call rex%init(this%n)
153 call rexsi%init(this%n)
154 call reeta%init(this%n)
156 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
157 this%eta%x_d, this%lambda%x_d, this%n, this%m)
159 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
160 call device_add2(rey%x_d, this%c%x_d, this%m)
161 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
162 call device_sub2(rey%x_d, this%mu%x_d, this%m)
164 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
167 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
169 call device_sub2(relambda%x_d, this%y%x_d, this%m)
170 call device_add2(relambda%x_d, this%s%x_d, this%m)
172 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
173 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
175 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
176 call device_col2(reeta%x_d, this%eta%x_d, this%n)
178 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
180 rezeta = this%zeta * this%z
182 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
184 this%residumax = maxval([ &
185 device_maxval(rex%x_d, this%n), &
186 device_maxval(rey%x_d, this%m), &
188 device_maxval(relambda%x_d, this%m), &
189 device_maxval(rexsi%x_d, this%n), &
190 device_maxval(reeta%x_d, this%n), &
191 device_maxval(remu%x_d, this%m), &
193 device_maxval(res%x_d, this%m)])
195 re_sq_norm = device_norm(rex%x_d, this%n) + &
196 device_norm(rexsi%x_d, this%n) + &
197 device_norm(reeta%x_d, this%n)
199 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
200 mpi_real_precision, mpi_max, neko_comm, ierr)
202 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
203 mpi_real_precision, mpi_sum, neko_comm, ierr)
205 this%residunorm = sqrt(( &
206 device_norm(rey%x_d, this%m) + &
208 device_norm(relambda%x_d, this%m) + &
209 device_norm(remu%x_d, this%m) + &
211 device_norm(res%x_d, this%m) &
221 end subroutine mma_dpip_kkt_device
227 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
233 class(mma_t),
intent(inout) :: this
234 type(c_ptr),
intent(in) :: x
235 type(c_ptr),
intent(in) :: df0dx
236 type(c_ptr),
intent(in) :: fval
237 type(c_ptr),
intent(in) :: dfdx
239 integer,
intent(in) :: iter
242 type(vector_t):: x_diff
244 call x_diff%init(this%n)
245 call device_sub3 (x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
246 call device_memcpy(x_diff%x, x_diff%x_d, this%n, &
247 device_to_host, sync = .true.)
252 if (iter .lt. 3)
then
253 call device_copy(this%low%x_d, x, this%n)
254 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
255 call device_copy(this%upp%x_d, x, this%n)
256 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
258 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
259 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
260 this%asydecr, this%asyincr, this%n)
266 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
267 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
268 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
269 this%qij%x_d, this%n, this%m)
274 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
275 this%qij%x_d, this%n, this%m, this%bi%x_d)
277 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
279 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
280 mpi_real_precision, mpi_sum, neko_comm, ierr)
281 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
283 call device_sub2(this%bi%x_d, fval, this%m)
285 end subroutine mma_gensub_device
289 subroutine mma_subsolve_dpip_device(this, designx_d)
290 class(mma_t),
intent(inout) :: this
291 type(c_ptr),
intent(in) :: designx_d
292 integer :: iter, itto, ierr
293 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
294 delz, dz, dzeta, steg, zold, zetaold, new_residual
296 type(vector_t) :: y, lambda, s, mu, rey, relambda, remu, res, &
297 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
300 type(vector_t) :: x, xsi, eta, rex, rexsi, reeta, &
301 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
308 integer,
dimension(this%m+1) :: ipiv
309 real(kind=rp) :: re_sq_norm
313 real(kind=rp) :: minimal_epsilon
316 call lambda%init(this%m)
319 call rey%init(this%m)
320 call relambda%init(this%m)
321 call remu%init(this%m)
322 call res%init(this%m)
323 call dely%init(this%m)
324 call dellambda%init(this%m)
326 call dlambda%init(this%m)
328 call dmu%init(this%m)
329 call yold%init(this%m)
330 call lambdaold%init(this%m)
331 call sold%init(this%m)
332 call muold%init(this%m)
334 call xsi%init(this%n)
335 call eta%init(this%n)
336 call rex%init(this%n)
337 call rexsi%init(this%n)
338 call reeta%init(this%n)
339 call delx%init(this%n)
340 call diagx%init(this%n)
342 call dxsi%init(this%n)
343 call deta%init(this%n)
344 call xold%init(this%n)
345 call xsiold%init(this%n)
346 call etaold%init(this%n)
347 call bb%init(this%m+1)
349 call gg%init(this%m, this%n)
350 call aa%init(this%m+1, this%m+1)
357 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
359 call device_cfill(y%x_d, 1.0_rp, this%m)
362 call device_cfill(lambda%x_d, 1.0_rp, this%m)
363 call device_cfill(s%x_d, 1.0_rp, this%m)
364 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
365 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
366 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
371 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
372 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
373 mpi_real_precision, mpi_min, neko_comm, ierr)
378 do while (epsi .gt. minimal_epsilon)
385 associate(p0j => this%p0j, q0j => this%q0j, &
386 pij => this%pij, qij => this%qij, &
387 low => this%low, upp => this%upp, &
388 alpha => this%alpha, beta => this%beta, &
389 c => this%c, d => this%d, &
390 a0 => this%a0, a => this%a)
392 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
393 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
394 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
396 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
397 call device_add2(rey%x_d, c%x_d, this%m)
398 call device_sub2(rey%x_d, lambda%x_d, this%m)
399 call device_sub2(rey%x_d, mu%x_d, this%m)
400 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
402 call device_cfill(relambda%x_d, 0.0_rp, this%m)
403 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
404 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
412 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
414 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
415 mpi_real_precision, mpi_sum, neko_comm, ierr)
416 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
419 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
420 call device_sub2(relambda%x_d, y%x_d, this%m)
421 call device_add2(relambda%x_d, s%x_d, this%m)
422 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
424 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
425 call device_col2(rexsi%x_d, xsi%x_d, this%n)
426 call device_cadd(rexsi%x_d, - epsi, this%n)
428 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
429 call device_col2(reeta%x_d, eta%x_d, this%n)
430 call device_cadd(reeta%x_d, - epsi, this%n)
432 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
433 call device_cadd(remu%x_d, - epsi, this%m)
435 rezeta = zeta * z - epsi
437 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
438 call device_cadd(res%x_d, - epsi, this%m)
441 residual_max = maxval([device_maxval(rex%x_d, this%n), &
442 device_maxval(rey%x_d, this%m), abs(rez), &
443 device_maxval(relambda%x_d, this%m), &
444 device_maxval(rexsi%x_d, this%n), &
445 device_maxval(reeta%x_d, this%n), &
446 device_maxval(remu%x_d, this%m), abs(rezeta), &
447 device_maxval(res%x_d, this%m)])
449 re_sq_norm = device_norm(rex%x_d, this%n) + &
450 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
452 call mpi_allreduce(mpi_in_place, residual_max, 1, &
453 mpi_real_precision, mpi_max, neko_comm, ierr)
455 call mpi_allreduce(mpi_in_place, re_sq_norm, &
456 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
458 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
460 device_norm(relambda%x_d, this%m) + &
461 device_norm(remu%x_d, this%m)+ &
463 device_norm(res%x_d, this%m) &
469 do iter = 1, this%max_iter
471 if (residual_max .lt. epsi)
exit
473 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
474 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
475 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
478 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
479 call device_add2(dely%x_d, this%c%x_d, this%m)
480 call device_sub2(dely%x_d, lambda%x_d, this%m)
481 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
482 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
485 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
486 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
487 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
489 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
490 device_to_host, sync = .true.)
491 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
492 mpi_real_precision, mpi_sum, neko_comm, ierr)
493 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
494 host_to_device, sync = .true.)
496 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
498 call device_sub2(dellambda%x_d, y%x_d, this%m)
499 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
500 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
502 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
503 this%pij%x_d, this%qij%x_d, this%n, this%m)
505 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
506 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
507 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
508 lambda%x_d, this%n, this%m)
518 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
521 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
523 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
524 mpi_real_precision, mpi_sum, neko_comm, ierr)
525 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
526 host_to_device, sync = .true.)
528 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
529 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
539 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
540 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
541 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
542 device_to_host, sync = .true.)
544 call mpi_allreduce(mpi_in_place, aa%x, &
545 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
547 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
549 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
551 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
553 call device_memcpy(s%x, s%x_d, this%m, device_to_host, &
557 aa%x(i, i) = aa%x(i, i) &
558 + s%x(i) / lambda%x(i) &
559 + 1.0_rp / (this%d%x(i) + mu%x(i) / y%x(i))
561 aa%x(1:this%m, this%m+1) = this%a%x
562 aa%x(this%m+1, 1:this%m) = this%a%x
563 aa%x(this%m+1, this%m+1) = - zeta/z
565 call device_memcpy(aa%x, aa%x_d, &
566 (this%m + 1) * (this%m + 1), host_to_device, sync = .true.)
568 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
570 call dgesv(this%m + 1, 1, aa%x, this%m + 1, ipiv, bb%x, this%m + 1, &
573 if (info .ne. 0)
then
574 call neko_error(
"DGESV failed to solve the linear system in " // &
575 "mma_subsolve_dpip (device).")
578 call device_memcpy(bb%x, bb%x_d, this%m+1, host_to_device, &
581 dlambda%x = bb%x(1:this%m)
582 call device_memcpy(dlambda%x, dlambda%x_d, this%m, host_to_device, &
585 dz = bb%x(this%m + 1)
588 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
589 dlambda%x_d, this%n, this%m)
590 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
591 mu%x_d, y%x_d, this%m)
592 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
593 this%alpha%x_d, epsi, this%n)
594 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
595 this%beta%x_d, epsi, this%n)
597 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
598 call device_cmult(dmu%x_d, -1.0_rp, this%m)
599 call device_cadd(dmu%x_d, epsi, this%m)
600 call device_invcol2(dmu%x_d, y%x_d, this%m)
601 call device_sub2(dmu%x_d, mu%x_d, this%m)
602 dzeta = -zeta + (epsi - zeta * dz) / z
603 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
604 call device_cmult(ds%x_d, -1.0_rp, this%m)
605 call device_cadd(ds%x_d, epsi, this%m)
606 call device_invcol2(ds%x_d, lambda%x_d, this%m)
607 call device_sub2(ds%x_d, s%x_d, this%m)
609 steg = maxval([1.0_rp, &
610 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
612 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
613 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
614 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
615 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
616 -1.01_rp * dzeta / zeta, &
617 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
618 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
619 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
623 call device_copy(xold%x_d, x%x_d, this%n)
624 call device_copy(yold%x_d, y%x_d, this%m)
626 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
627 call device_copy(xsiold%x_d, xsi%x_d, this%n)
628 call device_copy(etaold%x_d, eta%x_d, this%n)
629 call device_copy(muold%x_d, mu%x_d, this%m)
631 call device_copy(sold%x_d, s%x_d, this%m)
633 new_residual = 2.0_rp * residual_norm
636 call mpi_allreduce(mpi_in_place, steg, 1, &
637 mpi_real_precision, mpi_min, neko_comm, ierr)
638 call mpi_allreduce(mpi_in_place, new_residual, 1, &
639 mpi_real_precision, mpi_min, neko_comm, ierr)
644 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
648 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
649 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
651 call device_add3s2(lambda%x_d, lambdaold%x_d, &
652 dlambda%x_d, 1.0_rp, steg, this%m)
653 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
654 1.0_rp, steg, this%n)
655 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
656 1.0_rp, steg, this%n)
657 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
658 1.0_rp, steg, this%m)
659 zeta = zetaold + steg*dzeta
660 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
665 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
666 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
667 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
668 eta%x_d, this%n, this%m)
670 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
671 call device_add2(rey%x_d, this%c%x_d, this%m)
672 call device_sub2(rey%x_d, lambda%x_d, this%m)
673 call device_sub2(rey%x_d, mu%x_d, this%m)
675 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
678 call device_cfill(relambda%x_d, 0.0_rp, this%m)
679 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
680 this%low%x_d, this%pij%x_d, this%qij%x_d, &
683 call device_memcpy(relambda%x, relambda%x_d, this%m, &
684 device_to_host, sync = .true.)
685 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
686 mpi_real_precision, mpi_sum, neko_comm, ierr)
687 call device_memcpy(relambda%x, relambda%x_d, &
688 this%m, host_to_device, sync = .true.)
690 call device_add3s2(relambda%x_d, relambda%x_d, &
691 this%a%x_d, 1.0_rp, -z, this%m)
692 call device_sub2(relambda%x_d, y%x_d, this%m)
693 call device_add2(relambda%x_d, s%x_d, this%m)
694 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
696 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
697 call device_col2(rexsi%x_d, xsi%x_d, this%n)
698 call device_cadd(rexsi%x_d, - epsi, this%n)
700 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
701 call device_col2(reeta%x_d, eta%x_d, this%n)
702 call device_cadd(reeta%x_d, - epsi, this%n)
704 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
705 call device_cadd(remu%x_d, - epsi, this%m)
707 rezeta = zeta*z - epsi
709 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
710 call device_cadd(res%x_d, - epsi, this%m)
713 re_sq_norm = device_norm(rex%x_d, this%n) + &
714 device_norm(rexsi%x_d, this%n) + &
715 device_norm(reeta%x_d, this%n)
716 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
717 mpi_real_precision, mpi_sum, neko_comm, ierr)
719 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
721 device_norm(relambda%x_d, this%m) + &
722 device_norm(remu%x_d, this%m) + &
724 device_norm(res%x_d, this%m) + &
733 residual_norm = new_residual
734 residual_max = maxval([ &
735 device_maxval(rex%x_d, this%n), &
736 device_maxval(rey%x_d, this%m), &
738 device_maxval(relambda%x_d, this%m), &
739 device_maxval(rexsi%x_d, this%n), &
740 device_maxval(reeta%x_d, this%n), &
741 device_maxval(remu%x_d, this%m), &
743 device_maxval(res%x_d, this%m)])
745 call mpi_allreduce(mpi_in_place, residual_max, 1, &
746 mpi_real_precision, mpi_max, neko_comm, ierr)
754 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
755 call device_copy(this%xold1%x_d, designx_d, this%n)
756 call device_copy(designx_d, x%x_d, this%n)
759 call device_copy(this%y%x_d, y%x_d, this%m)
761 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
763 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
764 call device_copy(this%eta%x_d, eta%x_d, this%n)
765 call device_copy(this%mu%x_d, mu%x_d, this%m)
766 call device_copy(this%s%x_d, s%x_d, this%m)
778 call dellambda%free()
784 call lambdaold%free()
803 end subroutine mma_subsolve_dpip_device
807 subroutine mma_subsolve_dip_device(this, designx_d)
808 class(mma_t),
intent(inout) :: this
809 type(c_ptr),
intent(in) :: designx_d
810 integer :: iter, ierr
811 real(kind=rp) :: epsi, residumax, z, steg
813 type(vector_t) :: y, lambda, mu, relambda, remu, dlambda, dmu, &
814 gradlambda, zerom, dd, dummy_m
816 type(vector_t) :: x, pjlambda, qjlambda
819 type(vector_t) :: Ljjxinv
820 type(matrix_t) :: hijx
821 type(matrix_t) :: Hess
822 real(kind=rp) :: hesstrace
825 integer,
dimension(this%m+1) :: ipiv
828 real(kind=rp) :: minimal_epsilon
831 call lambda%init(this%m)
833 call relambda%init(this%m)
834 call remu%init(this%m)
835 call dlambda%init(this%m)
836 call dmu%init(this%m)
837 call gradlambda%init(this%m)
838 call zerom%init(this%m)
840 call dummy_m%init(this%m)
843 call pjlambda%init(this%n)
844 call qjlambda%init(this%n)
846 call ljjxinv%init(this%n)
847 call hijx%init(this%m,this%n)
848 call hess%init(this%m,this%m)
850 call device_cfill(zerom%x_d, 0.0_rp, this%m)
857 call device_cfill(y%x_d, 1.0_rp, this%m)
859 call device_cfill(lambda%x_d, 1.0_rp, this%m)
860 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
861 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
863 call device_cfill(mu%x_d, 1.0_rp, this%m)
867 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
872 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
873 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
874 mpi_real_precision, mpi_min, neko_comm, ierr)
879 outer:
do while (epsi .gt. minimal_epsilon)
883 associate(p0j => this%p0j, q0j => this%q0j, &
884 pij => this%pij, qij => this%qij, &
885 low => this%low, upp => this%upp, &
886 alpha => this%alpha, beta => this%beta, &
887 c => this%c, a0 => this%a0, a => this%a)
895 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
897 call device_invcol2(y%x_d, dd%x_d, this%m)
898 call device_pwmax2(y%x_d, zerom%x_d, this%m)
904 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
905 z = device_glsum(dummy_m%x_d, this%m)
906 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
912 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
913 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
914 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
915 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
917 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
918 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
920 call device_cfill(relambda%x_d, 0.0_rp, this%m)
921 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
922 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
926 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
928 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
929 mpi_real_precision, mpi_sum, neko_comm, ierr)
930 call device_memcpy(relambda%x, relambda%x_d, this%m, &
931 host_to_device, sync = .true.)
933 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
934 call device_sub2(relambda%x_d, y%x_d, this%m)
935 call device_add2(relambda%x_d, mu%x_d, this%m)
936 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
938 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
939 call device_cadd(remu%x_d, -epsi, this%m)
943 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
945 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
947 residumax = maxval(abs([relambda%x, remu%x]))
951 do iter = 1, this%max_iter
953 if (residumax .lt. epsi)
exit
961 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
962 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
965 call device_cfill(dummy_m%x_d, epsi, this%m)
966 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
967 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
968 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
974 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
975 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
977 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
978 this%pij%x_d, this%qij%x_d, this%n, this%m)
980 call device_memcpy(hijx%x, hijx%x_d, this%n*this%m, device_to_host, &
983 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
984 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
987 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
989 call mpi_allreduce(mpi_in_place, hess%x, &
990 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
1007 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1009 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1011 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
1014 if (y%x(i) .gt. 0.0_rp)
then
1015 if (abs(this%d%x(i)) < 1.0e-15_rp)
then
1018 hess%x(i, i) = hess%x(i, i) - 1.0_rp/this%d%x(i)
1022 hess%x(i, i) = hess%x(i, i) - mu%x(i) / lambda%x(i)
1029 hesstrace = hesstrace + hess%x(i, i)
1032 hess%x(i,i) = hess%x(i, i) - &
1033 max(-1.0e-4_rp*hesstrace/this%m, 1.0e-7_rp)
1036 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, device_to_host, &
1038 call dgesv(this%m , 1, hess%x, this%m , ipiv, &
1039 gradlambda%x, this%m, info)
1041 if (info .ne. 0)
then
1042 call neko_error(
"DGESV failed to solve the linear system in " // &
1043 "mma_subsolve_dip (device).")
1045 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, host_to_device, &
1048 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
1051 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
1052 call device_col2(dummy_m%x_d, mu%x_d, this%m)
1053 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
1055 call device_cfill(dmu%x_d, epsi, this%m)
1056 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
1057 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
1058 call device_sub2(dmu%x_d, mu%x_d, this%m)
1060 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
1061 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
1063 steg = 1.0_rp / steg
1065 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
1066 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
1068 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1070 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1080 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
1082 call device_invcol2(y%x_d, dd%x_d, this%m)
1083 call device_pwmax2(y%x_d, zerom%x_d, this%m)
1089 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
1090 z = device_glsum(dummy_m%x_d, this%m)
1091 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1097 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1098 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1099 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1100 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1102 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1103 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1107 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1108 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1109 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1113 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1115 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1116 mpi_real_precision, mpi_sum, neko_comm, ierr)
1117 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1118 host_to_device, sync = .true.)
1120 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1121 call device_sub2(relambda%x_d, y%x_d, this%m)
1122 call device_add2(relambda%x_d, mu%x_d, this%m)
1123 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1125 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1126 call device_cadd(remu%x_d, -epsi, this%m)
1131 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1133 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
1135 residumax = maxval(abs([relambda%x, remu%x]))
1138 epsi = 0.1_rp * epsi
1142 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1143 call device_copy(this%xold1%x_d, designx_d, this%n)
1144 call device_copy(designx_d, x%x_d, this%n)
1147 call device_copy(this%y%x_d, y%x_d, this%m)
1149 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1150 call device_copy(this%mu%x_d, mu%x_d, this%m)
1155 call relambda%free()
1159 call gradlambda%free()
1165 call pjlambda%free()
1166 call qjlambda%free()
1171 end subroutine mma_subsolve_dip_device
1173end submodule mma_device