Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma_device.f90
1
34
35submodule(mma) mma_device
36
37 use device_math, only: device_copy, device_cmult, device_cadd, device_cfill, &
38 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
39 device_sub2, device_sub3, device_add2s2, device_cadd2, device_pwmax2, &
40 device_glsum, device_cmult2
41 use device_mma_math, only: device_maxval, device_norm, device_lcsc2, &
42 device_maxval2, device_maxval3, device_mma_gensub3, &
43 device_mma_gensub4, device_mma_max, device_max2, device_rex, &
44 device_relambda, device_delx, device_add2inv2, device_gg, device_diagx, &
45 device_bb, device_updatebb, device_aa, device_updateaa, device_dx, &
46 device_dy, device_dxsi, device_deta, device_kkt_rex, &
47 device_mma_gensub2, device_mattrans_v_mul, device_mma_dipsolvesub1, &
48 device_mma_ljjxinv, device_hess, device_solve_linear_system, &
49 device_prepare_hessian, device_prepare_aa_matrix
50
51 use neko_config, only: neko_bcknd_device, neko_device_mpi
52 use device, only: device_to_host
53 use comm, only: neko_comm, pe_rank, mpi_real_precision
54 use mpi_f08, only: mpi_in_place, mpi_max, mpi_min
55 use profiler, only: profiler_start_region, profiler_end_region
56 use scratch_registry, only: neko_scratch_registry
57
58 implicit none
59
60contains
61
62 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
63 ! ----------------------------------------------------- !
64 ! Update the design variable x by solving the convex !
65 ! approximation of the problem. !
66 ! !
67 ! This subroutine is called in each iteration of the !
68 ! optimization loop !
69 ! ----------------------------------------------------- !
70 class(mma_t), intent(inout) :: this
71 integer, intent(in) :: iter
72 type(c_ptr), intent(inout) :: x
73 type(c_ptr), intent(in) :: df0dx, fval, dfdx
74
75 if (.not. this%is_initialized) then
76 call neko_error("The MMA object is not initialized.")
77 end if
78
79 call profiler_start_region("MMA gensub")
80 ! generate a convex approximation of the problem
81 call mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
82 call profiler_end_region("MMA gensub")
83
84 !solve the approximation problem using interior point method
85 call profiler_start_region("MMA subsolve")
86 if (this%subsolver .eq. "dip") then
87 call mma_subsolve_dip_device(this, x)
88 else if (this%subsolver .eq. "dpip") then
89 call mma_subsolve_dpip_device(this, x)
90 else
91 call neko_error("Unrecognized subsolver for MMA in mma_device.")
92 end if
93 call profiler_end_region("MMA subsolve")
94
95 this%is_updated = .true.
96 end subroutine mma_update_device
97
98 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
99 class(mma_t), intent(inout) :: this
100 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
101
102 if (this%subsolver .eq. "dip") then
103 call mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
104 else
105 call mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
106 end if
107 end subroutine mma_kkt_device
108
110 ! point method (dip) subsolve of MMA algorithm.
111 module subroutine mma_dip_kkt_device(this, x, df0dx, fval, dfdx)
112 class(mma_t), intent(inout) :: this
113 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
114
115 type(vector_t), pointer :: relambda, remu
116 integer :: ind(2)
117
118 call neko_scratch_registry%request(relambda, ind(1), this%m, .false.)
119 call neko_scratch_registry%request(remu, ind(2), this%m, .false.)
120
121 ! relambda = fval - this%a%x * this%z - this%y%x + this%mu%x
122 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
123 this%m)
124 call device_sub2(relambda%x_d, this%y%x_d, this%m)
125 call device_add2(relambda%x_d, this%mu%x_d, this%m)
126
127 ! Compute residual for mu (eta in the paper)
128 call device_col3(remu%x_d, this%lambda%x_d, this%mu%x_d, this%m)
129
130 this%residumax = maxval([device_maxval(relambda%x_d, this%m), &
131 device_maxval(remu%x_d, this%m)])
132 this%residunorm = sqrt(device_norm(relambda%x_d, this%m)+ &
133 device_norm(remu%x_d, this%m))
134
135 call neko_scratch_registry%relinquish(ind)
136 end subroutine mma_dip_kkt_device
137
139 ! point method (dpip) subsolve of MMA algorithm.
140 module subroutine mma_dpip_kkt_device(this, x, df0dx, fval, dfdx)
141 class(mma_t), intent(inout) :: this
142 type(c_ptr), intent(in) :: x, df0dx, fval, dfdx
143
144 real(kind=rp) :: rez, rezeta
145 type(vector_t), pointer :: rey, relambda, remu, res
146 type(vector_t), pointer :: rex, rexsi, reeta
147 integer :: ierr, ind(7)
148 real(kind=rp) :: re_sq_norm
149
150 call neko_scratch_registry%request(rey, ind(1), this%m, .false.)
151 call neko_scratch_registry%request(relambda, ind(2), this%m, .false.)
152 call neko_scratch_registry%request(remu, ind(3), this%m, .false.)
153 call neko_scratch_registry%request(res, ind(4), this%m, .false.)
154
155 call neko_scratch_registry%request(rex, ind(5), this%n, .false.)
156 call neko_scratch_registry%request(rexsi, ind(6), this%n, .false.)
157 call neko_scratch_registry%request(reeta, ind(7), this%n, .false.)
158
159 call device_kkt_rex(rex%x_d, df0dx, dfdx, this%xsi%x_d, &
160 this%eta%x_d, this%lambda%x_d, this%n, this%m)
161
162 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
163 call device_add2(rey%x_d, this%c%x_d, this%m)
164 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
165 call device_sub2(rey%x_d, this%mu%x_d, this%m)
166
167 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
168 this%m)
169
170 call device_add3s2(relambda%x_d, fval, this%a%x_d, 1.0_rp, -this%z, &
171 this%m)
172 call device_sub2(relambda%x_d, this%y%x_d, this%m)
173 call device_add2(relambda%x_d, this%s%x_d, this%m)
174
175 call device_sub3(rexsi%x_d, x, this%xmin%x_d, this%n)
176 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
177
178 call device_sub3(reeta%x_d, this%xmax%x_d, x, this%n)
179 call device_col2(reeta%x_d, this%eta%x_d, this%n)
180
181 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
182
183 rezeta = this%zeta * this%z
184
185 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
186
187 this%residumax = maxval([ &
188 device_maxval(rex%x_d, this%n), &
189 device_maxval(rey%x_d, this%m), &
190 abs(rez), &
191 device_maxval(relambda%x_d, this%m), &
192 device_maxval(rexsi%x_d, this%n), &
193 device_maxval(reeta%x_d, this%n), &
194 device_maxval(remu%x_d, this%m), &
195 abs(rezeta), &
196 device_maxval(res%x_d, this%m)])
197
198 re_sq_norm = device_norm(rex%x_d, this%n) + &
199 device_norm(rexsi%x_d, this%n) + &
200 device_norm(reeta%x_d, this%n)
201
202 call mpi_allreduce(mpi_in_place, this%residumax, 1, &
203 mpi_real_precision, mpi_max, neko_comm, ierr)
204
205 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
206 mpi_real_precision, mpi_sum, neko_comm, ierr)
207
208 this%residunorm = sqrt(( &
209 device_norm(rey%x_d, this%m) + &
210 rez**2 + &
211 device_norm(relambda%x_d, this%m) + &
212 device_norm(remu%x_d, this%m) + &
213 rezeta**2 + &
214 device_norm(res%x_d, this%m) &
215 ) + re_sq_norm)
216
217 call neko_scratch_registry%relinquish(ind)
218 end subroutine mma_dpip_kkt_device
219
220 !============================================================================!
221 ! private internal subroutines
222
224 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
225 ! ----------------------------------------------------- !
226 ! Generate the approximation sub problem by computing !
227 ! the lower and upper asymtotes and the other necessary !
228 ! parameters (alpha, beta, p0j, q0j, pij, qij, ...). !
229 ! ----------------------------------------------------- !
230 class(mma_t), intent(inout) :: this
231 type(c_ptr), intent(in) :: x
232 type(c_ptr), intent(in) :: df0dx
233 type(c_ptr), intent(in) :: fval
234 type(c_ptr), intent(in) :: dfdx
235
236 integer, intent(in) :: iter
237 integer :: ierr
238
239 type(vector_t), pointer :: x_diff
240 integer :: ind
241
242 call neko_scratch_registry%request(x_diff, ind, this%n, .false.)
243
244 call device_sub3(x_diff%x_d, this%xmax%x_d, this%xmin%x_d, this%n)
245
246 ! ------------------------------------------------------------------------ !
247 ! Setup the current asymptotes
248
249 if (iter .lt. 3) then
250 call device_copy(this%low%x_d, x, this%n)
251 call device_add2s2(this%low%x_d, x_diff%x_d, - this%asyinit, this%n)
252 call device_copy(this%upp%x_d, x, this%n)
253 call device_add2s2(this%upp%x_d, x_diff%x_d, this%asyinit, this%n)
254 else
255 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x, &
256 this%xold1%x_d, this%xold2%x_d, x_diff%x_d, &
257 this%asydecr, this%asyincr, this%n)
258 end if
259
260 ! ------------------------------------------------------------------------ !
261 ! Calculate p0j, q0j, pij, qij, alpha, and beta
262
263 call device_mma_gensub3(x, df0dx, dfdx, this%low%x_d, &
264 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
265 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
266 this%qij%x_d, this%n, this%m)
267
268 ! ------------------------------------------------------------------------ !
269 ! Computing bi as defined in page 5
270
271 call device_mma_gensub4(x, this%low%x_d, this%upp%x_d, this%pij%x_d, &
272 this%qij%x_d, this%n, this%m, this%bi%x_d)
273
274 if (neko_device_mpi) then
275 call mpi_allreduce(mpi_in_place, this%bi%x_d, this%m, &
276 mpi_real_precision, mpi_sum, neko_comm, ierr)
277 else
278 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
279 sync = .true.)
280 call mpi_allreduce(mpi_in_place, this%bi%x, this%m, &
281 mpi_real_precision, mpi_sum, neko_comm, ierr)
282 call device_memcpy(this%bi%x, this%bi%x_d, this%m, host_to_device, &
283 sync = .true.)
284 end if
285 call device_sub2(this%bi%x_d, fval, this%m)
286
287 call neko_scratch_registry%relinquish(ind)
288 end subroutine mma_gensub_device
289
292 subroutine mma_subsolve_dpip_device(this, designx_d)
293 class(mma_t), intent(inout) :: this
294 type(c_ptr), intent(in) :: designx_d
295 integer :: iter, itto, ierr
296 real(kind=rp) :: epsi, residual_max, residual_norm, z, zeta, rez, rezeta, &
297 delz, dz, dzeta, steg, zold, zetaold, new_residual
298 ! vectors with size m
299 type(vector_t) , pointer :: y, lambda, s, mu, rey, relambda, remu, res, &
300 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
301
302 ! vectors with size n
303 type(vector_t), pointer :: x, xsi, eta, rex, rexsi, reeta, &
304 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
305
306 type(vector_t), pointer :: bb
307 type(matrix_t), pointer :: GG
308 type(matrix_t), pointer :: AA
309
310 integer :: info
311 real(kind=rp) :: re_sq_norm
312
313 integer :: ind(35)
314
315 real(kind=rp) :: minimal_epsilon
316
317 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
318 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
319 call neko_scratch_registry%request(s, ind(3), this%m, .false.)
320 call neko_scratch_registry%request(mu, ind(4), this%m, .false.)
321 call neko_scratch_registry%request(rey, ind(5), this%m, .false.)
322 call neko_scratch_registry%request(relambda, ind(6), this%m, .false.)
323 call neko_scratch_registry%request(remu, ind(7), this%m, .false.)
324 call neko_scratch_registry%request(res, ind(8), this%m, .false.)
325 call neko_scratch_registry%request(dely, ind(9), this%m, .false.)
326 call neko_scratch_registry%request(dellambda, ind(10), this%m, .false.)
327 call neko_scratch_registry%request(dy, ind(11), this%m, .false.)
328 call neko_scratch_registry%request(dlambda, ind(12), this%m, .false.)
329 call neko_scratch_registry%request(ds, ind(13), this%m, .false.)
330 call neko_scratch_registry%request(dmu, ind(14), this%m, .false.)
331 call neko_scratch_registry%request(yold, ind(15), this%m, .false.)
332 call neko_scratch_registry%request(lambdaold, ind(16), this%m, .false.)
333 call neko_scratch_registry%request(sold, ind(17), this%m, .false.)
334 call neko_scratch_registry%request(muold, ind(18), this%m, .false.)
335 call neko_scratch_registry%request(x, ind(19), this%n, .false.)
336 call neko_scratch_registry%request(xsi, ind(20), this%n, .false.)
337 call neko_scratch_registry%request(eta, ind(21), this%n, .false.)
338 call neko_scratch_registry%request(rex, ind(22), this%n, .false.)
339 call neko_scratch_registry%request(rexsi, ind(23), this%n, .false.)
340 call neko_scratch_registry%request(reeta, ind(24), this%n, .false.)
341 call neko_scratch_registry%request(delx, ind(25), this%n, .false.)
342 call neko_scratch_registry%request(diagx, ind(26), this%n, .false.)
343 call neko_scratch_registry%request(dx, ind(27), this%n, .false.)
344 call neko_scratch_registry%request(dxsi, ind(28), this%n, .false.)
345 call neko_scratch_registry%request(deta, ind(29), this%n, .false.)
346 call neko_scratch_registry%request(xold, ind(30), this%n, .false.)
347 call neko_scratch_registry%request(xsiold, ind(31), this%n, .false.)
348 call neko_scratch_registry%request(etaold, ind(32), this%n, .false.)
349 call neko_scratch_registry%request(bb, ind(33), this%m+1, .false.)
350
351 call neko_scratch_registry%request(gg, ind(34), this%m, this%n, .false.)
352 call neko_scratch_registry%request(aa, ind(35), this%m+1, this%m+1, .false.)
353
354 ! ------------------------------------------------------------------------ !
355 ! initial value for the parameters in the subsolve based on
356 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
357
358 epsi = 1.0_rp !100
359 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
360 this%n)
361 call device_cfill(y%x_d, 1.0_rp, this%m)
362 z = 1.0_rp
363 zeta = 1.0_rp
364 call device_cfill(lambda%x_d, 1.0_rp, this%m)
365 call device_cfill(s%x_d, 1.0_rp, this%m)
366 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
367 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
368 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
369
370 ! ------------------------------------------------------------------------ !
371 ! Computing the minimal epsilon and choose the most conservative one
372
373 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
374 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
375 mpi_real_precision, mpi_min, neko_comm, ierr)
376
377 ! ------------------------------------------------------------------------ !
378 ! The main loop of the dual-primal interior point method.
379
380 do while (epsi .gt. minimal_epsilon)
381
382 ! --------------------------------------------------------------------- !
383 ! Calculating residuals based on
384 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
385 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
386
387 associate(p0j => this%p0j, q0j => this%q0j, &
388 pij => this%pij, qij => this%qij, &
389 low => this%low, upp => this%upp, &
390 alpha => this%alpha, beta => this%beta, &
391 c => this%c, d => this%d, &
392 a0 => this%a0, a => this%a)
393
394 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
395 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
396 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
397
398 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
399 call device_add2(rey%x_d, c%x_d, this%m)
400 call device_sub2(rey%x_d, lambda%x_d, this%m)
401 call device_sub2(rey%x_d, mu%x_d, this%m)
402 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
403
404 call device_cfill(relambda%x_d, 0.0_rp, this%m)
405 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
406 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
407
408 end associate
409
410 ! --------------------------------------------------------------------- !
411 ! Computing the norm of the residuals
412
413 ! Complete the computations of lambda residuals
414 if (neko_device_mpi) then
415 call mpi_allreduce(mpi_in_place, relambda%x_d, this%m, &
416 mpi_real_precision, mpi_sum, neko_comm, ierr)
417 else
418 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
419 sync = .true.)
420 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
421 mpi_real_precision, mpi_sum, neko_comm, ierr)
422 call device_memcpy(relambda%x, relambda%x_d, this%m, host_to_device, &
423 sync = .true.)
424 end if
425
426 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
427 call device_sub2(relambda%x_d, y%x_d, this%m)
428 call device_add2(relambda%x_d, s%x_d, this%m)
429 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
430
431 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
432 call device_col2(rexsi%x_d, xsi%x_d, this%n)
433 call device_cadd(rexsi%x_d, - epsi, this%n)
434
435 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
436 call device_col2(reeta%x_d, eta%x_d, this%n)
437 call device_cadd(reeta%x_d, - epsi, this%n)
438
439 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
440 call device_cadd(remu%x_d, - epsi, this%m)
441
442 rezeta = zeta * z - epsi
443
444 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
445 call device_cadd(res%x_d, - epsi, this%m)
446
447 ! Setup vectors of residuals and their norms
448 residual_max = maxval([device_maxval(rex%x_d, this%n), &
449 device_maxval(rey%x_d, this%m), abs(rez), &
450 device_maxval(relambda%x_d, this%m), &
451 device_maxval(rexsi%x_d, this%n), &
452 device_maxval(reeta%x_d, this%n), &
453 device_maxval(remu%x_d, this%m), abs(rezeta), &
454 device_maxval(res%x_d, this%m)])
455
456 re_sq_norm = device_norm(rex%x_d, this%n) + &
457 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
458
459 call mpi_allreduce(mpi_in_place, residual_max, 1, &
460 mpi_real_precision, mpi_max, neko_comm, ierr)
461
462 call mpi_allreduce(mpi_in_place, re_sq_norm, &
463 1, mpi_real_precision, mpi_sum, neko_comm, ierr)
464
465 residual_norm = sqrt(device_norm(rey%x_d, this%m) + &
466 rez**2 + &
467 device_norm(relambda%x_d, this%m) + &
468 device_norm(remu%x_d, this%m)+ &
469 rezeta**2 + &
470 device_norm(res%x_d, this%m) &
471 + re_sq_norm)
472
473 ! --------------------------------------------------------------------- !
474 ! Internal loop
475
476 do iter = 1, this%max_iter
477
478 if (residual_max .lt. epsi) exit
479
480 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
481 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
482 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
483 this%m)
484
485 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
486 call device_add2(dely%x_d, this%c%x_d, this%m)
487 call device_sub2(dely%x_d, lambda%x_d, this%m)
488 call device_add2inv2(dely%x_d, y%x_d, - epsi, this%m)
489 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - epsi/z
490
491 ! Accumulate sums for dellambda (the term gi(x))
492 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
493 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
494 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
495
496 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
497 device_to_host, sync = .true.)
498 call mpi_allreduce(mpi_in_place, dellambda%x, this%m, &
499 mpi_real_precision, mpi_sum, neko_comm, ierr)
500 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
501 host_to_device, sync = .true.)
502
503 call device_add3s2(dellambda%x_d, dellambda%x_d, this%a%x_d, &
504 1.0_rp, -z, this%m)
505 call device_sub2(dellambda%x_d, y%x_d, this%m)
506 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
507 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
508
509 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
510 this%pij%x_d, this%qij%x_d, this%n, this%m)
511
512 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
513 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
514 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
515 lambda%x_d, this%n, this%m)
516
517 !Here we only consider the case m<n in the matlab code
518 !assembling the right hand side matrix based on eq(5.20)
519 ! bb = [dellambda + dely/(this%d%x + &
520 ! (mu/y)) - matmul(GG,delx/diagx), delz ]
521
522 !--------------------------------------------------------------------!
523 ! for MPI computation of bb
524
525 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
526 this%m)
527
528 call device_memcpy(bb%x, bb%x_d, this%m + 1, device_to_host, &
529 sync = .true.)
530 call mpi_allreduce(mpi_in_place, bb%x, this%m + 1, &
531 mpi_real_precision, mpi_sum, neko_comm, ierr)
532 call device_memcpy(bb%x, bb%x_d, this%m + 1, &
533 host_to_device, sync = .true.)
534
535 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
536 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
537
538 ! assembling the coefficients matrix AA based on eq(5.20)
539 ! AA(1:this%m,1:this%m) = &
540 ! matmul(matmul(GG,mma_diag(1/diagx)), transpose(GG))
541 ! !update diag(AA)
542 ! AA(1:this%m,1:this%m) = AA(1:this%m,1:this%m) + &
543 ! mma_diag(s/lambda + 1.0/(this%d%x + (mu/y)))
544
545 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1))
546 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
547
548 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
549 device_to_host, sync = .true.)
550 call mpi_allreduce(mpi_in_place, aa%x, &
551 (this%m + 1)**2, mpi_real_precision, mpi_sum, neko_comm, ierr)
552 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
553 host_to_device, sync = .true.)
554
555 call device_prepare_aa_matrix(aa%x_d, s%x_d, lambda%x_d, &
556 this%d%x_d, mu%x_d, y%x_d, this%a%x_d, zeta, z, this%m)
557
558 ! Device solve for the linear system
559 call device_solve_linear_system(aa%x_d, bb%x_d, this%m + 1, info)
560 if (info .ne. 0) then
561 call neko_error("Linear solver failed on the device in " // &
562 "mma_subsolve_dpip")
563 end if
564
565 call device_copy(dlambda%x_d, bb%x_d, this%m)
566
567
568 !We need to write the last element of bb to dz so this is necessary
569 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
570 sync = .true.)
571 dz = bb%x(this%m + 1)
572
573
574 ! based on eq(5.19)
575 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
576 dlambda%x_d, this%n, this%m)
577 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
578 mu%x_d, y%x_d, this%m)
579 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
580 this%alpha%x_d, epsi, this%n)
581 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
582 this%beta%x_d, epsi, this%n)
583
584 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
585 call device_cmult(dmu%x_d, -1.0_rp, this%m)
586 call device_cadd(dmu%x_d, epsi, this%m)
587 call device_invcol2(dmu%x_d, y%x_d, this%m)
588 call device_sub2(dmu%x_d, mu%x_d, this%m)
589 dzeta = -zeta + (epsi - zeta * dz) / z
590 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
591 call device_cmult(ds%x_d, -1.0_rp, this%m)
592 call device_cadd(ds%x_d, epsi, this%m)
593 call device_invcol2(ds%x_d, lambda%x_d, this%m)
594 call device_sub2(ds%x_d, s%x_d, this%m)
595
596 steg = maxval([1.0_rp, &
597 device_maxval2(dy%x_d, y%x_d, -1.01_rp, this%m), &
598 -1.01_rp * dz / z, &
599 device_maxval2(dlambda%x_d, lambda%x_d, -1.01_rp, this%m), &
600 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
601 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
602 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
603 -1.01_rp * dzeta / zeta, &
604 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
605 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, this%n),&
606 device_maxval3(dx%x_d, this%beta%x_d, x%x_d, 1.01_rp, this%n)])
607
608 steg = 1.0_rp / steg
609
610 call device_copy(xold%x_d, x%x_d, this%n)
611 call device_copy(yold%x_d, y%x_d, this%m)
612 zold = z
613 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
614 call device_copy(xsiold%x_d, xsi%x_d, this%n)
615 call device_copy(etaold%x_d, eta%x_d, this%n)
616 call device_copy(muold%x_d, mu%x_d, this%m)
617 zetaold = zeta
618 call device_copy(sold%x_d, s%x_d, this%m)
619
620 new_residual = 2.0_rp * residual_norm
621
622 ! Share the new_residual and steg values
623 call mpi_allreduce(mpi_in_place, steg, 1, &
624 mpi_real_precision, mpi_min, neko_comm, ierr)
625 call mpi_allreduce(mpi_in_place, new_residual, 1, &
626 mpi_real_precision, mpi_min, neko_comm, ierr)
627
628 ! The innermost loop to determine the suitable step length
629 ! using the Backtracking Line Search approach
630 itto = 0
631 do while ((new_residual .gt. residual_norm) .and. (itto .lt. 50))
632 itto = itto + 1
633
634 ! update the variables
635 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, steg, this%n)
636 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, steg, this%m)
637 z = zold + steg*dz
638 call device_add3s2(lambda%x_d, lambdaold%x_d, &
639 dlambda%x_d, 1.0_rp, steg, this%m)
640 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
641 1.0_rp, steg, this%n)
642 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
643 1.0_rp, steg, this%n)
644 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
645 1.0_rp, steg, this%m)
646 zeta = zetaold + steg*dzeta
647 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
648 steg, this%m)
649
650 ! Recompute the new_residual to see if this stepsize improves
651 ! the residue
652 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
653 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
654 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
655 eta%x_d, this%n, this%m)
656
657 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
658 call device_add2(rey%x_d, this%c%x_d, this%m)
659 call device_sub2(rey%x_d, lambda%x_d, this%m)
660 call device_sub2(rey%x_d, mu%x_d, this%m)
661
662 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, this%a%x_d, this%m)
663
664 ! Accumulate sums for relambda (the term gi(x))
665 call device_cfill(relambda%x_d, 0.0_rp, this%m)
666 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
667 this%low%x_d, this%pij%x_d, this%qij%x_d, &
668 this%n, this%m)
669
670 call device_memcpy(relambda%x, relambda%x_d, this%m, &
671 device_to_host, sync = .true.)
672 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
673 mpi_real_precision, mpi_sum, neko_comm, ierr)
674 call device_memcpy(relambda%x, relambda%x_d, &
675 this%m, host_to_device, sync = .true.)
676
677 call device_add3s2(relambda%x_d, relambda%x_d, &
678 this%a%x_d, 1.0_rp, -z, this%m)
679 call device_sub2(relambda%x_d, y%x_d, this%m)
680 call device_add2(relambda%x_d, s%x_d, this%m)
681 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
682
683 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
684 call device_col2(rexsi%x_d, xsi%x_d, this%n)
685 call device_cadd(rexsi%x_d, - epsi, this%n)
686
687 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
688 call device_col2(reeta%x_d, eta%x_d, this%n)
689 call device_cadd(reeta%x_d, - epsi, this%n)
690
691 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
692 call device_cadd(remu%x_d, - epsi, this%m)
693
694 rezeta = zeta*z - epsi
695
696 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
697 call device_cadd(res%x_d, - epsi, this%m)
698
699 ! Compute squared norms for the residuals
700 re_sq_norm = device_norm(rex%x_d, this%n) + &
701 device_norm(rexsi%x_d, this%n) + &
702 device_norm(reeta%x_d, this%n)
703 call mpi_allreduce(mpi_in_place, re_sq_norm, 1, &
704 mpi_real_precision, mpi_sum, neko_comm, ierr)
705
706 new_residual = sqrt(device_norm(rey%x_d, this%m) + &
707 rez**2 + &
708 device_norm(relambda%x_d, this%m) + &
709 device_norm(remu%x_d, this%m) + &
710 rezeta**2 + &
711 device_norm(res%x_d, this%m) + &
712 re_sq_norm)
713
714 steg = steg / 2.0_rp
715
716 end do
717 steg = 2.0_rp * steg ! Correction for the final division by 2
718
719 ! Update the maximum and norm of the residuals
720 residual_norm = new_residual
721 residual_max = maxval([ &
722 device_maxval(rex%x_d, this%n), &
723 device_maxval(rey%x_d, this%m), &
724 abs(rez), &
725 device_maxval(relambda%x_d, this%m), &
726 device_maxval(rexsi%x_d, this%n), &
727 device_maxval(reeta%x_d, this%n), &
728 device_maxval(remu%x_d, this%m), &
729 abs(rezeta), &
730 device_maxval(res%x_d, this%m)])
731
732 call mpi_allreduce(mpi_in_place, residual_max, 1, &
733 mpi_real_precision, mpi_max, neko_comm, ierr)
734
735 end do
736
737 epsi = 0.1_rp * epsi
738 end do
739
740 ! Save the new designx
741 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
742 call device_copy(this%xold1%x_d, designx_d, this%n)
743 call device_copy(designx_d, x%x_d, this%n)
744
745 ! update the parameters of the MMA object nesessary to compute KKT residual
746 call device_copy(this%y%x_d, y%x_d, this%m)
747 this%z = z
748 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
749 this%zeta = zeta
750 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
751 call device_copy(this%eta%x_d, eta%x_d, this%n)
752 call device_copy(this%mu%x_d, mu%x_d, this%m)
753 call device_copy(this%s%x_d, s%x_d, this%m)
754
755 !free all the initiated variables in this subroutine
756 call neko_scratch_registry%relinquish(ind)
757 end subroutine mma_subsolve_dpip_device
758
761 subroutine mma_subsolve_dip_device(this, designx_d)
762 class(mma_t), intent(inout) :: this
763 type(c_ptr), intent(in) :: designx_d
764 integer :: iter, ierr
765 real(kind=rp) :: epsi, residumax, z, steg
766 ! vectors with size m
767 type(vector_t), pointer :: y, lambda, mu, relambda, remu, dlambda, dmu, &
768 gradlambda, zerom, dd, dummy_m
769 ! vectors with size n
770 type(vector_t), pointer :: x, pjlambda, qjlambda
771
772 ! inverse of a diag matrix:
773 type(vector_t), pointer :: Ljjxinv ! [∇_x^2 Ljj]−1
774 type(matrix_t), pointer :: hijx ! ∇_x hij
775 type(matrix_t), pointer :: Hess
776
777 integer :: info, ind(17)
778
779 real(kind=rp) :: minimal_epsilon
780
781 call neko_scratch_registry%request(y, ind(1), this%m, .false.)
782 call neko_scratch_registry%request(lambda, ind(2), this%m, .false.)
783 call neko_scratch_registry%request(mu, ind(3), this%m, .false.)
784 call neko_scratch_registry%request(relambda, ind(4), this%m, .false.)
785 call neko_scratch_registry%request(remu, ind(5), this%m, .false.)
786 call neko_scratch_registry%request(dlambda, ind(6), this%m, .false.)
787 call neko_scratch_registry%request(dmu, ind(7), this%m, .false.)
788 call neko_scratch_registry%request(gradlambda, ind(8), this%m, .false.)
789 call neko_scratch_registry%request(zerom, ind(9), this%m, .false.)
790 call neko_scratch_registry%request(dd, ind(10), this%m, .false.)
791 call neko_scratch_registry%request(dummy_m, ind(11), this%m, .false.)
792
793 call neko_scratch_registry%request(x, ind(12), this%n, .false.)
794 call neko_scratch_registry%request(pjlambda,ind(13), this%n, .false.)
795 call neko_scratch_registry%request(qjlambda, ind(14), this%n, .false.)
796
797 call neko_scratch_registry%request(ljjxinv, ind(15), this%n, .false.)
798
799 call neko_scratch_registry%request(hijx, ind(16), this%m, this%n, .false.)
800 call neko_scratch_registry%request(hess, ind(17), this%m, this%m, .false.)
801
802 ! ------------------------------------------------------------------------ !
803 ! initial value for the parameters in the subsolve based on
804 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
805
806 epsi = 1.0_rp !100
807 call device_cfill(y%x_d, 1.0_rp, this%m)
808 ! initialize lambda with an array of ones (change to this%c%x/2 if needed!)
809 call device_cfill(lambda%x_d, 1.0_rp, this%m)
810 call device_cmult2(dummy_m%x_d, this%c%x_d, 0.5_rp, this%m)
811 call device_pwmax2(lambda%x_d, dummy_m%x_d, this%m)
812
813 call device_cfill(mu%x_d, 1.0_rp, this%m)
814 z = 0.0_rp
815
816 ! dd is defined as this%d + 1.0e-8_rp, to avoid devision by 0 in computing y
817 call device_cadd2(dd%x_d, this%d%x_d, 1.0e-8_rp, this%m)
818
819 ! ------------------------------------------------------------------------ !
820 ! Computing the minimal epsilon and choose the most conservative one
821
822 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
823 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
824 mpi_real_precision, mpi_min, neko_comm, ierr)
825
826 ! ------------------------------------------------------------------------ !
827 ! The main loop of the dual-primal interior point method.
828
829 outer: do while (epsi .gt. minimal_epsilon)
830 ! calculating residuals based on
831 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
832 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
833 associate(p0j => this%p0j, q0j => this%q0j, &
834 pij => this%pij, qij => this%qij, &
835 low => this%low, upp => this%upp, &
836 alpha => this%alpha, beta => this%beta, &
837 c => this%c, a0 => this%a0, a => this%a)
838
839 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
840 ! the initial value of λ
841
842 ! Comput the value of y that minimizes L_y for the current λ
843 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
844 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
845 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
846 ! division by dd to avoid devision by 0 (in case this%d%x_d = 0)
847 call device_invcol2(y%x_d, dd%x_d, this%m)
848 call device_pwmax2(y%x_d, zerom%x_d, this%m)
849
850 ! Comput the value of z that minimizes L_z for the current λ
851 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
852 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
853 ! ensure z>=0
854 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
855 z = device_glsum(dummy_m%x_d, this%m)
856 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
857
858 ! Comput the value of x that minimizes L_x for the current λ
859 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
860 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
861 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
862 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
863 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
864 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
865 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
866
867 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
868 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
869
870 call device_cfill(relambda%x_d, 0.0_rp, this%m)
871 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
872 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
873
874 ! Global comminucation for relambda values
875
876 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
877 sync = .true.)
878 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
879 mpi_real_precision, mpi_sum, neko_comm, ierr)
880 call device_memcpy(relambda%x, relambda%x_d, this%m, &
881 host_to_device, sync = .true.)
882
883 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
884 call device_sub2(relambda%x_d, y%x_d, this%m)
885 call device_add2(relambda%x_d, mu%x_d, this%m)
886 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
887
888 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
889 call device_cadd(remu%x_d, -epsi, this%m)
890
891 residumax = maxval([device_maxval(relambda%x_d, this%m), &
892 device_maxval(remu%x_d, this%m)])
893
894 ! ------------------------------------------------------------------- !
895 ! Internal loop
896 do iter = 1, this%max_iter
897 !Check the condition
898 if (residumax .lt. epsi) exit
899
900 ! Compute dL(x, y, z, λ)/dλ for the updated x(λ), y(λ), z(λ)
901 ! based on the implementation in the following paper by Niels
902 ! https://doi.org/10.1007/s00158-012-0869-2
903 ! (https://github.com/topopt/TopOpt_in_PETSc/blob/master/MMA.cc)
904 ! The formula for gradlambda and relambda are basically the same:
905 ! thus, we utilise gradlambda = relambda - mu for efficiency
906 call device_copy(gradlambda%x_d, relambda%x_d, this%m)
907 call device_sub2(gradlambda%x_d, mu%x_d, this%m)
908
909 ! Update gradlambda as the right hand side for Newton's method(eq10)
910 call device_cfill(dummy_m%x_d, epsi, this%m)
911 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
912 call device_add2(gradlambda%x_d, dummy_m%x_d, this%m)
913 call device_cmult(gradlambda%x_d, -1.0_rp, this%m)
914
915 ! Computing the Hessian as in equation (13) in
916 !! https://doi.org/10.1007/s00158-012-0869-2
917
918 !--------------contributions of x terms to Hess--------------------!
919 call device_mma_ljjxinv(ljjxinv%x_d, pjlambda%x_d, qjlambda%x_d, &
920 x%x_d, low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
921
922 call device_gg(hijx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
923 this%pij%x_d, this%qij%x_d, this%n, this%m)
924
925 call device_cfill(hess%x_d, 0.0_rp, (this%m) * (this%m) )
926 call device_hess(hess%x_d, hijx%x_d, ljjxinv%x_d, this%n, this%m)
927
928 ! download Hess to CPU, mpi reduce, upload to the device
929 call device_memcpy(hess%x, hess%x_d, this%m*this%m, device_to_host, &
930 sync = .true.)
931 call mpi_allreduce(mpi_in_place, hess%x, &
932 this%m*this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
933 ! No need to upload to device since we solve LSE on CPU
934 ! But now we solve LSE on GPU, so upload it:
935 call device_memcpy(hess%x, hess%x_d, this%m*this%m, &
936 host_to_device, sync = .true.)
937
938 !---------------contributions of z terms to Hess-------------------!
939 ! There is no contibution to the Hess from z terms as z terms are
940 ! linear w.r.t λ
941
942
943 !---------------contributions of y terms to Hess-------------------!
944 ! Only for inactive constraint, we consider contributions to Hess.
945 ! Note that if d(i) = 0, the y terms (just like z terms) will not
946 ! contribute to the Hessian matrix.
947 ! Note that since we use DGESV to solve LSE on CPU, we dont need
948 ! cuda kernel for this part
949 ! Also, improve the robustness by stablizing the Hess using
950 ! Levenberg-Marquardt algorithm (heuristically)
951 call device_prepare_hessian(hess%x_d, y%x_d, this%d%x_d, &
952 mu%x_d, lambda%x_d, this%m)
953
954 ! Device solve for the linear system
955 call device_solve_linear_system(hess%x_d, gradlambda%x_d, &
956 this%m, info)
957 if (info .ne. 0) then
958 call neko_error("Linear solver failed on the device in " // &
959 "mma_subsolve_dip")
960 end if
961
962 call device_copy(dlambda%x_d, gradlambda%x_d, this%m)
963
964 ! based on eq(11) for delta eta
965 call device_copy(dummy_m%x_d, dlambda%x_d, this%m)
966 call device_col2(dummy_m%x_d, mu%x_d, this%m)
967 call device_invcol2(dummy_m%x_d, lambda%x_d, this%m)
968
969 call device_cfill(dmu%x_d, epsi, this%m)
970 call device_invcol2(dmu%x_d, lambda%x_d, this%m)
971 call device_add2s2(dmu%x_d, dummy_m%x_d, -1.0_rp, this%m)
972 call device_sub2(dmu%x_d, mu%x_d, this%m)
973
974 steg = maxval([1.005_rp, device_maxval2(dlambda%x_d, lambda%x_d, &
975 -1.01_rp, this%m), device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, &
976 this%m)])
977 steg = 1.0_rp / steg
978
979 call device_add2s2(lambda%x_d, dlambda%x_d, steg, this%m)
980 call device_add2s2(mu%x_d, dmu%x_d, steg, this%m)
981
982 ! minimize(L_x, L_y, L_z) and compute x(λ), y(λ), z(λ) for
983 ! the updated values of λ
984
985 ! Comput the value of y that minimizes L_y for the current λ
986 ! minimize (sum_{i=1}^{m} [ (c_i - λ_i) * y_i + 0.5 * d_i * y_i^2 ])
987 ! dL_y/dy =0 => y= (λ_i - c_i)/d_i, ensure y>=0
988
989 call device_sub3(y%x_d, lambda%x_d, c%x_d, this%m)
990 ! division by dd to avoid devision by 0 (in case this%d%x_d = 0)
991 call device_invcol2(y%x_d, dd%x_d, this%m)
992 call device_pwmax2(y%x_d, zerom%x_d, this%m)
993
994 ! Comput the value of z that minimizes L_z for the current λ
995 ! minimize ((a_0 - sum_{i=1}^{m} λ_i * a_i) * z)
996 ! if (a_0-dot_product(lambda, a)>=0) z=0 else z= 1.0
997 ! ensure z>=0
998 call device_col3(dummy_m%x_d, lambda%x_d, a%x_d, this%m)
999 z = device_glsum(dummy_m%x_d, this%m)
1000 z = merge(0.0_rp, 1.0_rp, a0 - z >= 0.0)
1001
1002 ! Comput the value of x that minimizes L_x for the current λ
1003 ! minimize( sum_{j=1}^{n} [ (p_{0j} + sum_{i=1}^{m} λ_i *
1004 ! p_{ij}) / (u_j - x_j) + (q_{0j} + sum_{i=1}^{m} λ_i * q_{ij}) /
1005 ! (x_j - l_j) ] - sum_{i=1}^{m} λ_i * b_i)
1006 call device_mattrans_v_mul(pjlambda%x_d, pij%x_d, lambda%x_d, this%m, this%n)
1007 call device_mattrans_v_mul(qjlambda%x_d, qij%x_d, lambda%x_d, this%m, this%n)
1008 call device_add2(pjlambda%x_d, p0j%x_d, this%n)
1009 call device_add2(qjlambda%x_d, q0j%x_d, this%n)
1010
1011 call device_mma_dipsolvesub1(x%x_d, pjlambda%x_d, qjlambda%x_d, &
1012 low%x_d, upp%x_d, alpha%x_d, beta%x_d, this%n)
1013
1014 ! Compute the residual for the lambda and mu using eq(9) and eq(15)
1015
1016 call device_cfill(relambda%x_d, 0.0_rp, this%m)
1017 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
1018 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
1019
1020 ! Global comminucation for relambda values
1021
1022 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
1023 sync = .true.)
1024 call mpi_allreduce(mpi_in_place, relambda%x, this%m, &
1025 mpi_real_precision, mpi_sum, neko_comm, ierr)
1026 call device_memcpy(relambda%x, relambda%x_d, this%m, &
1027 host_to_device, sync = .true.)
1028
1029 call device_add2s2(relambda%x_d, this%a%x_d, -z, this%m)
1030 call device_sub2(relambda%x_d, y%x_d, this%m)
1031 call device_add2(relambda%x_d, mu%x_d, this%m)
1032 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
1033
1034 call device_col3(remu%x_d, mu%x_d, lambda%x_d, this%m)
1035 call device_cadd(remu%x_d, -epsi, this%m)
1036
1037 residumax = maxval([device_maxval(relambda%x_d, this%m), &
1038 device_maxval(remu%x_d, this%m)])
1039 end do
1040 end associate
1041 epsi = 0.1_rp * epsi
1042 end do outer
1043
1044 ! Save the new designx
1045 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
1046 call device_copy(this%xold1%x_d, designx_d, this%n)
1047 call device_copy(designx_d, x%x_d, this%n)
1048
1049 ! update the parameters of the MMA object nesessary to compute KKT residual
1050 call device_copy(this%y%x_d, y%x_d, this%m)
1051 this%z = z
1052 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
1053 call device_copy(this%mu%x_d, mu%x_d, this%m)
1054
1055 call neko_scratch_registry%relinquish(ind)
1056 end subroutine mma_subsolve_dip_device
1057
1058end submodule mma_device
MMA module.
Definition mma.f90:69