Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma_device.f90
1! Copyright (c) 2025, The Neko-TOP Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32
33submodule(mma) mma_device
34
35 use device_math, only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax, &
38 device_glsum, device_cmult2, device_pwmax
39 use device_mma_math, only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess
47
48 use neko_config, only: neko_bcknd_device
49 use device, only: device_to_host
50 use comm, only: neko_comm, pe_rank, mpi_real_precision
51 use mpi_f08, only: mpi_in_place, mpi_max, mpi_min
52
53 implicit none
54
55contains
56
57 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
58 ! ----------------------------------------------------- !
59 ! Update the design variable x by solving the convex !
60 ! approximation of the problem. !
61 ! !
62 ! This subroutine is called in each iteration of the !
63 ! optimization loop !
64 ! ----------------------------------------------------- !
65 class(mma_t), intent(inout) :: this
66 integer, intent(in) :: iter
67 type(c_ptr), intent(inout) :: x
68 type(c_ptr), intent(in) :: df0dx, fval, dfdx
69
70 if (.not. this%is_initialized) then
71 call neko_error("The MMA object is not initialized.")
72 end if
73
74 ! generate a convex approximation of the problem
75 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
76
77 !solve the approximation problem using interior point method
78 if (this%subsolver .eq. "dip") then
79 call mma_subsolve_dip_device(this, x)
80 else if (this%subsolver .eq. "dpip") then
81 call mma_subsolve_dpip_device(this, x)
82 else
83 call neko_error("Unrecognized subsolver for MMA in mma_device.")
84 end if
85
86 this%is_updated = .true.
87 end subroutine mma_update_device
88
89 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
90 class(mma_t), intent(inout) :: this
91 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
92
93 if (this%subsolver .eq. "dip") then
94 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
95 else
96 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
97 end if
98 end subroutine mma_kkt_device
99
101 ! point method (dip) subsolve of MMA algorithm.
102 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
103 class(mma_t), intent(inout) :: this
104 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
105
106 type(vector_t) :: relambda, remu
107
108 call relambda%init(this%m)
109 call remu%init(this%m)
110
111 ! relambda = fval - this%a%x * this%z - this%y%x + this%mu%x
112 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
113 this%m)
114 call device_sub2(relambda%x_d, this%y%x_d, this%m)
115 call device_add2(relambda%x_d, this%mu%x_d, this%m)
116
117 ! Compute residual for mu (eta in the paper)
118 call device_col3 (remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
119
120
121 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
122 device_maxval(remu%x_d, this%m)])
123 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
124 device_norm(remu%x_d, this%m))
125
126 call relambda%free()
127 call remu%free()
128 end subroutine mma_dip_kkt_device
129
131 ! point method (dpip) subsolve of MMA algorithm.
132 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
133 class(mma_t), intent(inout) :: this
134 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
135
136 real(kind=rp) :: rez, rezeta
137 type(vector_t) :: rey, relambda, remu, res
138 type(vector_t) :: rex, rexsi, reeta
139 integer :: ierr
140 real(kind=rp) :: re_sq_norm
141
142 call rey%init(this%m)
143 call relambda%init(this%m)
144 call remu%init(this%m)
145 call res%init(this%m)
146
147 call rex%init(this%n)
148 call rexsi%init(this%n)
149 call reeta%init(this%n)
150
151 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
152 this%eta%x_d, this%lambda%x_d, this%n, this%m)
153
154 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
155 call device_add2(rey%x_d, this%c%x_d, this%m)
156 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
157 call device_sub2(rey%x_d, this%mu%x_d, this%m)
158
159 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
160 this%m)
161
162 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
163 this%m)
164 call device_sub2(relambda%x_d, this%y%x_d, this%m)
165 call device_add2(relambda%x_d, this%s%x_d, this%m)
166
167 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
168 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
169
170 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
171 call device_col2(reeta%x_d, this%eta%x_d, this%n)
172
173 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
174
175 rezeta = this%zeta * this%z
176
177 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
178
179 this%residumax = maxval([ &
180 device_maxval(rex%x_d, this%n), &
181 device_maxval(rey%x_d, this%m), &
182 abs(rez), &
183 device_maxval(relambda%x_d, this%m), &
184 device_maxval(rexsi%x_d, this%n), &
185 device_maxval(reeta%x_d, this%n), &
186 device_maxval(remu%x_d, this%m), &
187 abs(rezeta), &
188 device_maxval(res%x_d, this%m)])
189
190 re_sq_norm = device_norm(rex%x_d, this%n) + &
191 device_norm(rexsi%x_d, this%n) + &
192 device_norm(reeta%x_d, this%n)
193
194 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
195 mpi_real_precision, mpi_max, neko_comm, ierr)
196
197 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
198 mpi_real_precision, mpi_sum, neko_comm, ierr)
199
200 this%residunorm = sqrt(( &
201 device_norm(rey%x_d, this%m) + &
202 rez**2 + &
203 device_norm(relambda%x_d, this%m) + &
204 device_norm(remu%x_d, this%m) + &
205 rezeta**2 + &
206 device_norm(res%x_d, this%m) &
207 ) + re_sq_norm)
208
209 call rey%free()
210 call relambda%free()
211 call remu%free()
212 call res%free()
213 call rex%free()
214 call rexsi%free()
215 call reeta%free()
216 end subroutine mma_dpip_kkt_device
217
218 !============================================================================!
219 ! private internal subroutines
220
222 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
223 ! ----------------------------------------------------- !
224 ! Generate the approximation sub problem by computing !
225 ! the lower and upper asymtotes and the other necessary !
226 ! parameters (alpha, beta, p0j, q0j, pij, qij, ...). !
227 ! ----------------------------------------------------- !
228 class(mma_t), intent(inout) :: this
229 type(c_ptr), intent(in) :: x
230 type(c_ptr), intent(in) :: df0dx
231 type(c_ptr), intent(in) :: fval
232 type(c_ptr), intent(in) :: dfdx
233
234 integer, intent(in) :: iter
235 integer :: ierr
236
237 type(vector_t):: x_diff
238
239 call x_diff%init(this%n)
240 call device_sub3 (x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
241 call device_memcpy(x_diff%x, x_diff%x_d, this%n, &
242 device_to_host, sync = .true.)
243
244 ! ------------------------------------------------------------------------ !
245 ! Setup the current asymptotes
246
247 if (iter .lt. 3) then
248 call device_copy(this%low%x_d, x, this%n)
249 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
250 call device_copy(this%upp%x_d, x, this%n)
251 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
252 else
253 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
254 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
255 this%asydecr, this%asyincr, this%n)
256 end if
257
258 ! ------------------------------------------------------------------------ !
259 ! Calculate p0j, q0j, pij, qij, alpha, and beta
260
261 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
262 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
263 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
264 this%qij%x_d, this%n, this%m)
265
266 ! ------------------------------------------------------------------------ !
267 ! Computing bi as defined in page 5
268
269 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
270 this%qij%x_d, this%n, this%m, this%bi%x_d)
271
272 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
273 sync = .true.)
274 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
275 mpi_real_precision, mpi_sum, neko_comm, ierr)
276 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
277 sync = .true.)
278 call device_sub2(this%bi%x_d, fval, this%m)
279
280 end subroutine mma_gensub_device
281
284 subroutine mma_subsolve_dpip_device(this, designx_d)
285 class(mma_t), intent(inout) :: this
286 type(c_ptr), intent(in) :: designx_d
287 integer :: iter, itto, ierr
288 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
289 delz, dz, dzeta, steg, zold, zetaold, new_residual
290 ! vectors with size m
291 type(vector_t) :: y, lambda, s, mu, rey, relambda, remu, res, &
292 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
293
294 ! vectors with size n
295 type(vector_t) :: x, xsi, eta, rex, rexsi, reeta, &
296 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
297
298 type(vector_t) :: bb
299 type(matrix_t) :: GG
300 type(matrix_t) :: AA
301
302 integer :: info
303 integer, dimension(this%m+1) :: ipiv
304 real(kind=rp) :: re_sq_norm
305
306 integer :: i
307
308 real(kind=rp) :: minimal_epsilon
309
310 call y%init(this%m)
311 call lambda%init(this%m)
312 call s%init(this%m)
313 call mu%init(this%m)
314 call rey%init(this%m)
315 call relambda%init(this%m)
316 call remu%init(this%m)
317 call res%init(this%m)
318 call dely%init(this%m)
319 call dellambda%init(this%m)
320 call dy%init(this%m)
321 call dlambda%init(this%m)
322 call ds%init(this%m)
323 call dmu%init(this%m)
324 call yold%init(this%m)
325 call lambdaold%init(this%m)
326 call sold%init(this%m)
327 call muold%init(this%m)
328 call x%init(this%n)
329 call xsi%init(this%n)
330 call eta%init(this%n)
331 call rex%init(this%n)
332 call rexsi%init(this%n)
333 call reeta%init(this%n)
334 call delx%init(this%n)
335 call diagx%init(this%n)
336 call dx%init(this%n)
337 call dxsi%init(this%n)
338 call deta%init(this%n)
339 call xold%init(this%n)
340 call xsiold%init(this%n)
341 call etaold%init(this%n)
342 call bb%init(this%m+1)
343
344 call gg%init(this%m, this%n)
345 call aa%init(this%m+1, this%m+1)
346
347 ! ------------------------------------------------------------------------ !
348 ! initial value for the parameters in the subsolve based on
349 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
350
351 epsi = 1.0_rp !100
352 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
353 this%n)
354 call device_cfill(y%x_d, 1.0_rp, this%m)
355 z = 1.0_rp
356 zeta = 1.0_rp
357 call device_cfill(lambda%x_d, 1.0_rp, this%m)
358 call device_cfill(s%x_d, 1.0_rp, this%m)
359 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
360 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
361 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
362
363 ! ------------------------------------------------------------------------ !
364 ! Computing the minimal epsilon and choose the most conservative one
365
366 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
367 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
368 mpi_real_precision, mpi_min, neko_comm, ierr)
369
370 ! ------------------------------------------------------------------------ !
371 ! The main loop of the dual-primal interior point method.
372
373 do while (epsi .gt. minimal_epsilon)
374
375 ! --------------------------------------------------------------------- !
376 ! Calculating residuals based on
377 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
378 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
379
380 associate(p0j => this%p0j, q0j => this%q0j, &
381 pij => this%pij, qij => this%qij, &
382 low => this%low, upp => this%upp, &
383 alpha => this%alpha, beta => this%beta, &
384 c => this%c, d => this%d, &
385 a0 => this%a0, a => this%a)
386
387 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
388 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
389 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
390
391 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
392 call device_add2(rey%x_d, c%x_d, this%m)
393 call device_sub2(rey%x_d, lambda%x_d, this%m)
394 call device_sub2(rey%x_d, mu%x_d, this%m)
395 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
396
397 call device_cfill(relambda%x_d, 0.0_rp, this%m)
398 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
399 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
400
401 end associate
402
403 ! --------------------------------------------------------------------- !
404 ! Computing the norm of the residuals
405
406 ! Complete the computations of lambda residuals
407 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
408 sync = .true.)
409 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
410 mpi_real_precision, mpi_sum, neko_comm, ierr)
411 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
412 sync = .true.)
413
414 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
415 call device_sub2(relambda%x_d, y%x_d, this%m)
416 call device_add2(relambda%x_d, s%x_d, this%m)
417 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
418
419 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
420 call device_col2(rexsi%x_d, xsi%x_d, this%n)
421 call device_cadd(rexsi%x_d, - epsi, this%n)
422
423 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
424 call device_col2(reeta%x_d, eta%x_d, this%n)
425 call device_cadd(reeta%x_d, - epsi, this%n)
426
427 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
428 call device_cadd(remu%x_d, - epsi, this%m)
429
430 rezeta = zeta * z - epsi
431
432 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
433 call device_cadd(res%x_d, - epsi, this%m)
434
435 ! Setup vectors of residuals and their norms
436 residual_max = maxval([device_maxval(rex%x_d, this%n), &
437 device_maxval(rey%x_d, this%m), abs(rez), &
438 device_maxval(relambda%x_d, this%m), &
439 device_maxval(rexsi%x_d, this%n), &
440 device_maxval(reeta%x_d, this%n), &
441 device_maxval(remu%x_d, this%m), abs(rezeta), &
442 device_maxval(res%x_d, this%m)])
443
444 re_sq_norm = device_norm(rex%x_d, this%n) + &
445 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
446
447 call mpi_allreduce(mpi_in_place, residual_max, 1, &
448 mpi_real_precision, mpi_max, neko_comm, ierr)
449
450 call mpi_allreduce(mpi_in_place, re_sq_norm, &
451 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
452
453 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
454 rez**2 + &
455 device_norm(relambda%x_d, this%m) + &
456 device_norm(remu%x_d, this%m)+ &
457 rezeta**2 + &
458 device_norm(res%x_d, this%m) &
459 + re_sq_norm)
460
461 ! --------------------------------------------------------------------- !
462 ! Internal loop
463
464 do iter = 1, this%max_iter
465
466 if (residual_max .lt. epsi) exit
467
468 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
469 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
470 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
471 this%m)
472
473 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
474 call device_add2(dely%x_d, this%c%x_d, this%m)
475 call device_sub2(dely%x_d, lambda%x_d, this%m)
476 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
477 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
478
479 ! Accumulate sums for dellambda (the term gi(x))
480 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
481 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
482 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
483
484 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
485 device_to_host, sync = .true.)
486 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
487 mpi_real_precision, mpi_sum, neko_comm, ierr)
488 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
489 host_to_device, sync = .true.)
490
491 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
492 1.0_rp, -z, this%m)
493 call device_sub2(dellambda%x_d, y%x_d, this%m)
494 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
495 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
496
497 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
498 this%pij%x_d, this%qij%x_d, this%n, this%m)
499
500 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
501 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
502 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
503 lambda%x_d, this%n, this%m)
504
505 !Here we only consider the case m<n in the matlab code
506 !assembling the right hand side matrix based on eq(5.20)
507 ! bb = [dellambda + dely/(this%d%x + &
508 ! (mu/y)) - matmul(GG,delx/diagx), delz ]
509
510 !--------------------------------------------------------------------!
511 ! for MPI computation of bb
512
513 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
514 this%m)
515
516 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
517 sync = .true.)
518 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
519 mpi_real_precision, mpi_sum, neko_comm, ierr)
520 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
521 host_to_device, sync = .true.)
522
523 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
524 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
525
526 !--------------------------------------------------------------------!
527 ! assembling the coefficients matrix AA based on eq(5.20)
528 ! AA(1:this%m,1:this%m) = &
529 ! matmul(matmul(GG,mma_diag(1/diagx)), transpose(GG))
530 ! !update diag(AA)
531 ! AA(1:this%m,1:this%m) = AA(1:this%m,1:this%m) + &
532 ! mma_diag(s/lambda + 1.0/(this%d%x + (mu/y)))
533
534 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
535 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
536 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
537 device_to_host, sync = .true.)
538
539 call mpi_allreduce(mpi_in_place, aa%x, &
540 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
541
542 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
543 sync = .false.)
544 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
545 sync = .false.)
546 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
547 sync = .false.)
548 call device_memcpy(s%x, s%x_d, this%m, device_to_host, &
549 sync = .true.)
550 do i = 1, this%m
551 ! update the diag AA
552 aa%x(i, i) = aa%x(i, i) &
553 + s%x(i) / lambda%x(i) &
554 + 1.0_rp / (this%d%x(i) + mu%x(i) / y%x(i))
555 end do
556 aa%x(1:this%m, this%m+1) = this%a%x
557 aa%x(this%m+1, 1:this%m) = this%a%x
558 aa%x(this%m+1, this%m+1) = - zeta/z
559
560 call device_memcpy(aa%x, aa%x_d, &
561 (this%m + 1) * (this%m + 1), host_to_device, sync = .true.)
562
563 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
564 sync = .true.)
565 call dgesv(this%m + 1, 1, aa%x, this%m + 1, ipiv, bb%x, this%m + 1, &
566 info)
567
568 if (info .ne. 0) then
569 call neko_error("DGESV failed to solve the linear system in " // &
570 "mma_subsolve_dpip (device).")
571 end if
572
573 call device_memcpy(bb%x, bb%x_d, this%m+1, host_to_device, &
574 sync = .true.)
575
576 dlambda%x = bb%x(1:this%m)
577 call device_memcpy(dlambda%x, dlambda%x_d, this%m, host_to_device, &
578 sync = .true.)
579
580 dz = bb%x(this%m + 1)
581
582 ! based on eq(5.19)
583 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
584 dlambda%x_d, this%n, this%m)
585 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
586 mu%x_d, y%x_d, this%m)
587 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
588 this%alpha%x_d, epsi, this%n)
589 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
590 this%beta%x_d, epsi, this%n)
591
592 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
593 call device_cmult(dmu%x_d, -1.0_rp, this%m)
594 call device_cadd(dmu%x_d, epsi, this%m)
595 call device_invcol2(dmu%x_d, y%x_d, this%m)
596 call device_sub2(dmu%x_d, mu%x_d, this%m)
597 dzeta = -zeta + (epsi - zeta * dz) / z
598 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
599 call device_cmult(ds%x_d, -1.0_rp, this%m)
600 call device_cadd(ds%x_d, epsi, this%m)
601 call device_invcol2(ds%x_d, lambda%x_d, this%m)
602 call device_sub2(ds%x_d, s%x_d, this%m)
603
604 steg = maxval([1.0_rp, &
605 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
606 -1.01_rp * dz / z, &
607 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
608 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
609 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
610 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
611 -1.01_rp * dzeta / zeta, &
612 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
613 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
614 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
615
616 steg = 1.0_rp / steg
617
618 call device_copy(xold%x_d, x%x_d, this%n)
619 call device_copy(yold%x_d, y%x_d, this%m)
620 zold = z
621 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
622 call device_copy(xsiold%x_d, xsi%x_d, this%n)
623 call device_copy(etaold%x_d, eta%x_d, this%n)
624 call device_copy(muold%x_d, mu%x_d, this%m)
625 zetaold = zeta
626 call device_copy(sold%x_d, s%x_d, this%m)
627
628 new_residual = 2.0_rp * residual_norm
629
630 ! Share the new_residual and steg values
631 call mpi_allreduce(mpi_in_place, steg, 1, &
632 mpi_real_precision, mpi_min, neko_comm, ierr)
633 call mpi_allreduce(mpi_in_place, new_residual, 1, &
634 mpi_real_precision, mpi_min, neko_comm, ierr)
635
636 ! The innermost loop to determine the suitable step length
637 ! using the Backtracking Line Search approach
638 itto = 0
639 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
640 itto = itto + 1
641
642 ! update the variables
643 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
644 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
645 z = zold + steg*dz
646 call device_add3s2(lambda%x_d, lambdaold%x_d, &
647 dlambda%x_d, 1.0_rp, steg, this%m)
648 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
649 1.0_rp, steg, this%n)
650 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
651 1.0_rp, steg, this%n)
652 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
653 1.0_rp, steg, this%m)
654 zeta = zetaold + steg*dzeta
655 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
656 steg, this%m)
657
658 ! Recompute the new_residual to see if this stepsize improves
659 ! the residue
660 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
661 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
662 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
663 eta%x_d, this%n, this%m)
664
665 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
666 call device_add2(rey%x_d, this%c%x_d, this%m)
667 call device_sub2(rey%x_d, lambda%x_d, this%m)
668 call device_sub2(rey%x_d, mu%x_d, this%m)
669
670 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
671
672 ! Accumulate sums for relambda (the term gi(x))
673 call device_cfill(relambda%x_d, 0.0_rp, this%m)
674 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
675 this%low%x_d, this%pij%x_d, this%qij%x_d, &
676 this%n, this%m)
677
678 call device_memcpy(relambda%x, relambda%x_d, this%m, &
679 device_to_host, sync = .true.)
680 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
681 mpi_real_precision, mpi_sum, neko_comm, ierr)
682 call device_memcpy(relambda%x, relambda%x_d, &
683 this%m, host_to_device, sync = .true.)
684
685 call device_add3s2(relambda%x_d, relambda%x_d, &
686 this%a%x_d, 1.0_rp, -z, this%m)
687 call device_sub2(relambda%x_d, y%x_d, this%m)
688 call device_add2(relambda%x_d, s%x_d, this%m)
689 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
690
691 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
692 call device_col2(rexsi%x_d, xsi%x_d, this%n)
693 call device_cadd(rexsi%x_d, - epsi, this%n)
694
695 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
696 call device_col2(reeta%x_d, eta%x_d, this%n)
697 call device_cadd(reeta%x_d, - epsi, this%n)
698
699 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
700 call device_cadd(remu%x_d, - epsi, this%m)
701
702 rezeta = zeta*z - epsi
703
704 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
705 call device_cadd(res%x_d, - epsi, this%m)
706
707 ! Compute squared norms for the residuals
708 re_sq_norm = device_norm(rex%x_d, this%n) + &
709 device_norm(rexsi%x_d, this%n) + &
710 device_norm(reeta%x_d, this%n)
711 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
712 mpi_real_precision, mpi_sum, neko_comm, ierr)
713
714 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
715 rez**2 + &
716 device_norm(relambda%x_d, this%m) + &
717 device_norm(remu%x_d, this%m) + &
718 rezeta**2 + &
719 device_norm(res%x_d, this%m) + &
720 re_sq_norm)
721
722 steg = steg / 2.0_rp
723
724 end do
725 steg = 2.0_rp * steg ! Correction for the final division by 2
726
727 ! Update the maximum and norm of the residuals
728 residual_norm = new_residual
729 residual_max = maxval([ &
730 device_maxval(rex%x_d, this%n), &
731 device_maxval(rey%x_d, this%m), &
732 abs(rez), &
733 device_maxval(relambda%x_d, this%m), &
734 device_maxval(rexsi%x_d, this%n), &
735 device_maxval(reeta%x_d, this%n), &
736 device_maxval(remu%x_d, this%m), &
737 abs(rezeta), &
738 device_maxval(res%x_d, this%m)])
739
740 call mpi_allreduce(mpi_in_place, residual_max, 1, &
741 mpi_real_precision, mpi_max, neko_comm, ierr)
742
743 end do
744
745 epsi = 0.1_rp * epsi
746 end do
747
748 ! Save the new designx
749 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
750 call device_copy(this%xold1%x_d, designx_d, this%n)
751 call device_copy(designx_d, x%x_d, this%n)
752
753 ! update the parameters of the MMA object nesessary to compute KKT residual
754 call device_copy(this%y%x_d, y%x_d, this%m)
755 this%z = z
756 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
757 this%zeta = zeta
758 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
759 call device_copy(this%eta%x_d, eta%x_d, this%n)
760 call device_copy(this%mu%x_d, mu%x_d, this%m)
761 call device_copy(this%s%x_d, s%x_d, this%m)
762
763 !free all the initiated variables in this subroutine
764 call y%free()
765 call lambda%free()
766 call s%free()
767 call mu%free()
768 call rey%free()
769 call relambda%free()
770 call remu%free()
771 call res%free()
772 call dely%free()
773 call dellambda%free()
774 call dy%free()
775 call dlambda%free()
776 call ds%free()
777 call dmu%free()
778 call yold%free()
779 call lambdaold%free()
780 call sold%free()
781 call muold%free()
782 call x%free()
783 call xsi%free()
784 call eta%free()
785 call rex%free()
786 call rexsi%free()
787 call reeta%free()
788 call delx%free()
789 call diagx%free()
790 call dx%free()
791 call dxsi%free()
792 call deta%free()
793 call xold%free()
794 call xsiold%free()
795 call etaold%free()
796 call bb%free()
797
798 end subroutine mma_subsolve_dpip_device
799
802 subroutine mma_subsolve_dip_device(this, designx_d)
803 class(mma_t), intent(inout) :: this
804 type(c_ptr), intent(in) :: designx_d
805 integer :: iter, ierr
806 real(kind=rp) :: epsi, residumax, z, steg
807 ! vectors with size m
808 type(vector_t) :: y, lambda, mu, relambda, remu, dlambda, dmu, &
809 gradlambda, zerom, dd, dummy_m
810 ! vectors with size n
811 type(vector_t) :: x, pjlambda, qjlambda
812
813 ! inverse of a diag matrix:
814 type(vector_t) :: Ljjxinv ! [∇_x^2 Ljj]−1
815 type(matrix_t) :: hijx ! ∇_x hij
816 type(matrix_t) :: Hess
817 real(kind=rp) :: hesstrace
818
819 integer :: info
820 integer, dimension(this%m+1) :: ipiv
821 integer :: i
822
823 real(kind=rp) :: minimal_epsilon
824
825 call y%init(this%m)
826 call lambda%init(this%m)
827 call mu%init(this%m)
828 call relambda%init(this%m)
829 call remu%init(this%m)
830 call dlambda%init(this%m)
831 call dmu%init(this%m)
832 call gradlambda%init(this%m)
833 call zerom%init(this%m)
834 call dd%init(this%m)
835 call dummy_m%init(this%m)
836
837 call x%init(this%n)
838 call pjlambda%init(this%n)
839 call qjlambda%init(this%n)
840
841 call ljjxinv%init(this%n)
842 call hijx%init(this%m,this%n)
843 call hess%init(this%m,this%m)
844
845 call device_cfill(zerom%x_d, 0.0_rp, this%m)
846
847 ! ------------------------------------------------------------------------ !
848 ! initial value for the parameters in the subsolve based on
849 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
850
851 epsi = 1.0_rp !100
852 call device_cfill(y%x_d, 1.0_rp, this%m)
853 ! initialize lambda with an array of ones (change to this%c%x/2 if needed!)
854 call device_cfill(lambda%x_d, 1.0_rp, this%m)
855 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
856 call device_pwmax(lambda%x_d, dummy_m%x_d, this%m)
857
858 call device_cfill(mu%x_d, 1.0_rp, this%m)
859 z = 0.0_rp
860
861 ! dd is defined as this%d + 1.0e-8_rp, to avoid devision by 0 in computing y
862 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
863
864 ! ------------------------------------------------------------------------ !
865 ! Computing the minimal epsilon and choose the most conservative one
866
867 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
868 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
869 mpi_real_precision, mpi_min, neko_comm, ierr)
870
871 ! ------------------------------------------------------------------------ !
872 ! The main loop of the dual-primal interior point method.
873
874 outer: do while (epsi .gt. minimal_epsilon)
875 ! calculating residuals based on
876 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
877 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
878 associate(p0j => this%p0j, q0j => this%q0j, &
879 pij => this%pij, qij => this%qij, &
880 low => this%low, upp => this%upp, &
881 alpha => this%alpha, beta => this%beta, &
882 c => this%c, a0 => this%a0, a => this%a)
883
884 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
885 ! the initial value of λ
886
887 ! Comput the value of y that minimizes L_y for the current λ
888 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
889 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
890 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
891 ! division by dd to avoid devision by 0 (in case this%d%x_d)
892 call device_invcol2(y%x_d, dd%x_d, this%m)
893 call device_pwmax(y%x_d, zerom%x_d, this%m)
894
895 ! Comput the value of z that minimizes L_z for the current λ
896 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
897 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
898 ! ensure z>=0
899 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
900 z = device_glsum(dummy_m%x_d, this%m)
901 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
902
903 ! Comput the value of x that minimizes L_x for the current λ
904 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
905 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
906 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
907 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
908 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
909 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
910 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
911
912 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
913 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
914
915 call device_cfill(relambda%x_d, 0.0_rp, this%m)
916 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
917 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
918
919 ! Global comminucation for relambda values
920
921 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
922 sync = .true.)
923 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
924 mpi_real_precision, mpi_sum, neko_comm, ierr)
925 call device_memcpy(relambda%x, relambda%x_d, this%m, &
926 host_to_device, sync = .true.)
927
928 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
929 call device_sub2(relambda%x_d, y%x_d, this%m)
930 call device_add2(relambda%x_d, mu%x_d, this%m)
931 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
932
933 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
934 call device_cadd(remu%x_d, -epsi, this%m)
935
936 ! Download the re(lambda, mu) to CPU to calculate residumax
937
938 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
939 sync = .true.)
940 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
941 sync = .true.)
942 residumax = maxval(abs([relambda%x, remu%x]))
943
944 ! ------------------------------------------------------------------- !
945 ! Internal loop
946 do iter = 1, this%max_iter
947 !Check the condition
948 if (residumax .lt. epsi) exit
949
950 ! Compute dL(x, y, z, λ)/dλ for the updated x(λ), y(λ), z(λ)
951 ! based on the implementation in the following paper by Niels
952 ! https://doi.org/10.1007/s00158-012-0869-2
953 ! (https://github.com/topopt/TopOpt_in_PETSc/blob/master/MMA.cc)
954 ! The formula for gradlambda and relambda are basically the same:
955 ! thus, we utilise gradlambda = relambda - mu for efficiency
956 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
957 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
958
959 ! Update gradlambda as the right hand side for Newton's method(eq10)
960 call device_cfill(dummy_m%x_d, epsi, this%m)
961 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
962 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
963 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
964
965 ! Computing the Hessian as in equation (13) in
966 !! https://doi.org/10.1007/s00158-012-0869-2
967
968 !--------------contributions of x terms to Hess--------------------!
969 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
970 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
971
972 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
973 this%pij%x_d, this%qij%x_d, this%n, this%m)
974
975 call device_memcpy(hijx%x, hijx%x_d, this%n*this%m, device_to_host, &
976 sync = .true.)
977
978 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
979 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
980
981 ! download Hess to CPU, mpi reduce, upload to the device
982 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
983 sync = .true.)
984 call mpi_allreduce(mpi_in_place, hess%x, &
985 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
986 ! No need to upload to device since we solve LSE on CPU
987 ! call device_memcpy(Hess%x, Hess%x_d, this%m*this%m, &
988 ! HOST_TO_DEVICE, sync = .true.)
989
990 !---------------contributions of z terms to Hess-------------------!
991 ! There is no contibution to the Hess from z terms as z terms are
992 ! linear w.r.t λ
993
994
995 !---------------contributions of y terms to Hess-------------------!
996 ! Only for inactive constraint, we consider contributions to Hess.
997 ! Note that if d(i) = 0, the y terms (just like z terms) will not
998 ! contribute to the Hessian matrix.
999 ! Note that since we use DGESV to solve LSE on CPU, we dont need
1000 ! cuda kernel for this part
1001
1002 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1003 sync = .true.)
1004 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1005 sync = .true.)
1006 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
1007 sync = .true.)
1008 do i = 1, this%m
1009 if (y%x(i) .gt. 0.0_rp) then
1010 if (abs(this%d%x(i)) < 1.0e-15_rp) then
1011 ! Hess(i, i) = Hess(i, i) - 1.0_rp/1.0e-8_rp
1012 else
1013 hess%x(i, i) = hess%x(i, i) - 1.0_rp/this%d%x(i)
1014 end if
1015 end if
1016 ! Based on eq(10), note the term (-\Omega \Lambda)
1017 hess%x(i, i) = hess%x(i, i) - mu%x(i) / lambda%x(i)
1018 end do
1019
1020 ! Improve the robustness by stablizing the Hess using
1021 ! Levenberg-Marquardt algorithm (heuristically)
1022 hesstrace = 0.0_rp
1023 do i=1, this%m
1024 hesstrace = hesstrace + hess%x(i, i)
1025 end do
1026 do i=1, this%m
1027 hess%x(i,i) = hess%x(i, i) - &
1028 max(-1.0e-4_rp*hesstrace/this%m, 1.0e-7_rp)
1029 end do
1030
1031 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, device_to_host, &
1032 sync = .true.)
1033 call dgesv(this%m , 1, hess%x, this%m , ipiv, &
1034 gradlambda%x, this%m, info)
1035
1036 if (info .ne. 0) then
1037 call neko_error("DGESV failed to solve the linear system in " // &
1038 "mma_subsolve_dip (device).")
1039 end if
1040 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, host_to_device, &
1041 sync = .true.)
1042
1043 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
1044
1045 ! based on eq(11) for delta eta
1046 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
1047 call device_col2(dummy_m%x_d, mu%x_d, this%m)
1048 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
1049
1050 call device_cfill(dmu%x_d, epsi, this%m)
1051 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
1052 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
1053 call device_sub2(dmu%x_d, mu%x_d, this%m)
1054
1055 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
1056 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
1057 this%m)])
1058 steg = 1.0_rp / steg
1059
1060 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
1061 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
1062
1063 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1064 sync = .true.)
1065 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1066 sync = .true.)
1067
1068 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
1069 ! the updated values of λ
1070
1071 ! Comput the value of y that minimizes L_y for the current λ
1072 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
1073 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
1074
1075 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
1076 ! division by dd to avoid devision by 0 (in case this%d%x_d)
1077 call device_invcol2(y%x_d, dd%x_d, this%m)
1078 call device_pwmax(y%x_d, zerom%x_d, this%m)
1079
1080 ! Comput the value of z that minimizes L_z for the current λ
1081 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
1082 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
1083 ! ensure z>=0
1084 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
1085 z = device_glsum(dummy_m%x_d, this%m)
1086 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1087
1088 ! Comput the value of x that minimizes L_x for the current λ
1089 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
1090 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
1091 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
1092 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1093 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1094 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1095 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1096
1097 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1098 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1099
1100 ! Compute the residual for the lambda and mu using eq(9) and eq(15)
1101
1102 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1103 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1104 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1105
1106 ! Global comminucation for relambda values
1107
1108 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1109 sync = .true.)
1110 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1111 mpi_real_precision, mpi_sum, neko_comm, ierr)
1112 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1113 host_to_device, sync = .true.)
1114
1115 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1116 call device_sub2(relambda%x_d, y%x_d, this%m)
1117 call device_add2(relambda%x_d, mu%x_d, this%m)
1118 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1119
1120 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1121 call device_cadd(remu%x_d, -epsi, this%m)
1122
1123
1125
1126 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1127 sync = .true.)
1128 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
1129 sync = .true.)
1130 residumax = maxval(abs([relambda%x, remu%x]))
1131 end do
1132 end associate
1133 epsi = 0.1_rp * epsi
1134 end do outer
1135
1136 ! Save the new designx
1137 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1138 call device_copy(this%xold1%x_d, designx_d, this%n)
1139 call device_copy(designx_d, x%x_d, this%n)
1140
1141 ! update the parameters of the MMA object nesessary to compute KKT residual
1142 call device_copy(this%y%x_d, y%x_d, this%m)
1143 this%z = z
1144 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1145 call device_copy(this%mu%x_d, mu%x_d, this%m)
1146
1147 call y%free()
1148 call lambda%free()
1149 call mu%free()
1150 call relambda%free()
1151 call remu%free()
1152 call dlambda%free()
1153 call dmu%free()
1154 call gradlambda%free()
1155 call zerom%free()
1156 call dd%free()
1157 call dummy_m%free()
1158
1159 call x%free()
1160 call pjlambda%free()
1161 call qjlambda%free()
1162
1163 call ljjxinv%free()
1164 call hijx%free()
1165 call hess%free()
1166 end subroutine mma_subsolve_dip_device
1167
1168end submodule mma_device