Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma_device.f90
1! Copyright (c) 2025, The Neko-TOP Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32
33submodule(mma) mma_device
34
35 use device_math, only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
38 device_glsum, device_cmult2
39 use device_mma_math, only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess
47
48 use neko_config, only: neko_bcknd_device
49 use device, only: device_to_host
50 use comm, only: neko_comm, pe_rank, mpi_real_precision
51 use mpi_f08, only: mpi_in_place, mpi_max, mpi_min
52 use profiler, only: profiler_start_region, profiler_end_region
53
54 implicit none
55
56contains
57
58 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
59 ! ----------------------------------------------------- !
60 ! Update the design variable x by solving the convex !
61 ! approximation of the problem. !
62 ! !
63 ! This subroutine is called in each iteration of the !
64 ! optimization loop !
65 ! ----------------------------------------------------- !
66 class(mma_t), intent(inout) :: this
67 integer, intent(in) :: iter
68 type(c_ptr), intent(inout) :: x
69 type(c_ptr), intent(in) :: df0dx, fval, dfdx
70
71 if (.not. this%is_initialized) then
72 call neko_error("The MMA object is not initialized.")
73 end if
74
75 call profiler_start_region("MMA gensub")
76 ! generate a convex approximation of the problem
77 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
78 call profiler_end_region("MMA gensub")
79
80 !solve the approximation problem using interior point method
81 call profiler_start_region("MMA subsolve")
82 if (this%subsolver .eq. "dip") then
83 call mma_subsolve_dip_device(this, x)
84 else if (this%subsolver .eq. "dpip") then
85 call mma_subsolve_dpip_device(this, x)
86 else
87 call neko_error("Unrecognized subsolver for MMA in mma_device.")
88 end if
89 call profiler_end_region("MMA subsolve")
90
91 this%is_updated = .true.
92 end subroutine mma_update_device
93
94 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
95 class(mma_t), intent(inout) :: this
96 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
97
98 if (this%subsolver .eq. "dip") then
99 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
100 else
101 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
102 end if
103 end subroutine mma_kkt_device
104
106 ! point method (dip) subsolve of MMA algorithm.
107 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
108 class(mma_t), intent(inout) :: this
109 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
110
111 type(vector_t) :: relambda, remu
112
113 call relambda%init(this%m)
114 call remu%init(this%m)
115
116 ! relambda = fval - this%a%x * this%z - this%y%x + this%mu%x
117 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
118 this%m)
119 call device_sub2(relambda%x_d, this%y%x_d, this%m)
120 call device_add2(relambda%x_d, this%mu%x_d, this%m)
121
122 ! Compute residual for mu (eta in the paper)
123 call device_col3 (remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
124
125
126 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
127 device_maxval(remu%x_d, this%m)])
128 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
129 device_norm(remu%x_d, this%m))
130
131 call relambda%free()
132 call remu%free()
133 end subroutine mma_dip_kkt_device
134
136 ! point method (dpip) subsolve of MMA algorithm.
137 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
138 class(mma_t), intent(inout) :: this
139 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
140
141 real(kind=rp) :: rez, rezeta
142 type(vector_t) :: rey, relambda, remu, res
143 type(vector_t) :: rex, rexsi, reeta
144 integer :: ierr
145 real(kind=rp) :: re_sq_norm
146
147 call rey%init(this%m)
148 call relambda%init(this%m)
149 call remu%init(this%m)
150 call res%init(this%m)
151
152 call rex%init(this%n)
153 call rexsi%init(this%n)
154 call reeta%init(this%n)
155
156 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
157 this%eta%x_d, this%lambda%x_d, this%n, this%m)
158
159 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
160 call device_add2(rey%x_d, this%c%x_d, this%m)
161 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
162 call device_sub2(rey%x_d, this%mu%x_d, this%m)
163
164 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
165 this%m)
166
167 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
168 this%m)
169 call device_sub2(relambda%x_d, this%y%x_d, this%m)
170 call device_add2(relambda%x_d, this%s%x_d, this%m)
171
172 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
173 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
174
175 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
176 call device_col2(reeta%x_d, this%eta%x_d, this%n)
177
178 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
179
180 rezeta = this%zeta * this%z
181
182 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
183
184 this%residumax = maxval([ &
185 device_maxval(rex%x_d, this%n), &
186 device_maxval(rey%x_d, this%m), &
187 abs(rez), &
188 device_maxval(relambda%x_d, this%m), &
189 device_maxval(rexsi%x_d, this%n), &
190 device_maxval(reeta%x_d, this%n), &
191 device_maxval(remu%x_d, this%m), &
192 abs(rezeta), &
193 device_maxval(res%x_d, this%m)])
194
195 re_sq_norm = device_norm(rex%x_d, this%n) + &
196 device_norm(rexsi%x_d, this%n) + &
197 device_norm(reeta%x_d, this%n)
198
199 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
200 mpi_real_precision, mpi_max, neko_comm, ierr)
201
202 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
203 mpi_real_precision, mpi_sum, neko_comm, ierr)
204
205 this%residunorm = sqrt(( &
206 device_norm(rey%x_d, this%m) + &
207 rez**2 + &
208 device_norm(relambda%x_d, this%m) + &
209 device_norm(remu%x_d, this%m) + &
210 rezeta**2 + &
211 device_norm(res%x_d, this%m) &
212 ) + re_sq_norm)
213
214 call rey%free()
215 call relambda%free()
216 call remu%free()
217 call res%free()
218 call rex%free()
219 call rexsi%free()
220 call reeta%free()
221 end subroutine mma_dpip_kkt_device
222
223 !============================================================================!
224 ! private internal subroutines
225
227 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
228 ! ----------------------------------------------------- !
229 ! Generate the approximation sub problem by computing !
230 ! the lower and upper asymtotes and the other necessary !
231 ! parameters (alpha, beta, p0j, q0j, pij, qij, ...). !
232 ! ----------------------------------------------------- !
233 class(mma_t), intent(inout) :: this
234 type(c_ptr), intent(in) :: x
235 type(c_ptr), intent(in) :: df0dx
236 type(c_ptr), intent(in) :: fval
237 type(c_ptr), intent(in) :: dfdx
238
239 integer, intent(in) :: iter
240 integer :: ierr
241
242 type(vector_t):: x_diff
243
244 call x_diff%init(this%n)
245 call device_sub3 (x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
246 call device_memcpy(x_diff%x, x_diff%x_d, this%n, &
247 device_to_host, sync = .true.)
248
249 ! ------------------------------------------------------------------------ !
250 ! Setup the current asymptotes
251
252 if (iter .lt. 3) then
253 call device_copy(this%low%x_d, x, this%n)
254 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
255 call device_copy(this%upp%x_d, x, this%n)
256 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
257 else
258 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
259 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
260 this%asydecr, this%asyincr, this%n)
261 end if
262
263 ! ------------------------------------------------------------------------ !
264 ! Calculate p0j, q0j, pij, qij, alpha, and beta
265
266 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
267 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
268 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
269 this%qij%x_d, this%n, this%m)
270
271 ! ------------------------------------------------------------------------ !
272 ! Computing bi as defined in page 5
273
274 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
275 this%qij%x_d, this%n, this%m, this%bi%x_d)
276
277 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
278 sync = .true.)
279 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
280 mpi_real_precision, mpi_sum, neko_comm, ierr)
281 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
282 sync = .true.)
283 call device_sub2(this%bi%x_d, fval, this%m)
284
285 end subroutine mma_gensub_device
286
289 subroutine mma_subsolve_dpip_device(this, designx_d)
290 class(mma_t), intent(inout) :: this
291 type(c_ptr), intent(in) :: designx_d
292 integer :: iter, itto, ierr
293 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
294 delz, dz, dzeta, steg, zold, zetaold, new_residual
295 ! vectors with size m
296 type(vector_t) :: y, lambda, s, mu, rey, relambda, remu, res, &
297 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
298
299 ! vectors with size n
300 type(vector_t) :: x, xsi, eta, rex, rexsi, reeta, &
301 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
302
303 type(vector_t) :: bb
304 type(matrix_t) :: GG
305 type(matrix_t) :: AA
306
307 integer :: info
308 integer, dimension(this%m+1) :: ipiv
309 real(kind=rp) :: re_sq_norm
310
311 integer :: i
312
313 real(kind=rp) :: minimal_epsilon
314
315 call y%init(this%m)
316 call lambda%init(this%m)
317 call s%init(this%m)
318 call mu%init(this%m)
319 call rey%init(this%m)
320 call relambda%init(this%m)
321 call remu%init(this%m)
322 call res%init(this%m)
323 call dely%init(this%m)
324 call dellambda%init(this%m)
325 call dy%init(this%m)
326 call dlambda%init(this%m)
327 call ds%init(this%m)
328 call dmu%init(this%m)
329 call yold%init(this%m)
330 call lambdaold%init(this%m)
331 call sold%init(this%m)
332 call muold%init(this%m)
333 call x%init(this%n)
334 call xsi%init(this%n)
335 call eta%init(this%n)
336 call rex%init(this%n)
337 call rexsi%init(this%n)
338 call reeta%init(this%n)
339 call delx%init(this%n)
340 call diagx%init(this%n)
341 call dx%init(this%n)
342 call dxsi%init(this%n)
343 call deta%init(this%n)
344 call xold%init(this%n)
345 call xsiold%init(this%n)
346 call etaold%init(this%n)
347 call bb%init(this%m+1)
348
349 call gg%init(this%m, this%n)
350 call aa%init(this%m+1, this%m+1)
351
352 ! ------------------------------------------------------------------------ !
353 ! initial value for the parameters in the subsolve based on
354 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
355
356 epsi = 1.0_rp !100
357 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
358 this%n)
359 call device_cfill(y%x_d, 1.0_rp, this%m)
360 z = 1.0_rp
361 zeta = 1.0_rp
362 call device_cfill(lambda%x_d, 1.0_rp, this%m)
363 call device_cfill(s%x_d, 1.0_rp, this%m)
364 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
365 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
366 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
367
368 ! ------------------------------------------------------------------------ !
369 ! Computing the minimal epsilon and choose the most conservative one
370
371 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
372 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
373 mpi_real_precision, mpi_min, neko_comm, ierr)
374
375 ! ------------------------------------------------------------------------ !
376 ! The main loop of the dual-primal interior point method.
377
378 do while (epsi .gt. minimal_epsilon)
379
380 ! --------------------------------------------------------------------- !
381 ! Calculating residuals based on
382 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
383 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
384
385 associate(p0j => this%p0j, q0j => this%q0j, &
386 pij => this%pij, qij => this%qij, &
387 low => this%low, upp => this%upp, &
388 alpha => this%alpha, beta => this%beta, &
389 c => this%c, d => this%d, &
390 a0 => this%a0, a => this%a)
391
392 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
393 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
394 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
395
396 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
397 call device_add2(rey%x_d, c%x_d, this%m)
398 call device_sub2(rey%x_d, lambda%x_d, this%m)
399 call device_sub2(rey%x_d, mu%x_d, this%m)
400 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
401
402 call device_cfill(relambda%x_d, 0.0_rp, this%m)
403 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
404 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
405
406 end associate
407
408 ! --------------------------------------------------------------------- !
409 ! Computing the norm of the residuals
410
411 ! Complete the computations of lambda residuals
412 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
413 sync = .true.)
414 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
415 mpi_real_precision, mpi_sum, neko_comm, ierr)
416 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
417 sync = .true.)
418
419 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
420 call device_sub2(relambda%x_d, y%x_d, this%m)
421 call device_add2(relambda%x_d, s%x_d, this%m)
422 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
423
424 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
425 call device_col2(rexsi%x_d, xsi%x_d, this%n)
426 call device_cadd(rexsi%x_d, - epsi, this%n)
427
428 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
429 call device_col2(reeta%x_d, eta%x_d, this%n)
430 call device_cadd(reeta%x_d, - epsi, this%n)
431
432 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
433 call device_cadd(remu%x_d, - epsi, this%m)
434
435 rezeta = zeta * z - epsi
436
437 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
438 call device_cadd(res%x_d, - epsi, this%m)
439
440 ! Setup vectors of residuals and their norms
441 residual_max = maxval([device_maxval(rex%x_d, this%n), &
442 device_maxval(rey%x_d, this%m), abs(rez), &
443 device_maxval(relambda%x_d, this%m), &
444 device_maxval(rexsi%x_d, this%n), &
445 device_maxval(reeta%x_d, this%n), &
446 device_maxval(remu%x_d, this%m), abs(rezeta), &
447 device_maxval(res%x_d, this%m)])
448
449 re_sq_norm = device_norm(rex%x_d, this%n) + &
450 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
451
452 call mpi_allreduce(mpi_in_place, residual_max, 1, &
453 mpi_real_precision, mpi_max, neko_comm, ierr)
454
455 call mpi_allreduce(mpi_in_place, re_sq_norm, &
456 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
457
458 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
459 rez**2 + &
460 device_norm(relambda%x_d, this%m) + &
461 device_norm(remu%x_d, this%m)+ &
462 rezeta**2 + &
463 device_norm(res%x_d, this%m) &
464 + re_sq_norm)
465
466 ! --------------------------------------------------------------------- !
467 ! Internal loop
468
469 do iter = 1, this%max_iter
470
471 if (residual_max .lt. epsi) exit
472
473 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
474 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
475 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
476 this%m)
477
478 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
479 call device_add2(dely%x_d, this%c%x_d, this%m)
480 call device_sub2(dely%x_d, lambda%x_d, this%m)
481 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
482 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
483
484 ! Accumulate sums for dellambda (the term gi(x))
485 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
486 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
487 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
488
489 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
490 device_to_host, sync = .true.)
491 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
492 mpi_real_precision, mpi_sum, neko_comm, ierr)
493 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
494 host_to_device, sync = .true.)
495
496 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
497 1.0_rp, -z, this%m)
498 call device_sub2(dellambda%x_d, y%x_d, this%m)
499 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
500 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
501
502 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
503 this%pij%x_d, this%qij%x_d, this%n, this%m)
504
505 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
506 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
507 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
508 lambda%x_d, this%n, this%m)
509
510 !Here we only consider the case m<n in the matlab code
511 !assembling the right hand side matrix based on eq(5.20)
512 ! bb = [dellambda + dely/(this%d%x + &
513 ! (mu/y)) - matmul(GG,delx/diagx), delz ]
514
515 !--------------------------------------------------------------------!
516 ! for MPI computation of bb
517
518 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
519 this%m)
520
521 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
522 sync = .true.)
523 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
524 mpi_real_precision, mpi_sum, neko_comm, ierr)
525 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
526 host_to_device, sync = .true.)
527
528 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
529 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
530
531 !--------------------------------------------------------------------!
532 ! assembling the coefficients matrix AA based on eq(5.20)
533 ! AA(1:this%m,1:this%m) = &
534 ! matmul(matmul(GG,mma_diag(1/diagx)), transpose(GG))
535 ! !update diag(AA)
536 ! AA(1:this%m,1:this%m) = AA(1:this%m,1:this%m) + &
537 ! mma_diag(s/lambda + 1.0/(this%d%x + (mu/y)))
538
539 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
540 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
541 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
542 device_to_host, sync = .true.)
543
544 call mpi_allreduce(mpi_in_place, aa%x, &
545 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
546
547 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
548 sync = .false.)
549 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
550 sync = .false.)
551 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
552 sync = .false.)
553 call device_memcpy(s%x, s%x_d, this%m, device_to_host, &
554 sync = .true.)
555 do i = 1, this%m
556 ! update the diag AA
557 aa%x(i, i) = aa%x(i, i) &
558 + s%x(i) / lambda%x(i) &
559 + 1.0_rp / (this%d%x(i) + mu%x(i) / y%x(i))
560 end do
561 aa%x(1:this%m, this%m+1) = this%a%x
562 aa%x(this%m+1, 1:this%m) = this%a%x
563 aa%x(this%m+1, this%m+1) = - zeta/z
564
565 call device_memcpy(aa%x, aa%x_d, &
566 (this%m + 1) * (this%m + 1), host_to_device, sync = .true.)
567
568 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
569 sync = .true.)
570 call dgesv(this%m + 1, 1, aa%x, this%m + 1, ipiv, bb%x, this%m + 1, &
571 info)
572
573 if (info .ne. 0) then
574 call neko_error("DGESV failed to solve the linear system in " // &
575 "mma_subsolve_dpip (device).")
576 end if
577
578 call device_memcpy(bb%x, bb%x_d, this%m+1, host_to_device, &
579 sync = .true.)
580
581 dlambda%x = bb%x(1:this%m)
582 call device_memcpy(dlambda%x, dlambda%x_d, this%m, host_to_device, &
583 sync = .true.)
584
585 dz = bb%x(this%m + 1)
586
587 ! based on eq(5.19)
588 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
589 dlambda%x_d, this%n, this%m)
590 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
591 mu%x_d, y%x_d, this%m)
592 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
593 this%alpha%x_d, epsi, this%n)
594 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
595 this%beta%x_d, epsi, this%n)
596
597 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
598 call device_cmult(dmu%x_d, -1.0_rp, this%m)
599 call device_cadd(dmu%x_d, epsi, this%m)
600 call device_invcol2(dmu%x_d, y%x_d, this%m)
601 call device_sub2(dmu%x_d, mu%x_d, this%m)
602 dzeta = -zeta + (epsi - zeta * dz) / z
603 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
604 call device_cmult(ds%x_d, -1.0_rp, this%m)
605 call device_cadd(ds%x_d, epsi, this%m)
606 call device_invcol2(ds%x_d, lambda%x_d, this%m)
607 call device_sub2(ds%x_d, s%x_d, this%m)
608
609 steg = maxval([1.0_rp, &
610 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
611 -1.01_rp * dz / z, &
612 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
613 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
614 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
615 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
616 -1.01_rp * dzeta / zeta, &
617 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
618 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
619 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
620
621 steg = 1.0_rp / steg
622
623 call device_copy(xold%x_d, x%x_d, this%n)
624 call device_copy(yold%x_d, y%x_d, this%m)
625 zold = z
626 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
627 call device_copy(xsiold%x_d, xsi%x_d, this%n)
628 call device_copy(etaold%x_d, eta%x_d, this%n)
629 call device_copy(muold%x_d, mu%x_d, this%m)
630 zetaold = zeta
631 call device_copy(sold%x_d, s%x_d, this%m)
632
633 new_residual = 2.0_rp * residual_norm
634
635 ! Share the new_residual and steg values
636 call mpi_allreduce(mpi_in_place, steg, 1, &
637 mpi_real_precision, mpi_min, neko_comm, ierr)
638 call mpi_allreduce(mpi_in_place, new_residual, 1, &
639 mpi_real_precision, mpi_min, neko_comm, ierr)
640
641 ! The innermost loop to determine the suitable step length
642 ! using the Backtracking Line Search approach
643 itto = 0
644 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
645 itto = itto + 1
646
647 ! update the variables
648 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
649 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
650 z = zold + steg*dz
651 call device_add3s2(lambda%x_d, lambdaold%x_d, &
652 dlambda%x_d, 1.0_rp, steg, this%m)
653 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
654 1.0_rp, steg, this%n)
655 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
656 1.0_rp, steg, this%n)
657 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
658 1.0_rp, steg, this%m)
659 zeta = zetaold + steg*dzeta
660 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
661 steg, this%m)
662
663 ! Recompute the new_residual to see if this stepsize improves
664 ! the residue
665 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
666 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
667 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
668 eta%x_d, this%n, this%m)
669
670 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
671 call device_add2(rey%x_d, this%c%x_d, this%m)
672 call device_sub2(rey%x_d, lambda%x_d, this%m)
673 call device_sub2(rey%x_d, mu%x_d, this%m)
674
675 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
676
677 ! Accumulate sums for relambda (the term gi(x))
678 call device_cfill(relambda%x_d, 0.0_rp, this%m)
679 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
680 this%low%x_d, this%pij%x_d, this%qij%x_d, &
681 this%n, this%m)
682
683 call device_memcpy(relambda%x, relambda%x_d, this%m, &
684 device_to_host, sync = .true.)
685 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
686 mpi_real_precision, mpi_sum, neko_comm, ierr)
687 call device_memcpy(relambda%x, relambda%x_d, &
688 this%m, host_to_device, sync = .true.)
689
690 call device_add3s2(relambda%x_d, relambda%x_d, &
691 this%a%x_d, 1.0_rp, -z, this%m)
692 call device_sub2(relambda%x_d, y%x_d, this%m)
693 call device_add2(relambda%x_d, s%x_d, this%m)
694 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
695
696 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
697 call device_col2(rexsi%x_d, xsi%x_d, this%n)
698 call device_cadd(rexsi%x_d, - epsi, this%n)
699
700 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
701 call device_col2(reeta%x_d, eta%x_d, this%n)
702 call device_cadd(reeta%x_d, - epsi, this%n)
703
704 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
705 call device_cadd(remu%x_d, - epsi, this%m)
706
707 rezeta = zeta*z - epsi
708
709 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
710 call device_cadd(res%x_d, - epsi, this%m)
711
712 ! Compute squared norms for the residuals
713 re_sq_norm = device_norm(rex%x_d, this%n) + &
714 device_norm(rexsi%x_d, this%n) + &
715 device_norm(reeta%x_d, this%n)
716 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
717 mpi_real_precision, mpi_sum, neko_comm, ierr)
718
719 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
720 rez**2 + &
721 device_norm(relambda%x_d, this%m) + &
722 device_norm(remu%x_d, this%m) + &
723 rezeta**2 + &
724 device_norm(res%x_d, this%m) + &
725 re_sq_norm)
726
727 steg = steg / 2.0_rp
728
729 end do
730 steg = 2.0_rp * steg ! Correction for the final division by 2
731
732 ! Update the maximum and norm of the residuals
733 residual_norm = new_residual
734 residual_max = maxval([ &
735 device_maxval(rex%x_d, this%n), &
736 device_maxval(rey%x_d, this%m), &
737 abs(rez), &
738 device_maxval(relambda%x_d, this%m), &
739 device_maxval(rexsi%x_d, this%n), &
740 device_maxval(reeta%x_d, this%n), &
741 device_maxval(remu%x_d, this%m), &
742 abs(rezeta), &
743 device_maxval(res%x_d, this%m)])
744
745 call mpi_allreduce(mpi_in_place, residual_max, 1, &
746 mpi_real_precision, mpi_max, neko_comm, ierr)
747
748 end do
749
750 epsi = 0.1_rp * epsi
751 end do
752
753 ! Save the new designx
754 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
755 call device_copy(this%xold1%x_d, designx_d, this%n)
756 call device_copy(designx_d, x%x_d, this%n)
757
758 ! update the parameters of the MMA object nesessary to compute KKT residual
759 call device_copy(this%y%x_d, y%x_d, this%m)
760 this%z = z
761 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
762 this%zeta = zeta
763 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
764 call device_copy(this%eta%x_d, eta%x_d, this%n)
765 call device_copy(this%mu%x_d, mu%x_d, this%m)
766 call device_copy(this%s%x_d, s%x_d, this%m)
767
768 !free all the initiated variables in this subroutine
769 call y%free()
770 call lambda%free()
771 call s%free()
772 call mu%free()
773 call rey%free()
774 call relambda%free()
775 call remu%free()
776 call res%free()
777 call dely%free()
778 call dellambda%free()
779 call dy%free()
780 call dlambda%free()
781 call ds%free()
782 call dmu%free()
783 call yold%free()
784 call lambdaold%free()
785 call sold%free()
786 call muold%free()
787 call x%free()
788 call xsi%free()
789 call eta%free()
790 call rex%free()
791 call rexsi%free()
792 call reeta%free()
793 call delx%free()
794 call diagx%free()
795 call dx%free()
796 call dxsi%free()
797 call deta%free()
798 call xold%free()
799 call xsiold%free()
800 call etaold%free()
801 call bb%free()
802
803 end subroutine mma_subsolve_dpip_device
804
807 subroutine mma_subsolve_dip_device(this, designx_d)
808 class(mma_t), intent(inout) :: this
809 type(c_ptr), intent(in) :: designx_d
810 integer :: iter, ierr
811 real(kind=rp) :: epsi, residumax, z, steg
812 ! vectors with size m
813 type(vector_t) :: y, lambda, mu, relambda, remu, dlambda, dmu, &
814 gradlambda, zerom, dd, dummy_m
815 ! vectors with size n
816 type(vector_t) :: x, pjlambda, qjlambda
817
818 ! inverse of a diag matrix:
819 type(vector_t) :: Ljjxinv ! [∇_x^2 Ljj]−1
820 type(matrix_t) :: hijx ! ∇_x hij
821 type(matrix_t) :: Hess
822 real(kind=rp) :: hesstrace
823
824 integer :: info
825 integer, dimension(this%m+1) :: ipiv
826 integer :: i
827
828 real(kind=rp) :: minimal_epsilon
829
830 call y%init(this%m)
831 call lambda%init(this%m)
832 call mu%init(this%m)
833 call relambda%init(this%m)
834 call remu%init(this%m)
835 call dlambda%init(this%m)
836 call dmu%init(this%m)
837 call gradlambda%init(this%m)
838 call zerom%init(this%m)
839 call dd%init(this%m)
840 call dummy_m%init(this%m)
841
842 call x%init(this%n)
843 call pjlambda%init(this%n)
844 call qjlambda%init(this%n)
845
846 call ljjxinv%init(this%n)
847 call hijx%init(this%m,this%n)
848 call hess%init(this%m,this%m)
849
850 call device_cfill(zerom%x_d, 0.0_rp, this%m)
851
852 ! ------------------------------------------------------------------------ !
853 ! initial value for the parameters in the subsolve based on
854 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
855
856 epsi = 1.0_rp !100
857 call device_cfill(y%x_d, 1.0_rp, this%m)
858 ! initialize lambda with an array of ones (change to this%c%x/2 if needed!)
859 call device_cfill(lambda%x_d, 1.0_rp, this%m)
860 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
861 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
862
863 call device_cfill(mu%x_d, 1.0_rp, this%m)
864 z = 0.0_rp
865
866 ! dd is defined as this%d + 1.0e-8_rp, to avoid devision by 0 in computing y
867 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
868
869 ! ------------------------------------------------------------------------ !
870 ! Computing the minimal epsilon and choose the most conservative one
871
872 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
873 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
874 mpi_real_precision, mpi_min, neko_comm, ierr)
875
876 ! ------------------------------------------------------------------------ !
877 ! The main loop of the dual-primal interior point method.
878
879 outer: do while (epsi .gt. minimal_epsilon)
880 ! calculating residuals based on
881 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
882 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
883 associate(p0j => this%p0j, q0j => this%q0j, &
884 pij => this%pij, qij => this%qij, &
885 low => this%low, upp => this%upp, &
886 alpha => this%alpha, beta => this%beta, &
887 c => this%c, a0 => this%a0, a => this%a)
888
889 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
890 ! the initial value of λ
891
892 ! Comput the value of y that minimizes L_y for the current λ
893 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
894 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
895 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
896 ! division by dd to avoid devision by 0 (in case this%d%x_d)
897 call device_invcol2(y%x_d, dd%x_d, this%m)
898 call device_pwmax2(y%x_d, zerom%x_d, this%m)
899
900 ! Comput the value of z that minimizes L_z for the current λ
901 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
902 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
903 ! ensure z>=0
904 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
905 z = device_glsum(dummy_m%x_d, this%m)
906 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
907
908 ! Comput the value of x that minimizes L_x for the current λ
909 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
910 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
911 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
912 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
913 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
914 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
915 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
916
917 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
918 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
919
920 call device_cfill(relambda%x_d, 0.0_rp, this%m)
921 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
922 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
923
924 ! Global comminucation for relambda values
925
926 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
927 sync = .true.)
928 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
929 mpi_real_precision, mpi_sum, neko_comm, ierr)
930 call device_memcpy(relambda%x, relambda%x_d, this%m, &
931 host_to_device, sync = .true.)
932
933 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
934 call device_sub2(relambda%x_d, y%x_d, this%m)
935 call device_add2(relambda%x_d, mu%x_d, this%m)
936 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
937
938 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
939 call device_cadd(remu%x_d, -epsi, this%m)
940
941 ! Download the re(lambda, mu) to CPU to calculate residumax
942
943 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
944 sync = .true.)
945 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
946 sync = .true.)
947 residumax = maxval(abs([relambda%x, remu%x]))
948
949 ! ------------------------------------------------------------------- !
950 ! Internal loop
951 do iter = 1, this%max_iter
952 !Check the condition
953 if (residumax .lt. epsi) exit
954
955 ! Compute dL(x, y, z, λ)/dλ for the updated x(λ), y(λ), z(λ)
956 ! based on the implementation in the following paper by Niels
957 ! https://doi.org/10.1007/s00158-012-0869-2
958 ! (https://github.com/topopt/TopOpt_in_PETSc/blob/master/MMA.cc)
959 ! The formula for gradlambda and relambda are basically the same:
960 ! thus, we utilise gradlambda = relambda - mu for efficiency
961 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
962 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
963
964 ! Update gradlambda as the right hand side for Newton's method(eq10)
965 call device_cfill(dummy_m%x_d, epsi, this%m)
966 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
967 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
968 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
969
970 ! Computing the Hessian as in equation (13) in
971 !! https://doi.org/10.1007/s00158-012-0869-2
972
973 !--------------contributions of x terms to Hess--------------------!
974 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
975 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
976
977 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
978 this%pij%x_d, this%qij%x_d, this%n, this%m)
979
980 call device_memcpy(hijx%x, hijx%x_d, this%n*this%m, device_to_host, &
981 sync = .true.)
982
983 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
984 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
985
986 ! download Hess to CPU, mpi reduce, upload to the device
987 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
988 sync = .true.)
989 call mpi_allreduce(mpi_in_place, hess%x, &
990 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
991 ! No need to upload to device since we solve LSE on CPU
992 ! call device_memcpy(Hess%x, Hess%x_d, this%m*this%m, &
993 ! HOST_TO_DEVICE, sync = .true.)
994
995 !---------------contributions of z terms to Hess-------------------!
996 ! There is no contibution to the Hess from z terms as z terms are
997 ! linear w.r.t λ
998
999
1000 !---------------contributions of y terms to Hess-------------------!
1001 ! Only for inactive constraint, we consider contributions to Hess.
1002 ! Note that if d(i) = 0, the y terms (just like z terms) will not
1003 ! contribute to the Hessian matrix.
1004 ! Note that since we use DGESV to solve LSE on CPU, we dont need
1005 ! cuda kernel for this part
1006
1007 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1008 sync = .true.)
1009 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1010 sync = .true.)
1011 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
1012 sync = .true.)
1013 do i = 1, this%m
1014 if (y%x(i) .gt. 0.0_rp) then
1015 if (abs(this%d%x(i)) < 1.0e-15_rp) then
1016 ! Hess(i, i) = Hess(i, i) - 1.0_rp/1.0e-8_rp
1017 else
1018 hess%x(i, i) = hess%x(i, i) - 1.0_rp/this%d%x(i)
1019 end if
1020 end if
1021 ! Based on eq(10), note the term (-\Omega \Lambda)
1022 hess%x(i, i) = hess%x(i, i) - mu%x(i) / lambda%x(i)
1023 end do
1024
1025 ! Improve the robustness by stablizing the Hess using
1026 ! Levenberg-Marquardt algorithm (heuristically)
1027 hesstrace = 0.0_rp
1028 do i=1, this%m
1029 hesstrace = hesstrace + hess%x(i, i)
1030 end do
1031 do i=1, this%m
1032 hess%x(i,i) = hess%x(i, i) - &
1033 max(-1.0e-4_rp*hesstrace/this%m, 1.0e-7_rp)
1034 end do
1035
1036 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, device_to_host, &
1037 sync = .true.)
1038 call dgesv(this%m , 1, hess%x, this%m , ipiv, &
1039 gradlambda%x, this%m, info)
1040
1041 if (info .ne. 0) then
1042 call neko_error("DGESV failed to solve the linear system in " // &
1043 "mma_subsolve_dip (device).")
1044 end if
1045 call device_memcpy(gradlambda%x, gradlambda%x_d, this%m, host_to_device, &
1046 sync = .true.)
1047
1048 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
1049
1050 ! based on eq(11) for delta eta
1051 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
1052 call device_col2(dummy_m%x_d, mu%x_d, this%m)
1053 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
1054
1055 call device_cfill(dmu%x_d, epsi, this%m)
1056 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
1057 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
1058 call device_sub2(dmu%x_d, mu%x_d, this%m)
1059
1060 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
1061 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
1062 this%m)])
1063 steg = 1.0_rp / steg
1064
1065 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
1066 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
1067
1068 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
1069 sync = .true.)
1070 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
1071 sync = .true.)
1072
1073 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
1074 ! the updated values of λ
1075
1076 ! Comput the value of y that minimizes L_y for the current λ
1077 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
1078 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
1079
1080 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
1081 ! division by dd to avoid devision by 0 (in case this%d%x_d)
1082 call device_invcol2(y%x_d, dd%x_d, this%m)
1083 call device_pwmax2(y%x_d, zerom%x_d, this%m)
1084
1085 ! Comput the value of z that minimizes L_z for the current λ
1086 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
1087 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
1088 ! ensure z>=0
1089 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
1090 z = device_glsum(dummy_m%x_d, this%m)
1091 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1092
1093 ! Comput the value of x that minimizes L_x for the current λ
1094 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
1095 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
1096 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
1097 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1098 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1099 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1100 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1101
1102 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1103 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1104
1105 ! Compute the residual for the lambda and mu using eq(9) and eq(15)
1106
1107 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1108 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1109 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1110
1111 ! Global comminucation for relambda values
1112
1113 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1114 sync = .true.)
1115 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1116 mpi_real_precision, mpi_sum, neko_comm, ierr)
1117 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1118 host_to_device, sync = .true.)
1119
1120 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1121 call device_sub2(relambda%x_d, y%x_d, this%m)
1122 call device_add2(relambda%x_d, mu%x_d, this%m)
1123 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1124
1125 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1126 call device_cadd(remu%x_d, -epsi, this%m)
1127
1128
1130
1131 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1132 sync = .true.)
1133 call device_memcpy(remu%x, remu%x_d, this%m, device_to_host, &
1134 sync = .true.)
1135 residumax = maxval(abs([relambda%x, remu%x]))
1136 end do
1137 end associate
1138 epsi = 0.1_rp * epsi
1139 end do outer
1140
1141 ! Save the new designx
1142 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1143 call device_copy(this%xold1%x_d, designx_d, this%n)
1144 call device_copy(designx_d, x%x_d, this%n)
1145
1146 ! update the parameters of the MMA object nesessary to compute KKT residual
1147 call device_copy(this%y%x_d, y%x_d, this%m)
1148 this%z = z
1149 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1150 call device_copy(this%mu%x_d, mu%x_d, this%m)
1151
1152 call y%free()
1153 call lambda%free()
1154 call mu%free()
1155 call relambda%free()
1156 call remu%free()
1157 call dlambda%free()
1158 call dmu%free()
1159 call gradlambda%free()
1160 call zerom%free()
1161 call dd%free()
1162 call dummy_m%free()
1163
1164 call x%free()
1165 call pjlambda%free()
1166 call qjlambda%free()
1167
1168 call ljjxinv%free()
1169 call hijx%free()
1170 call hess%free()
1171 end subroutine mma_subsolve_dip_device
1172
1173end submodule mma_device