Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma_device.f90
1! Copyright (c) 2025, The Neko-TOP Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32
33submodule(mma) mma_device
34
35 use device_math, only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
38 device_glsum, device_cmult2
39 use device_mma_math, only: device_maxval, device_norm, device_lcsc2, &
40 device_maxval2, device_maxval3, device_mma_gensub3, &
41 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
42 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
43 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
44 device_dy, device_dxsi, device_deta, device_kkt_rex, &
45 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
46 device_mma_ljjxinv, device_hess, device_solve_linear_system, &
47 device_prepare_hessian, device_prepare_aa_matrix
48
49 use neko_config, only: neko_bcknd_device, neko_device_mpi
50 use device, only: device_to_host
51 use comm, only: neko_comm, pe_rank, mpi_real_precision
52 use mpi_f08, only: mpi_in_place, mpi_max, mpi_min
53 use profiler, only: profiler_start_region, profiler_end_region
54 use scratch_registry, only: neko_scratch_registry
55
56 implicit none
57
58contains
59
60 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
61 ! ----------------------------------------------------- !
62 ! Update the design variable x by solving the convex !
63 ! approximation of the problem. !
64 ! !
65 ! This subroutine is called in each iteration of the !
66 ! optimization loop !
67 ! ----------------------------------------------------- !
68 class(mma_t), intent(inout) :: this
69 integer, intent(in) :: iter
70 type(c_ptr), intent(inout) :: x
71 type(c_ptr), intent(in) :: df0dx, fval, dfdx
72
73 if (.not. this%is_initialized) then
74 call neko_error("The MMA object is not initialized.")
75 end if
76
77 call profiler_start_region("MMA gensub")
78 ! generate a convex approximation of the problem
79 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
80 call profiler_end_region("MMA gensub")
81
82 !solve the approximation problem using interior point method
83 call profiler_start_region("MMA subsolve")
84 if (this%subsolver .eq. "dip") then
85 call mma_subsolve_dip_device(this, x)
86 else if (this%subsolver .eq. "dpip") then
87 call mma_subsolve_dpip_device(this, x)
88 else
89 call neko_error("Unrecognized subsolver for MMA in mma_device.")
90 end if
91 call profiler_end_region("MMA subsolve")
92
93 this%is_updated = .true.
94 end subroutine mma_update_device
95
96 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
97 class(mma_t), intent(inout) :: this
98 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
99
100 if (this%subsolver .eq. "dip") then
101 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
102 else
103 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
104 end if
105 end subroutine mma_kkt_device
106
108 ! point method (dip) subsolve of MMA algorithm.
109 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
110 class(mma_t), intent(inout) :: this
111 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
112
113 type(vector_t), pointer :: relambda, remu
114 integer :: ind(2)
115
116 call neko_scratch_registry%request(relambda, ind(1), this%m, .false.)
117 call neko_scratch_registry%request(remu, ind(2), this%m, .false.)
118
119 ! relambda = fval - this%a%x * this%z - this%y%x + this%mu%x
120 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
121 this%m)
122 call device_sub2(relambda%x_d, this%y%x_d, this%m)
123 call device_add2(relambda%x_d, this%mu%x_d, this%m)
124
125 ! Compute residual for mu (eta in the paper)
126 call device_col3(remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
127
128 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
129 device_maxval(remu%x_d, this%m)])
130 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
131 device_norm(remu%x_d, this%m))
132
133 call neko_scratch_registry%relinquish(ind)
134 end subroutine mma_dip_kkt_device
135
137 ! point method (dpip) subsolve of MMA algorithm.
138 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
139 class(mma_t), intent(inout) :: this
140 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
141
142 real(kind=rp) :: rez, rezeta
143 type(vector_t), pointer :: rey, relambda, remu, res
144 type(vector_t), pointer :: rex, rexsi, reeta
145 integer :: ierr, ind(7)
146 real(kind=rp) :: re_sq_norm
147
148 call neko_scratch_registry%request(rey, ind(1), this%m, .false.)
149 call neko_scratch_registry%request(relambda, ind(2), this%m, .false.)
150 call neko_scratch_registry%request(remu, ind(3), this%m, .false.)
151 call neko_scratch_registry%request(res, ind(4), this%m, .false.)
152
153 call neko_scratch_registry%request(rex, ind(5), this%n, .false.)
154 call neko_scratch_registry%request(rexsi, ind(6), this%n, .false.)
155 call neko_scratch_registry%request(reeta, ind(7), this%n, .false.)
156
157 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
158 this%eta%x_d, this%lambda%x_d, this%n, this%m)
159
160 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
161 call device_add2(rey%x_d, this%c%x_d, this%m)
162 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
163 call device_sub2(rey%x_d, this%mu%x_d, this%m)
164
165 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
166 this%m)
167
168 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
169 this%m)
170 call device_sub2(relambda%x_d, this%y%x_d, this%m)
171 call device_add2(relambda%x_d, this%s%x_d, this%m)
172
173 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
174 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
175
176 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
177 call device_col2(reeta%x_d, this%eta%x_d, this%n)
178
179 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
180
181 rezeta = this%zeta * this%z
182
183 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
184
185 this%residumax = maxval([ &
186 device_maxval(rex%x_d, this%n), &
187 device_maxval(rey%x_d, this%m), &
188 abs(rez), &
189 device_maxval(relambda%x_d, this%m), &
190 device_maxval(rexsi%x_d, this%n), &
191 device_maxval(reeta%x_d, this%n), &
192 device_maxval(remu%x_d, this%m), &
193 abs(rezeta), &
194 device_maxval(res%x_d, this%m)])
195
196 re_sq_norm = device_norm(rex%x_d, this%n) + &
197 device_norm(rexsi%x_d, this%n) + &
198 device_norm(reeta%x_d, this%n)
199
200 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
201 mpi_real_precision, mpi_max, neko_comm, ierr)
202
203 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
204 mpi_real_precision, mpi_sum, neko_comm, ierr)
205
206 this%residunorm = sqrt(( &
207 device_norm(rey%x_d, this%m) + &
208 rez**2 + &
209 device_norm(relambda%x_d, this%m) + &
210 device_norm(remu%x_d, this%m) + &
211 rezeta**2 + &
212 device_norm(res%x_d, this%m) &
213 ) + re_sq_norm)
214
215 call neko_scratch_registry%relinquish(ind)
216 end subroutine mma_dpip_kkt_device
217
218 !============================================================================!
219 ! private internal subroutines
220
222 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
223 ! ----------------------------------------------------- !
224 ! Generate the approximation sub problem by computing !
225 ! the lower and upper asymtotes and the other necessary !
226 ! parameters (alpha, beta, p0j, q0j, pij, qij, ...). !
227 ! ----------------------------------------------------- !
228 class(mma_t), intent(inout) :: this
229 type(c_ptr), intent(in) :: x
230 type(c_ptr), intent(in) :: df0dx
231 type(c_ptr), intent(in) :: fval
232 type(c_ptr), intent(in) :: dfdx
233
234 integer, intent(in) :: iter
235 integer :: ierr
236
237 type(vector_t), pointer :: x_diff
238 integer :: ind
239
240 call neko_scratch_registry%request(x_diff, ind, this%n, .false.)
241
242 call device_sub3(x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
243
244 ! ------------------------------------------------------------------------ !
245 ! Setup the current asymptotes
246
247 if (iter .lt. 3) then
248 call device_copy(this%low%x_d, x, this%n)
249 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
250 call device_copy(this%upp%x_d, x, this%n)
251 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
252 else
253 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
254 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
255 this%asydecr, this%asyincr, this%n)
256 end if
257
258 ! ------------------------------------------------------------------------ !
259 ! Calculate p0j, q0j, pij, qij, alpha, and beta
260
261 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
262 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
263 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
264 this%qij%x_d, this%n, this%m)
265
266 ! ------------------------------------------------------------------------ !
267 ! Computing bi as defined in page 5
268
269 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
270 this%qij%x_d, this%n, this%m, this%bi%x_d)
271
272 if (neko_device_mpi) then
273 call mpi_allreduce(mpi_in_place, this%bi%x_d, this%m, &
274 mpi_real_precision, mpi_sum, neko_comm, ierr)
275 else
276 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
277 sync = .true.)
278 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
279 mpi_real_precision, mpi_sum, neko_comm, ierr)
280 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
281 sync = .true.)
282 end if
283 call device_sub2(this%bi%x_d, fval, this%m)
284
285 call neko_scratch_registry%relinquish(ind)
286 end subroutine mma_gensub_device
287
290 subroutine mma_subsolve_dpip_device(this, designx_d)
291 class(mma_t), intent(inout) :: this
292 type(c_ptr), intent(in) :: designx_d
293 integer :: iter, itto, ierr
294 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
295 delz, dz, dzeta, steg, zold, zetaold, new_residual
296 ! vectors with size m
297 type(vector_t) , pointer :: y, lambda, s, mu, rey, relambda, remu, res, &
298 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
299
300 ! vectors with size n
301 type(vector_t), pointer :: x, xsi, eta, rex, rexsi, reeta, &
302 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
303
304 type(vector_t), pointer :: bb
305 type(matrix_t), pointer :: GG
306 type(matrix_t), pointer :: AA
307
308 integer :: info
309 real(kind=rp) :: re_sq_norm
310
311 integer :: ind(35)
312
313 real(kind=rp) :: minimal_epsilon
314
315 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
316 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
317 call neko_scratch_registry%request(s, ind(3), this%m, .false.)
318 call neko_scratch_registry%request(mu, ind(4), this%m, .false.)
319 call neko_scratch_registry%request(rey, ind(5), this%m, .false.)
320 call neko_scratch_registry%request(relambda, ind(6), this%m, .false.)
321 call neko_scratch_registry%request(remu, ind(7), this%m, .false.)
322 call neko_scratch_registry%request(res, ind(8), this%m, .false.)
323 call neko_scratch_registry%request(dely, ind(9), this%m, .false.)
324 call neko_scratch_registry%request(dellambda, ind(10), this%m, .false.)
325 call neko_scratch_registry%request(dy, ind(11), this%m, .false.)
326 call neko_scratch_registry%request(dlambda, ind(12), this%m, .false.)
327 call neko_scratch_registry%request(ds, ind(13), this%m, .false.)
328 call neko_scratch_registry%request(dmu, ind(14), this%m, .false.)
329 call neko_scratch_registry%request(yold, ind(15), this%m, .false.)
330 call neko_scratch_registry%request(lambdaold, ind(16), this%m, .false.)
331 call neko_scratch_registry%request(sold, ind(17), this%m, .false.)
332 call neko_scratch_registry%request(muold, ind(18), this%m, .false.)
333 call neko_scratch_registry%request(x, ind(19), this%n, .false.)
334 call neko_scratch_registry%request(xsi, ind(20), this%n, .false.)
335 call neko_scratch_registry%request(eta, ind(21), this%n, .false.)
336 call neko_scratch_registry%request(rex, ind(22), this%n, .false.)
337 call neko_scratch_registry%request(rexsi, ind(23), this%n, .false.)
338 call neko_scratch_registry%request(reeta, ind(24), this%n, .false.)
339 call neko_scratch_registry%request(delx, ind(25), this%n, .false.)
340 call neko_scratch_registry%request(diagx, ind(26), this%n, .false.)
341 call neko_scratch_registry%request(dx, ind(27), this%n, .false.)
342 call neko_scratch_registry%request(dxsi, ind(28), this%n, .false.)
343 call neko_scratch_registry%request(deta, ind(29), this%n, .false.)
344 call neko_scratch_registry%request(xold, ind(30), this%n, .false.)
345 call neko_scratch_registry%request(xsiold, ind(31), this%n, .false.)
346 call neko_scratch_registry%request(etaold, ind(32), this%n, .false.)
347 call neko_scratch_registry%request(bb, ind(33), this%m+1, .false.)
348
349 call neko_scratch_registry%request(gg, ind(34), this%m, this%n, .false.)
350 call neko_scratch_registry%request(aa, ind(35), this%m+1, this%m+1, .false.)
351
352 ! ------------------------------------------------------------------------ !
353 ! initial value for the parameters in the subsolve based on
354 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
355
356 epsi = 1.0_rp !100
357 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
358 this%n)
359 call device_cfill(y%x_d, 1.0_rp, this%m)
360 z = 1.0_rp
361 zeta = 1.0_rp
362 call device_cfill(lambda%x_d, 1.0_rp, this%m)
363 call device_cfill(s%x_d, 1.0_rp, this%m)
364 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
365 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
366 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
367
368 ! ------------------------------------------------------------------------ !
369 ! Computing the minimal epsilon and choose the most conservative one
370
371 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
372 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
373 mpi_real_precision, mpi_min, neko_comm, ierr)
374
375 ! ------------------------------------------------------------------------ !
376 ! The main loop of the dual-primal interior point method.
377
378 do while (epsi .gt. minimal_epsilon)
379
380 ! --------------------------------------------------------------------- !
381 ! Calculating residuals based on
382 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
383 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
384
385 associate(p0j => this%p0j, q0j => this%q0j, &
386 pij => this%pij, qij => this%qij, &
387 low => this%low, upp => this%upp, &
388 alpha => this%alpha, beta => this%beta, &
389 c => this%c, d => this%d, &
390 a0 => this%a0, a => this%a)
391
392 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
393 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
394 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
395
396 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
397 call device_add2(rey%x_d, c%x_d, this%m)
398 call device_sub2(rey%x_d, lambda%x_d, this%m)
399 call device_sub2(rey%x_d, mu%x_d, this%m)
400 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
401
402 call device_cfill(relambda%x_d, 0.0_rp, this%m)
403 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
404 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
405
406 end associate
407
408 ! --------------------------------------------------------------------- !
409 ! Computing the norm of the residuals
410
411 ! Complete the computations of lambda residuals
412 if (neko_device_mpi) then
413 call mpi_allreduce(mpi_in_place, relambda%x_d, this%m, &
414 mpi_real_precision, mpi_sum, neko_comm, ierr)
415 else
416 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
417 sync = .true.)
418 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
419 mpi_real_precision, mpi_sum, neko_comm, ierr)
420 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
421 sync = .true.)
422 end if
423
424 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
425 call device_sub2(relambda%x_d, y%x_d, this%m)
426 call device_add2(relambda%x_d, s%x_d, this%m)
427 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
428
429 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
430 call device_col2(rexsi%x_d, xsi%x_d, this%n)
431 call device_cadd(rexsi%x_d, - epsi, this%n)
432
433 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
434 call device_col2(reeta%x_d, eta%x_d, this%n)
435 call device_cadd(reeta%x_d, - epsi, this%n)
436
437 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
438 call device_cadd(remu%x_d, - epsi, this%m)
439
440 rezeta = zeta * z - epsi
441
442 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
443 call device_cadd(res%x_d, - epsi, this%m)
444
445 ! Setup vectors of residuals and their norms
446 residual_max = maxval([device_maxval(rex%x_d, this%n), &
447 device_maxval(rey%x_d, this%m), abs(rez), &
448 device_maxval(relambda%x_d, this%m), &
449 device_maxval(rexsi%x_d, this%n), &
450 device_maxval(reeta%x_d, this%n), &
451 device_maxval(remu%x_d, this%m), abs(rezeta), &
452 device_maxval(res%x_d, this%m)])
453
454 re_sq_norm = device_norm(rex%x_d, this%n) + &
455 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
456
457 call mpi_allreduce(mpi_in_place, residual_max, 1, &
458 mpi_real_precision, mpi_max, neko_comm, ierr)
459
460 call mpi_allreduce(mpi_in_place, re_sq_norm, &
461 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
462
463 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
464 rez**2 + &
465 device_norm(relambda%x_d, this%m) + &
466 device_norm(remu%x_d, this%m)+ &
467 rezeta**2 + &
468 device_norm(res%x_d, this%m) &
469 + re_sq_norm)
470
471 ! --------------------------------------------------------------------- !
472 ! Internal loop
473
474 do iter = 1, this%max_iter
475
476 if (residual_max .lt. epsi) exit
477
478 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
479 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
480 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
481 this%m)
482
483 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
484 call device_add2(dely%x_d, this%c%x_d, this%m)
485 call device_sub2(dely%x_d, lambda%x_d, this%m)
486 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
487 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
488
489 ! Accumulate sums for dellambda (the term gi(x))
490 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
491 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
492 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
493
494 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
495 device_to_host, sync = .true.)
496 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
497 mpi_real_precision, mpi_sum, neko_comm, ierr)
498 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
499 host_to_device, sync = .true.)
500
501 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
502 1.0_rp, -z, this%m)
503 call device_sub2(dellambda%x_d, y%x_d, this%m)
504 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
505 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
506
507 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
508 this%pij%x_d, this%qij%x_d, this%n, this%m)
509
510 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
511 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
512 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
513 lambda%x_d, this%n, this%m)
514
515 !Here we only consider the case m<n in the matlab code
516 !assembling the right hand side matrix based on eq(5.20)
517 ! bb = [dellambda + dely/(this%d%x + &
518 ! (mu/y)) - matmul(GG,delx/diagx), delz ]
519
520 !--------------------------------------------------------------------!
521 ! for MPI computation of bb
522
523 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
524 this%m)
525
526 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
527 sync = .true.)
528 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
529 mpi_real_precision, mpi_sum, neko_comm, ierr)
530 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
531 host_to_device, sync = .true.)
532
533 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
534 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
535
536 ! assembling the coefficients matrix AA based on eq(5.20)
537 ! AA(1:this%m,1:this%m) = &
538 ! matmul(matmul(GG,mma_diag(1/diagx)), transpose(GG))
539 ! !update diag(AA)
540 ! AA(1:this%m,1:this%m) = AA(1:this%m,1:this%m) + &
541 ! mma_diag(s/lambda + 1.0/(this%d%x + (mu/y)))
542
543 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
544 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
545
546 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
547 device_to_host, sync = .true.)
548 call mpi_allreduce(mpi_in_place, aa%x, &
549 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
550 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
551 host_to_device, sync = .true.)
552
553 call device_prepare_aa_matrix(aa%x_d, s%x_d, lambda%x_d, &
554 this%d%x_d, mu%x_d, y%x_d, this%a%x_d, zeta, z, this%m)
555
556 ! Device solve for the linear system
557 call device_solve_linear_system(aa%x_d, bb%x_d, this%m + 1, info)
558 if (info .ne. 0) then
559 call neko_error("Linear solver failed on the device in " // &
560 "mma_subsolve_dpip")
561 end if
562
563 call device_copy(dlambda%x_d, bb%x_d, this%m)
564
565
566 !We need to write the last element of bb to dz so this is necessary
567 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
568 sync = .true.)
569 dz = bb%x(this%m + 1)
570
571
572 ! based on eq(5.19)
573 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
574 dlambda%x_d, this%n, this%m)
575 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
576 mu%x_d, y%x_d, this%m)
577 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
578 this%alpha%x_d, epsi, this%n)
579 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
580 this%beta%x_d, epsi, this%n)
581
582 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
583 call device_cmult(dmu%x_d, -1.0_rp, this%m)
584 call device_cadd(dmu%x_d, epsi, this%m)
585 call device_invcol2(dmu%x_d, y%x_d, this%m)
586 call device_sub2(dmu%x_d, mu%x_d, this%m)
587 dzeta = -zeta + (epsi - zeta * dz) / z
588 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
589 call device_cmult(ds%x_d, -1.0_rp, this%m)
590 call device_cadd(ds%x_d, epsi, this%m)
591 call device_invcol2(ds%x_d, lambda%x_d, this%m)
592 call device_sub2(ds%x_d, s%x_d, this%m)
593
594 steg = maxval([1.0_rp, &
595 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
596 -1.01_rp * dz / z, &
597 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
598 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
599 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
600 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
601 -1.01_rp * dzeta / zeta, &
602 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
603 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
604 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
605
606 steg = 1.0_rp / steg
607
608 call device_copy(xold%x_d, x%x_d, this%n)
609 call device_copy(yold%x_d, y%x_d, this%m)
610 zold = z
611 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
612 call device_copy(xsiold%x_d, xsi%x_d, this%n)
613 call device_copy(etaold%x_d, eta%x_d, this%n)
614 call device_copy(muold%x_d, mu%x_d, this%m)
615 zetaold = zeta
616 call device_copy(sold%x_d, s%x_d, this%m)
617
618 new_residual = 2.0_rp * residual_norm
619
620 ! Share the new_residual and steg values
621 call mpi_allreduce(mpi_in_place, steg, 1, &
622 mpi_real_precision, mpi_min, neko_comm, ierr)
623 call mpi_allreduce(mpi_in_place, new_residual, 1, &
624 mpi_real_precision, mpi_min, neko_comm, ierr)
625
626 ! The innermost loop to determine the suitable step length
627 ! using the Backtracking Line Search approach
628 itto = 0
629 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
630 itto = itto + 1
631
632 ! update the variables
633 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
634 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
635 z = zold + steg*dz
636 call device_add3s2(lambda%x_d, lambdaold%x_d, &
637 dlambda%x_d, 1.0_rp, steg, this%m)
638 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
639 1.0_rp, steg, this%n)
640 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
641 1.0_rp, steg, this%n)
642 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
643 1.0_rp, steg, this%m)
644 zeta = zetaold + steg*dzeta
645 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
646 steg, this%m)
647
648 ! Recompute the new_residual to see if this stepsize improves
649 ! the residue
650 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
651 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
652 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
653 eta%x_d, this%n, this%m)
654
655 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
656 call device_add2(rey%x_d, this%c%x_d, this%m)
657 call device_sub2(rey%x_d, lambda%x_d, this%m)
658 call device_sub2(rey%x_d, mu%x_d, this%m)
659
660 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
661
662 ! Accumulate sums for relambda (the term gi(x))
663 call device_cfill(relambda%x_d, 0.0_rp, this%m)
664 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
665 this%low%x_d, this%pij%x_d, this%qij%x_d, &
666 this%n, this%m)
667
668 call device_memcpy(relambda%x, relambda%x_d, this%m, &
669 device_to_host, sync = .true.)
670 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
671 mpi_real_precision, mpi_sum, neko_comm, ierr)
672 call device_memcpy(relambda%x, relambda%x_d, &
673 this%m, host_to_device, sync = .true.)
674
675 call device_add3s2(relambda%x_d, relambda%x_d, &
676 this%a%x_d, 1.0_rp, -z, this%m)
677 call device_sub2(relambda%x_d, y%x_d, this%m)
678 call device_add2(relambda%x_d, s%x_d, this%m)
679 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
680
681 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
682 call device_col2(rexsi%x_d, xsi%x_d, this%n)
683 call device_cadd(rexsi%x_d, - epsi, this%n)
684
685 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
686 call device_col2(reeta%x_d, eta%x_d, this%n)
687 call device_cadd(reeta%x_d, - epsi, this%n)
688
689 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
690 call device_cadd(remu%x_d, - epsi, this%m)
691
692 rezeta = zeta*z - epsi
693
694 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
695 call device_cadd(res%x_d, - epsi, this%m)
696
697 ! Compute squared norms for the residuals
698 re_sq_norm = device_norm(rex%x_d, this%n) + &
699 device_norm(rexsi%x_d, this%n) + &
700 device_norm(reeta%x_d, this%n)
701 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
702 mpi_real_precision, mpi_sum, neko_comm, ierr)
703
704 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
705 rez**2 + &
706 device_norm(relambda%x_d, this%m) + &
707 device_norm(remu%x_d, this%m) + &
708 rezeta**2 + &
709 device_norm(res%x_d, this%m) + &
710 re_sq_norm)
711
712 steg = steg / 2.0_rp
713
714 end do
715 steg = 2.0_rp * steg ! Correction for the final division by 2
716
717 ! Update the maximum and norm of the residuals
718 residual_norm = new_residual
719 residual_max = maxval([ &
720 device_maxval(rex%x_d, this%n), &
721 device_maxval(rey%x_d, this%m), &
722 abs(rez), &
723 device_maxval(relambda%x_d, this%m), &
724 device_maxval(rexsi%x_d, this%n), &
725 device_maxval(reeta%x_d, this%n), &
726 device_maxval(remu%x_d, this%m), &
727 abs(rezeta), &
728 device_maxval(res%x_d, this%m)])
729
730 call mpi_allreduce(mpi_in_place, residual_max, 1, &
731 mpi_real_precision, mpi_max, neko_comm, ierr)
732
733 end do
734
735 epsi = 0.1_rp * epsi
736 end do
737
738 ! Save the new designx
739 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
740 call device_copy(this%xold1%x_d, designx_d, this%n)
741 call device_copy(designx_d, x%x_d, this%n)
742
743 ! update the parameters of the MMA object nesessary to compute KKT residual
744 call device_copy(this%y%x_d, y%x_d, this%m)
745 this%z = z
746 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
747 this%zeta = zeta
748 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
749 call device_copy(this%eta%x_d, eta%x_d, this%n)
750 call device_copy(this%mu%x_d, mu%x_d, this%m)
751 call device_copy(this%s%x_d, s%x_d, this%m)
752
753 !free all the initiated variables in this subroutine
754 call neko_scratch_registry%relinquish(ind)
755 end subroutine mma_subsolve_dpip_device
756
759 subroutine mma_subsolve_dip_device(this, designx_d)
760 class(mma_t), intent(inout) :: this
761 type(c_ptr), intent(in) :: designx_d
762 integer :: iter, ierr
763 real(kind=rp) :: epsi, residumax, z, steg
764 ! vectors with size m
765 type(vector_t), pointer :: y, lambda, mu, relambda, remu, dlambda, dmu, &
766 gradlambda, zerom, dd, dummy_m
767 ! vectors with size n
768 type(vector_t), pointer :: x, pjlambda, qjlambda
769
770 ! inverse of a diag matrix:
771 type(vector_t), pointer :: Ljjxinv ! [∇_x^2 Ljj]−1
772 type(matrix_t), pointer :: hijx ! ∇_x hij
773 type(matrix_t), pointer :: Hess
774
775 integer :: info, ind(17)
776
777 real(kind=rp) :: minimal_epsilon
778
779 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
780 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
781 call neko_scratch_registry%request(mu, ind(3), this%m, .false.)
782 call neko_scratch_registry%request(relambda, ind(4), this%m, .false.)
783 call neko_scratch_registry%request(remu, ind(5), this%m, .false.)
784 call neko_scratch_registry%request(dlambda, ind(6), this%m, .false.)
785 call neko_scratch_registry%request(dmu, ind(7), this%m, .false.)
786 call neko_scratch_registry%request(gradlambda, ind(8), this%m, .false.)
787 call neko_scratch_registry%request(zerom, ind(9), this%m, .false.)
788 call neko_scratch_registry%request(dd, ind(10), this%m, .false.)
789 call neko_scratch_registry%request(dummy_m, ind(11), this%m, .false.)
790
791 call neko_scratch_registry%request(x, ind(12), this%n, .false.)
792 call neko_scratch_registry%request(pjlambda,ind(13), this%n, .false.)
793 call neko_scratch_registry%request(qjlambda, ind(14), this%n, .false.)
794
795 call neko_scratch_registry%request(ljjxinv, ind(15), this%n, .false.)
796
797 call neko_scratch_registry%request(hijx, ind(16), this%m, this%n, .false.)
798 call neko_scratch_registry%request(hess, ind(17), this%m, this%m, .false.)
799
800 ! ------------------------------------------------------------------------ !
801 ! initial value for the parameters in the subsolve based on
802 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
803
804 epsi = 1.0_rp !100
805 call device_cfill(y%x_d, 1.0_rp, this%m)
806 ! initialize lambda with an array of ones (change to this%c%x/2 if needed!)
807 call device_cfill(lambda%x_d, 1.0_rp, this%m)
808 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
809 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
810
811 call device_cfill(mu%x_d, 1.0_rp, this%m)
812 z = 0.0_rp
813
814 ! dd is defined as this%d + 1.0e-8_rp, to avoid devision by 0 in computing y
815 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
816
817 ! ------------------------------------------------------------------------ !
818 ! Computing the minimal epsilon and choose the most conservative one
819
820 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
821 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
822 mpi_real_precision, mpi_min, neko_comm, ierr)
823
824 ! ------------------------------------------------------------------------ !
825 ! The main loop of the dual-primal interior point method.
826
827 outer: do while (epsi .gt. minimal_epsilon)
828 ! calculating residuals based on
829 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
830 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
831 associate(p0j => this%p0j, q0j => this%q0j, &
832 pij => this%pij, qij => this%qij, &
833 low => this%low, upp => this%upp, &
834 alpha => this%alpha, beta => this%beta, &
835 c => this%c, a0 => this%a0, a => this%a)
836
837 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
838 ! the initial value of λ
839
840 ! Comput the value of y that minimizes L_y for the current λ
841 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
842 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
843 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
844 ! division by dd to avoid devision by 0 (in case this%d%x_d = 0)
845 call device_invcol2(y%x_d, dd%x_d, this%m)
846 call device_pwmax2(y%x_d, zerom%x_d, this%m)
847
848 ! Comput the value of z that minimizes L_z for the current λ
849 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
850 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
851 ! ensure z>=0
852 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
853 z = device_glsum(dummy_m%x_d, this%m)
854 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
855
856 ! Comput the value of x that minimizes L_x for the current λ
857 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
858 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
859 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
860 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
861 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
862 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
863 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
864
865 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
866 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
867
868 call device_cfill(relambda%x_d, 0.0_rp, this%m)
869 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
870 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
871
872 ! Global comminucation for relambda values
873
874 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
875 sync = .true.)
876 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
877 mpi_real_precision, mpi_sum, neko_comm, ierr)
878 call device_memcpy(relambda%x, relambda%x_d, this%m, &
879 host_to_device, sync = .true.)
880
881 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
882 call device_sub2(relambda%x_d, y%x_d, this%m)
883 call device_add2(relambda%x_d, mu%x_d, this%m)
884 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
885
886 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
887 call device_cadd(remu%x_d, -epsi, this%m)
888
889 residumax = maxval([device_maxval(relambda%x_d, this%m), &
890 device_maxval(remu%x_d, this%m)])
891
892 ! ------------------------------------------------------------------- !
893 ! Internal loop
894 do iter = 1, this%max_iter
895 !Check the condition
896 if (residumax .lt. epsi) exit
897
898 ! Compute dL(x, y, z, λ)/dλ for the updated x(λ), y(λ), z(λ)
899 ! based on the implementation in the following paper by Niels
900 ! https://doi.org/10.1007/s00158-012-0869-2
901 ! (https://github.com/topopt/TopOpt_in_PETSc/blob/master/MMA.cc)
902 ! The formula for gradlambda and relambda are basically the same:
903 ! thus, we utilise gradlambda = relambda - mu for efficiency
904 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
905 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
906
907 ! Update gradlambda as the right hand side for Newton's method(eq10)
908 call device_cfill(dummy_m%x_d, epsi, this%m)
909 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
910 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
911 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
912
913 ! Computing the Hessian as in equation (13) in
914 !! https://doi.org/10.1007/s00158-012-0869-2
915
916 !--------------contributions of x terms to Hess--------------------!
917 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
918 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
919
920 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
921 this%pij%x_d, this%qij%x_d, this%n, this%m)
922
923 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
924 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
925
926 ! download Hess to CPU, mpi reduce, upload to the device
927 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
928 sync = .true.)
929 call mpi_allreduce(mpi_in_place, hess%x, &
930 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
931 ! No need to upload to device since we solve LSE on CPU
932 ! But now we solve LSE on GPU, so upload it:
933 call device_memcpy(hess%x, hess%x_d, this%m*this%m, &
934 host_to_device, sync = .true.)
935
936 !---------------contributions of z terms to Hess-------------------!
937 ! There is no contibution to the Hess from z terms as z terms are
938 ! linear w.r.t λ
939
940
941 !---------------contributions of y terms to Hess-------------------!
942 ! Only for inactive constraint, we consider contributions to Hess.
943 ! Note that if d(i) = 0, the y terms (just like z terms) will not
944 ! contribute to the Hessian matrix.
945 ! Note that since we use DGESV to solve LSE on CPU, we dont need
946 ! cuda kernel for this part
947 ! Also, improve the robustness by stablizing the Hess using
948 ! Levenberg-Marquardt algorithm (heuristically)
949 call device_prepare_hessian(hess%x_d, y%x_d, this%d%x_d, &
950 mu%x_d, lambda%x_d, this%m)
951
952 ! Device solve for the linear system
953 call device_solve_linear_system(hess%x_d, gradlambda%x_d, &
954 this%m, info)
955 if (info .ne. 0) then
956 call neko_error("Linear solver failed on the device in " // &
957 "mma_subsolve_dip")
958 end if
959
960 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
961
962 ! based on eq(11) for delta eta
963 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
964 call device_col2(dummy_m%x_d, mu%x_d, this%m)
965 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
966
967 call device_cfill(dmu%x_d, epsi, this%m)
968 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
969 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
970 call device_sub2(dmu%x_d, mu%x_d, this%m)
971
972 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
973 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
974 this%m)])
975 steg = 1.0_rp / steg
976
977 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
978 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
979
980 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
981 ! the updated values of λ
982
983 ! Comput the value of y that minimizes L_y for the current λ
984 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
985 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
986
987 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
988 ! division by dd to avoid devision by 0 (in case this%d%x_d = 0)
989 call device_invcol2(y%x_d, dd%x_d, this%m)
990 call device_pwmax2(y%x_d, zerom%x_d, this%m)
991
992 ! Comput the value of z that minimizes L_z for the current λ
993 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
994 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
995 ! ensure z>=0
996 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
997 z = device_glsum(dummy_m%x_d, this%m)
998 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
999
1000 ! Comput the value of x that minimizes L_x for the current λ
1001 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
1002 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
1003 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
1004 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1005 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1006 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1007 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1008
1009 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1010 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1011
1012 ! Compute the residual for the lambda and mu using eq(9) and eq(15)
1013
1014 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1015 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1016 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1017
1018 ! Global comminucation for relambda values
1019
1020 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1021 sync = .true.)
1022 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1023 mpi_real_precision, mpi_sum, neko_comm, ierr)
1024 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1025 host_to_device, sync = .true.)
1026
1027 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1028 call device_sub2(relambda%x_d, y%x_d, this%m)
1029 call device_add2(relambda%x_d, mu%x_d, this%m)
1030 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1031
1032 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1033 call device_cadd(remu%x_d, -epsi, this%m)
1034
1035 residumax = maxval([device_maxval(relambda%x_d, this%m), &
1036 device_maxval(remu%x_d, this%m)])
1037 end do
1038 end associate
1039 epsi = 0.1_rp * epsi
1040 end do outer
1041
1042 ! Save the new designx
1043 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1044 call device_copy(this%xold1%x_d, designx_d, this%n)
1045 call device_copy(designx_d, x%x_d, this%n)
1046
1047 ! update the parameters of the MMA object nesessary to compute KKT residual
1048 call device_copy(this%y%x_d, y%x_d, this%m)
1049 this%z = z
1050 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1051 call device_copy(this%mu%x_d, mu%x_d, this%m)
1052
1053 call neko_scratch_registry%relinquish(ind)
1054 end subroutine mma_subsolve_dip_device
1055
1056end submodule mma_device