Neko-TOP
A portable framework for high-order spectral element flow toplogy optimization.
Loading...
Searching...
No Matches
mma_device.f90
Go to the documentation of this file.
1! Copyright (c) 2025, The Neko-TOP Authors
2! All rights reserved.
3!
4! Redistribution and use in source and binary forms, with or without
5! modification, are permitted provided that the following conditions
6! are met:
7!
8! * Redistributions of source code must retain the above copyright
9! notice, this list of conditions and the following disclaimer.
10!
11! * Redistributions in binary form must reproduce the above
12! copyright notice, this list of conditions and the following
13! disclaimer in the documentation and/or other materials provided
14! with the distribution.
15!
16! * Neither the name of the authors nor the names of its
17! contributors may be used to endorse or promote products derived
18! from this software without specific prior written permission.
19!
20! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21! "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22! LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23! FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24! COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
25! INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
26! BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27! LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28! CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29! LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30! ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31! POSSIBILITY OF SUCH DAMAGE.
32
33submodule(mma) mma_device
34
35 use device_math, only: device_copy, device_cmult, device_cadd, device_cfill, &
36 device_add2, device_add3s2, device_invcol2, device_col2, device_col3, &
37 device_sub2, device_sub3
45
46
47 use neko_config, only: neko_bcknd_device
48 use device, only: device_to_host
49 use comm, only: pe_rank
50 use mpi_f08, only: mpi_in_place
51
52 implicit none
53
54contains
55
56 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
57 ! ----------------------------------------------------- !
58 ! Update the design variable x by solving the convex !
59 ! approximation of the problem. !
60 ! !
61 ! This subroutine is called in each iteration of the !
62 ! optimization loop !
63 ! ----------------------------------------------------- !
64 class(mma_t), intent(inout) :: this
65 integer, intent(in) :: iter
66 real(kind=rp), dimension(this%n), intent(inout) :: x
67
68 type(vector_t) :: df0dx, fval, xdesign
69 type(matrix_t) :: dfdx
70
71 if (.not. this%is_initialized) then
72 write(stderr, *) "The MMA object is not initialized."
73 error stop
74 end if
75
76
77 call xdesign%init(this%n)
78 call device_memcpy(x, xdesign%x_d, this%n, host_to_device, sync = .false.)
79
80 ! generate a convex approximation of the problem
81 call mma_gensub_device(this, iter, xdesign, df0dx, fval, dfdx)
82 !solve the approximation problem using interior point method
83 call mma_subsolve_dpip_device(this, xdesign)
84 !update the design vector x on the host
85 call device_memcpy(x, xdesign%x_d, this%n, device_to_host, sync = .false.)
86
87 this%is_updated = .true.
88 end subroutine mma_update_device
89
90 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
91 class(mma_t), intent(inout) :: this
92 real(kind=rp), dimension(this%n), intent(in) :: x
93 type(vector_t), intent(in) :: fval, df0dx
94 type(matrix_t), intent(in) :: dfdx
95
96 type(vector_t) :: designx
97 real(kind=rp) :: rez, rezeta
98 type(vector_t) :: rey, relambda, remu, res
99 type(vector_t) :: rex, rexsi, reeta
100 real(kind=rp) :: residu_val
101 integer :: ierr
102 real(kind=rp) :: re_xstuff_squ_global
103 real(kind=rp) :: globaltemp_norm
104
105 ! create a vector type x to have a c_ptr to point to the array designx
106 call designx%init(this%n)
107 designx%x = x
108 call device_memcpy(designx%x, designx%x_d, this%n, host_to_device, &
109 sync = .false.)
110
111
112 call rey%init(this%m)
113 call relambda%init(this%m)
114 call remu%init(this%m)
115 call res%init(this%m)
116
117 call rex%init(this%n)
118 call rexsi%init(this%n)
119 call reeta%init(this%n)
120
121 call device_kkt_rex(rex%x_d, df0dx%x_d, dfdx%x_d, this%xsi%x_d, &
122 this%eta%x_d, this%lambda%x_d, this%n, this%m)
123
124
125 call device_col3(rey%x_d, this%d%x_d, this%y%x_d, this%m)
126 call device_add2(rey%x_d, this%c%x_d, this%m)
127 call device_sub2(rey%x_d, this%lambda%x_d, this%m)
128 call device_sub2(rey%x_d, this%mu%x_d, this%m)
129
130 rez = this%a0 - this%zeta - device_lcsc2(this%lambda%x_d, this%a%x_d, &
131 this%m)
132
133 call device_add3s2(relambda%x_d, fval%x_d, this%a%x_d, 1.0_rp, -this%z, &
134 this%m)
135 call device_sub2(relambda%x_d, this%y%x_d, this%m)
136 call device_add2(relambda%x_d, this%s%x_d, this%m)
137
138
139 call device_sub3(rexsi%x_d, designx%x_d, this%xmin%x_d, this%n)
140 call device_col2(rexsi%x_d, this%xsi%x_d, this%n)
141
142 call device_sub3(reeta%x_d, this%xmax%x_d, designx%x_d, this%n)
143 call device_col2(reeta%x_d, this%eta%x_d, this%n)
144
145 call device_col3(remu%x_d, this%mu%x_d, this%y%x_d, this%m)
146
147 rezeta = this%zeta*this%z
148
149 call device_col3(res%x_d, this%lambda%x_d, this%s%x_d, this%m)
150
151 residu_val = maxval([device_maxval(rex%x_d, this%n), &
152 device_maxval(rey%x_d, this%m), rez, &
153 device_maxval(relambda%x_d, this%m), &
154 device_maxval(rexsi%x_d, this%n), device_maxval(reeta%x_d, this%n), &
155 device_maxval(remu%x_d, this%m), rezeta, &
156 device_maxval(res%x_d, this%m)])
157
158 call mpi_allreduce(residu_val, this%residumax, 1, &
159 mpi_real_precision, mpi_max, neko_comm, ierr)
160
161 globaltemp_norm = device_norm(rex%x_d, this%n) + &
162 device_norm(rexsi%x_d, this%n) + device_norm(reeta%x_d, this%n)
163 call mpi_allreduce(globaltemp_norm, re_xstuff_squ_global, 1, &
164 mpi_real_precision, mpi_sum, neko_comm, ierr)
165 this%residunorm = sqrt(device_norm(rey%x_d, this%m) + rez**2 + &
166 device_norm(relambda%x_d, this%m) + device_norm(remu%x_d, this%m) + &
167 rezeta**2+device_norm(res%x_d, this%m) + re_xstuff_squ_global)
168 end subroutine mma_kkt_device
169
170 !============================================================================!
171 ! private internal subroutines
172
174 subroutine mma_gensub_device(this, iter, x, df0dx, fval, dfdx)
175 ! ----------------------------------------------------- !
176 ! Generate the approximation sub problem by computing !
177 ! the lower and upper asymtotes and the other necessary !
178 ! parameters (alpha, beta, p0j, q0j, pij, qij, ...). !
179 ! ----------------------------------------------------- !
180 class(mma_t), intent(inout) :: this
181 type(vector_t), intent(in) :: x
182 type(vector_t), intent(in) :: df0dx
183 type(vector_t), intent(in) :: fval
184 type(matrix_t), intent(in) :: dfdx
185
186 integer, intent(in) :: iter
187 integer :: ierr
188 type(vector_t) :: globaltmp_m
189
190 ! ------------------------------------------------------------------------ !
191 ! Setup the current asymptotes
192 call globaltmp_m%init(this%m)
193 if (iter .lt. 3) then
194 call device_add3s2(this%low%x_d, this%xmax%x_d, this%xmin%x_d, &
195 - this%asyinit, this%asyinit, this%n)
196 call device_add2(this%low%x_d, x%x_d, this%n)
197
198 call device_add3s2( this%upp%x_d, this%xmax%x_d, this%xmin%x_d, &
199 this%asyinit, - this%asyinit, this%n)
200 call device_add2(this%upp%x_d, x%x_d, this%n)
201 else
202 call device_mma_gensub2(this%low%x_d, this%upp%x_d, x%x_d, &
203 this%xold1%x_d, this%xold2%x_d, this%xmin%x_d, this%xmax%x_d, &
204 this%asydecr, this%asyincr, this%n)
205 end if
206 call device_memcpy(this%upp%x, this%upp%x_d, this%n, device_to_host, &
207 sync = .true.)
208 call device_memcpy(this%low%x, this%low%x_d, this%n, device_to_host, &
209 sync = .true.)
210 call device_mma_gensub3(x%x_d, df0dx%x_d, dfdx%x_d, this%low%x_d, &
211 this%upp%x_d, this%xmin%x_d, this%xmax%x_d, this%alpha%x_d, &
212 this%beta%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
213 this%qij%x_d, this%n, this%m)
214
215 call device_memcpy(this%alpha%x, this%alpha%x_d, this%n, device_to_host, &
216 sync = .true.)
217 call device_memcpy(this%beta%x, this%beta%x_d, this%n, device_to_host, &
218 sync = .true.)
219
220 ! ------------------------------------------------------------------------ !
221 ! Calculate p0j, q0j, pij, qij, and bi
222 call device_mma_gensub4(x%x_d, this%low%x_d, this%upp%x_d, this%pij%x_d, &
223 this%qij%x_d, this%n, this%m, this%bi%x_d)
224 call device_memcpy(this%pij%x, this%pij%x_d, this%n*this%m, &
225 device_to_host, sync = .true.)
226 call device_memcpy(this%qij%x, this%qij%x_d, this%n*this%m, &
227 device_to_host, sync = .true.)
228 ! ------------------------------------------------------------------------ !
229 ! cpu gpu transfer and global sum for bi
230 globaltmp_m%x = 0.0_rp
231 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
232 sync = .true.)
233 call mpi_allreduce(this%bi%x, globaltmp_m%x, this%m, mpi_real_precision, &
234 mpi_sum, neko_comm, ierr)
235 call device_memcpy(globaltmp_m%x, globaltmp_m%x_d, this%m, &
236 host_to_device, sync = .true.)
237 call device_sub3(this%bi%x_d, globaltmp_m%x_d, fval%x_d, this%m)
238
239 call device_memcpy(this%bi%x, this%bi%x_d, this%m, device_to_host, &
240 sync = .true.)
241
242 call globaltmp_m%free()
243 end subroutine mma_gensub_device
244
246 subroutine mma_subsolve_dpip_device(this, designx)
247 class(mma_t), intent(inout) :: this
248 type(vector_t), intent(in) :: designx
249 integer :: iter, itto, ierr
250 real(kind=rp) :: epsi, residumax, residunorm, z, zeta, rez, rezeta, &
251 delz, dz, dzeta, steg, dummy_one, zold, zetaold, newresidu
252 ! vectors with size m
253 type(vector_t) :: y, lambda, s, mu, rey, relambda, remu, res, &
254 dely, dellambda, dy, dlambda, ds, dmu, yold, lambdaold, sold, muold
255 type(vector_t) :: globaltmp_m
256
257 ! vectors with size n
258 type(vector_t) :: x, xsi, eta, rex, rexsi, reeta, &
259 delx, diagx, dx, dxsi, deta, xold, xsiold, etaold
260
261 type(vector_t) :: bb
262 type(matrix_t) :: GG
263 type(matrix_t) :: AA
264 type(matrix_t) :: globaltmp_mm
265
266 integer :: info
267 integer, dimension(this%m+1) :: ipiv
268 real(kind=rp) :: re_xstuff_squ_global
269
270 integer :: nglobal, i
271
272 real(kind=rp) :: cons
273 real(kind=rp) :: minimal_epsilon
274
275
276 call globaltmp_m%init(this%m)
277 call globaltmp_mm%init(this%m, this%m)
278
279
280 call y%init(this%m)
281 call lambda%init(this%m)
282 call s%init(this%m)
283 call mu%init(this%m)
284 call rey%init(this%m)
285 call relambda%init(this%m)
286 call remu%init(this%m)
287 call res%init(this%m)
288 call dely%init(this%m)
289 call dellambda%init(this%m)
290 call dy%init(this%m)
291 call dlambda%init(this%m)
292 call ds%init(this%m)
293 call dmu%init(this%m)
294 call yold%init(this%m)
295 call lambdaold%init(this%m)
296 call sold%init(this%m)
297 call muold%init(this%m)
298 call x%init(this%n)
299 call xsi%init(this%n)
300 call eta%init(this%n)
301 call rex%init(this%n)
302 call rexsi%init(this%n)
303 call reeta%init(this%n)
304 call delx%init(this%n)
305 call diagx%init(this%n)
306 call dx%init(this%n)
307 call dxsi%init(this%n)
308 call deta%init(this%n)
309 call xold%init(this%n)
310 call xsiold%init(this%n)
311 call etaold%init(this%n)
312 call bb%init(this%m+1)
313
314 call gg%init(this%m, this%n)
315 call aa%init(this%m+1, this%m+1)
316
317 ! ------------------------------------------------------------------------ !
318 ! initial value for the parameters in the subsolve based on
319 ! page 15 of "https://people.kth.se/~krille/mmagcmma.pdf"
320 dummy_one = 1.0_rp
321 epsi = 1.0_rp !100
322 call device_add3s2(x%x_d, this%alpha%x_d, this%beta%x_d, 0.5_rp, 0.5_rp, &
323 this%n)
324 call device_cfill(y%x_d, 1.0_rp, this%m)
325 z = 1.0_rp
326 zeta = 1.0_rp
327 call device_cfill(lambda%x_d, 1.0_rp, this%m)
328 call device_cfill(s%x_d, 1.0_rp, this%m)
329 call device_mma_max(xsi%x_d, x%x_d, this%alpha%x_d, this%n)
330 call device_mma_max(eta%x_d, this%beta%x_d, x%x_d, this%n)
331 call device_max2(mu%x_d, 1.0_rp, this%c%x_d, 0.5_rp, this%m)
332 call device_memcpy(xsi%x, xsi%x_d, this%n, device_to_host, sync = .true.)
333 call device_memcpy(eta%x, eta%x_d, this%n, device_to_host, sync = .true.)
334 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, sync = .true.)
335
336 call mpi_allreduce(this%n, nglobal, 1, mpi_integer, mpi_sum, &
337 neko_comm, ierr)
338
339 ! ------------------------------------------------------------------------ !
340 ! Computing the minimal epsilon and choose the most conservative one
341
342 minimal_epsilon = max(0.9_rp * this%epsimin, 1.0e-12_rp)
343 call mpi_allreduce(mpi_in_place, minimal_epsilon, 1, &
344 mpi_real_precision, mpi_min, neko_comm, ierr)
345
346 ! ------------------------------------------------------------------------ !
347 ! The main loop of the dual-primal interior point method.
348
349 outer: do while (epsi .gt. minimal_epsilon)
350 ! calculating residuals based on
351 ! "https://people.kth.se/~krille/mmagcmma.pdf" for the variables
352 ! x, y, z, lambda residuals based on eq(5.9a)-(5.9d), respectively.
353 associate(p0j => this%p0j, q0j => this%q0j, &
354 pij => this%pij, qij => this%qij, &
355 low => this%low, upp => this%upp, &
356 alpha => this%alpha, beta => this%beta, &
357 c => this%c, d => this%d, &
358 a0 => this%a0, a => this%a, &
359 bi => this%bi)
360
361 call device_rex(rex%x_d, x%x_d, low%x_d, upp%x_d, &
362 pij%x_d, p0j%x_d, qij%x_d, q0j%x_d, &
363 lambda%x_d, xsi%x_d, eta%x_d, this%n, this%m)
364
365 call device_col3(rey%x_d, d%x_d, y%x_d, this%m)
366 call device_add2(rey%x_d, c%x_d, this%m)
367 call device_sub2(rey%x_d, lambda%x_d, this%m)
368 call device_sub2(rey%x_d, mu%x_d, this%m)
369 rez = a0 - zeta - device_lcsc2(lambda%x_d, a%x_d, this%m)
370 call device_cfill(relambda%x_d, 0.0_rp, this%m)
371 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
372 low%x_d, pij%x_d, qij%x_d, this%n, this%m)
373 call device_memcpy(relambda%x, relambda%x_d, this%m, device_to_host, &
374 sync = .true.)
375
376 end associate
377
378 globaltmp_m%x = 0.0_rp
379 call mpi_allreduce(relambda%x, globaltmp_m%x, this%m, &
380 mpi_real_precision, mpi_sum, neko_comm, ierr)
381
382 call device_memcpy(globaltmp_m%x, globaltmp_m%x_d, this%m, &
383 host_to_device, sync = .true.)
384 call device_add3s2(relambda%x_d, globaltmp_m%x_d, this%a%x_d, &
385 1.0_rp, -z, this%m)
386 call device_sub2(relambda%x_d, y%x_d, this%m)
387 call device_add2(relambda%x_d, s%x_d, this%m)
388 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
389
390 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
391 call device_col2(rexsi%x_d, xsi%x_d, this%n)
392 call device_cadd(rexsi%x_d, -epsi, this%n)
393
394 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
395 call device_col2(reeta%x_d, eta%x_d, this%n)
396 call device_cadd(reeta%x_d, -epsi, this%n)
397
398 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
399 call device_cadd(remu%x_d, -epsi, this%m)
400
401 rezeta = zeta*z -epsi
402
403 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
404 call device_cadd(res%x_d, -epsi, this%m)
405
406 cons = 0.0_rp
407 cons = maxval([device_maxval(rex%x_d, this%n), &
408 device_maxval(rey%x_d, this%m), rez, &
409 device_maxval(relambda%x_d, this%m), &
410 device_maxval(rexsi%x_d, this%n), &
411 device_maxval(reeta%x_d, this%n), &
412 device_maxval(remu%x_d, this%m), rezeta, &
413 device_maxval(res%x_d, this%m)])
414 residumax = 0.0_rp
415 call mpi_allreduce(cons, residumax, 1, mpi_real_precision, mpi_max, &
416 neko_comm, ierr)
417
418 re_xstuff_squ_global = 0.0_rp
419 cons = device_norm(rex%x_d, this%n) + &
420 device_norm(rexsi%x_d, this%n)+device_norm(reeta%x_d, this%n)
421 call mpi_allreduce(cons, re_xstuff_squ_global, 1, &
422 mpi_real_precision, mpi_sum, neko_comm, ierr)
423 cons = device_norm(rey%x_d, this%m) + rez**2 + &
424 device_norm(relambda%x_d, this%m) + &
425 device_norm(remu%x_d, this%m)+ &
426 rezeta**2+device_norm(res%x_d, this%m)
427 residunorm = sqrt(cons + re_xstuff_squ_global)
428
429
430
431 do iter = 1, this%max_iter !ittt
432
433 if (residumax .lt. epsi) exit
434
435 call device_delx(delx%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
436 this%pij%x_d, this%qij%x_d, this%p0j%x_d, this%q0j%x_d, &
437 this%alpha%x_d, this%beta%x_d, lambda%x_d, epsi, this%n, &
438 this%m)
439
440 call device_col3(dely%x_d, this%d%x_d, y%x_d, this%m)
441 call device_add2(dely%x_d, this%c%x_d, this%m)
442 call device_sub2(dely%x_d, lambda%x_d, this%m)
443 call device_add2inv2(dely%x_d, y%x_d, -epsi, this%m)
444 delz = this%a0 - device_lcsc2(lambda%x_d, this%a%x_d, this%m) - &
445 epsi/z
446 call device_cfill(dellambda%x_d, 0.0_rp, this%m)
447 call device_relambda(dellambda%x_d, x%x_d, this%upp%x_d, &
448 this%low%x_d, this%pij%x_d, this%qij%x_d, this%n, this%m)
449 call device_memcpy(dellambda%x, dellambda%x_d, this%m, &
450 device_to_host, sync = .true.)
451
452 globaltmp_m%x = 0.0_rp
453 call mpi_allreduce(dellambda%x, globaltmp_m%x, this%m, &
454 mpi_real_precision, mpi_sum, neko_comm, ierr)
455
456 call device_memcpy(globaltmp_m%x, globaltmp_m%x_d, this%m, &
457 host_to_device, sync = .true.)
458 call device_add3s2(dellambda%x_d, globaltmp_m%x_d, this%a%x_d, &
459 1.0_rp, -z, this%m)
460
461 call device_sub2(dellambda%x_d, y%x_d, this%m)
462 call device_sub2(dellambda%x_d, this%bi%x_d, this%m)
463 call device_add2inv2(dellambda%x_d, lambda%x_d, epsi, this%m)
464
465 call device_gg(gg%x_d, x%x_d, this%low%x_d, this%upp%x_d, &
466 this%pij%x_d, this%qij%x_d, this%n, this%m)
467
468 call device_diagx(diagx%x_d, x%x_d, xsi%x_d, this%low%x_d, &
469 this%upp%x_d, this%p0j%x_d, this%q0j%x_d, this%pij%x_d, &
470 this%qij%x_d, this%alpha%x_d, this%beta%x_d, eta%x_d, &
471 lambda%x_d, this%n, this%m)
472
473 call device_bb(bb%x_d, gg%x_d, delx%x_d, diagx%x_d, this%n, &
474 this%m)
475 call device_memcpy(bb%x, bb%x_d, this%m, device_to_host, &
476 sync = .true.)
477
478 call mpi_allreduce(mpi_in_place, bb%x(1:this%m), this%m, &
479 mpi_real_precision, mpi_sum, neko_comm, ierr)
480
481 call device_memcpy(bb%x, bb%x_d, this%m, &
482 host_to_device, sync = .true.)
483
484 call device_updatebb(bb%x_d, dellambda%x_d, dely%x_d, &
485 this%d%x_d, mu%x_d, y%x_d, delz, this%m)
486
487 call device_cfill(aa%x_d, 0.0_rp, (this%m+1) * (this%m+1) )
488 call device_aa(aa%x_d, gg%x_d, diagx%x_d, this%n, this%m)
489 call device_memcpy(aa%x, aa%x_d, (this%m+1) * (this%m+1), &
490 device_to_host, sync = .true.)
491 call mpi_allreduce(mpi_in_place, aa%x(1:this%m, 1:this%m), &
492 this%m * this%m, mpi_real_precision, mpi_sum, neko_comm, ierr)
493 call device_memcpy(aa%x, aa%x_d, &
494 (this%m) * (this%m), host_to_device, sync = .true.)
495
496 call device_memcpy(lambda%x, lambda%x_d, this%m, device_to_host, &
497 sync = .true.)
498 call device_memcpy(mu%x, mu%x_d, this%m, device_to_host, &
499 sync = .true.)
500 call device_memcpy(y%x, y%x_d, this%m, device_to_host, &
501 sync = .true.)
502 call device_memcpy(s%x, s%x_d, this%m, device_to_host, &
503 sync = .true.)
504 do i = 1, this%m
505 ! update the diag AA
506 aa%x(i, i) = aa%x(i, i) &
507 + s%x(i) / lambda%x(i) &
508 + 1.0_rp / (this%d%x(i) + mu%x(i) / y%x(i))
509 end do
510 aa%x(1:this%m, this%m+1) = this%a%x
511 aa%x(this%m+1, 1:this%m) = this%a%x
512 aa%x(this%m+1, this%m+1) = - zeta/z
513
514
515
516 call device_memcpy(bb%x, bb%x_d, this%m+1, device_to_host, &
517 sync = .true.)
518 call dgesv(this%m+1, 1, aa%x, this%m+1, ipiv, bb%x, this%m+1, &
519 info)
520 if (info .ne. 0) then
521 write(stderr, *) "DGESV failed in mma_device.f90."
522 write(stderr, *) "Please check mma_subsolve_dpip in mma.f90"
523 error stop
524 end if
525 call device_memcpy(bb%x, bb%x_d, this%m+1, host_to_device, &
526 sync = .true.)
527
528 call device_copy(dlambda%x_d, bb%x_d, this%m)
529 dz = bb%x(this%m + 1)
530
531 call device_dx(dx%x_d, delx%x_d, diagx%x_d, gg%x_d, &
532 dlambda%x_d, this%n, this%m)
533 call device_dy(dy%x_d, dely%x_d, dlambda%x_d, this%d%x_d, &
534 mu%x_d, y%x_d, this%m)
535 call device_dxsi(dxsi%x_d, xsi%x_d, dx%x_d, x%x_d, &
536 this%alpha%x_d, epsi, this%n)
537 call device_deta(deta%x_d, eta%x_d, dx%x_d, x%x_d, &
538 this%beta%x_d, epsi, this%n)
539
540
541
542
543 call device_col3(dmu%x_d, mu%x_d, dy%x_d, this%m)
544 call device_cmult(dmu%x_d, -1.0_rp, this%m)
545 call device_cadd(dmu%x_d, epsi, this%m)
546 call device_invcol2(dmu%x_d, y%x_d, this%m)
547 call device_sub2(dmu%x_d, mu%x_d, this%m)
548
549 dzeta = -zeta + (epsi-zeta*dz)/z
550 call device_col3(ds%x_d, dlambda%x_d, s%x_d, this%m)
551 call device_cmult(ds%x_d, -1.0_rp, this%m)
552 call device_cadd(ds%x_d, epsi, this%m)
553 call device_invcol2(ds%x_d, lambda%x_d, this%m)
554 call device_sub2(ds%x_d, s%x_d, this%m)
555
556
557 steg = maxval([dummy_one, device_maxval2(dy%x_d, y%x_d, &
558 -1.01_rp, this%m), -1.01_rp*dz/z, &
559 device_maxval2(dlambda%x_d, lambda%x_d, &
560 -1.01_rp, this%m), &
561 device_maxval2(dxsi%x_d, xsi%x_d, -1.01_rp, this%n), &
562 device_maxval2(deta%x_d, eta%x_d, -1.01_rp, this%n), &
563 device_maxval2(dmu%x_d, mu%x_d, -1.01_rp, this%m), &
564 device_maxval2(ds%x_d, s%x_d, -1.01_rp, this%m), &
565 device_maxval3(dx%x_d, x%x_d, this%alpha%x_d, -1.01_rp, &
566 this%n), device_maxval3(dx%x_d, this%beta%x_d, x%x_d, &
567 1.01_rp, this%n), -1.01_rp*dzeta/zeta])
568 steg = 1.0_rp/steg
569
570 ! find minimum step sizes between nodes
571 call mpi_allreduce(steg, steg, 1, &
572 mpi_real_precision, mpi_min, neko_comm, ierr)
573
574
575 call device_copy(xold%x_d, x%x_d, this%n)
576 call device_copy(yold%x_d, y%x_d, this%m)
577 zold = z
578 call device_copy(lambdaold%x_d, lambda%x_d, this%m)
579 call device_copy(xsiold%x_d, xsi%x_d, this%n)
580 call device_copy(etaold%x_d, eta%x_d, this%n)
581 call device_copy(muold%x_d, mu%x_d, this%m)
582 zetaold = zeta
583 call device_copy(sold%x_d, s%x_d, this%m)
584 newresidu = 2.0*residunorm
585 itto = 0
586
587 ! The innermost loop to determine the suitable step length
588 ! using the Backtracking Line Search approach
589 do while ((newresidu .gt. residunorm) .and. (itto .lt. 50))
590 itto = itto + 1
591 call device_add3s2(x%x_d, xold%x_d, dx%x_d, 1.0_rp, &
592 steg, this%n)
593 call device_add3s2(y%x_d, yold%x_d, dy%x_d, 1.0_rp, &
594 steg, this%m)
595 z = zold + steg*dz
596 call device_add3s2(lambda%x_d, lambdaold%x_d, &
597 dlambda%x_d, 1.0_rp, steg, this%m)
598
599 call device_add3s2(xsi%x_d, xsiold%x_d, dxsi%x_d, &
600 1.0_rp, steg, this%n)
601 call device_add3s2(eta%x_d, etaold%x_d, deta%x_d, &
602 1.0_rp, steg, this%n)
603
604 call device_add3s2(mu%x_d, muold%x_d, dmu%x_d, &
605 1.0_rp, steg, this%m)
606
607 zeta = zetaold + steg*dzeta
608
609 call device_add3s2(s%x_d, sold%x_d, ds%x_d, 1.0_rp, &
610 steg, this%m)
611
612 ! recompute the newresidu to see if this stepsize improves
613 ! the residue
614 call device_rex(rex%x_d, x%x_d, this%low%x_d, &
615 this%upp%x_d, this%pij%x_d, this%p0j%x_d, &
616 this%qij%x_d, this%q0j%x_d, lambda%x_d, xsi%x_d, &
617 eta%x_d, this%n, this%m)
618
619 call device_memcpy(rex%x, rex%x_d, this%n, device_to_host, &
620 sync = .true.)
621 call device_memcpy(xsi%x, xsi%x_d, this%n, device_to_host, &
622 sync = .true.)
623 call device_memcpy(eta%x, eta%x_d, this%n, device_to_host, &
624 sync = .true.)
625 call device_memcpy(lambda%x, lambda%x_d, this%m, &
626 device_to_host, sync = .true.)
627
628
629
630 call device_col3(rey%x_d, this%d%x_d, y%x_d, this%m)
631 call device_add2(rey%x_d, this%c%x_d, this%m)
632 call device_sub2(rey%x_d, lambda%x_d, this%m)
633 call device_sub2(rey%x_d, mu%x_d, this%m)
634
635 rez = this%a0 - zeta - device_lcsc2(lambda%x_d, &
636 this%a%x_d, this%m)
637
638 call device_cfill(relambda%x_d, 0.0_rp, this%m)
639 call device_relambda(relambda%x_d, x%x_d, this%upp%x_d, &
640 this%low%x_d, this%pij%x_d, this%qij%x_d, &
641 this%n, this%m)
642 call device_memcpy(relambda%x, relambda%x_d, this%m, &
643 device_to_host, sync = .true.)
644
645 globaltmp_m%x = 0.0_rp
646 call mpi_allreduce(relambda%x, globaltmp_m%x, this%m, &
647 mpi_real_precision, mpi_sum, neko_comm, ierr)
648
649 call device_memcpy(globaltmp_m%x, globaltmp_m%x_d, &
650 this%m, host_to_device, sync = .true.)
651
652
653
654 call device_add3s2(relambda%x_d, globaltmp_m%x_d, &
655 this%a%x_d, 1.0_rp, -z, this%m)
656 call device_sub2(relambda%x_d, y%x_d, this%m)
657 call device_add2(relambda%x_d, s%x_d, this%m)
658 call device_sub2(relambda%x_d, this%bi%x_d, this%m)
659
660
661 call device_sub3(rexsi%x_d, x%x_d, this%alpha%x_d, this%n)
662 call device_col2(rexsi%x_d, xsi%x_d, this%n)
663 call device_cadd(rexsi%x_d, -epsi, this%n)
664
665 call device_sub3(reeta%x_d, this%beta%x_d, x%x_d, this%n)
666 call device_col2(reeta%x_d, eta%x_d, this%n)
667 call device_cadd(reeta%x_d, -epsi, this%n)
668
669 call device_col3(remu%x_d, mu%x_d, y%x_d, this%m)
670 call device_cadd(remu%x_d, -epsi, this%m)
671
672 rezeta = zeta*z - epsi
673
674
675 call device_col3(res%x_d, lambda%x_d, s%x_d, this%m)
676 call device_cadd(res%x_d, -epsi, this%m)
677
678 re_xstuff_squ_global = 0.0_rp
679 cons = device_norm(rex%x_d, this%n) + &
680 device_norm(rexsi%x_d, this%n) + &
681 device_norm(reeta%x_d, this%n)
682 call mpi_allreduce(cons, re_xstuff_squ_global, 1, &
683 mpi_real_precision, mpi_sum, neko_comm, ierr)
684
685 cons = device_norm(rey%x_d, this%m) + rez**2 + &
686 device_norm(relambda%x_d, this%m) + &
687 device_norm(remu%x_d, this%m) + &
688 rezeta**2+device_norm(res%x_d, this%m)
689
690 newresidu = sqrt(cons+ re_xstuff_squ_global)
691
692 steg = steg/2.0_rp
693
694 cons = 0.0_rp
695 cons = maxval([device_maxval(rex%x_d, this%n), &
696 device_maxval(rey%x_d, this%m), rez, &
697 device_maxval(relambda%x_d, this%m), &
698 device_maxval(rexsi%x_d, this%n), &
699 device_maxval(reeta%x_d, this%n), &
700 device_maxval(remu%x_d, this%m), rezeta, &
701 device_maxval(res%x_d, this%m)])
702 end do
703 residunorm = newresidu
704 residumax = 0.0_rp
705 call mpi_allreduce(cons, residumax, 1, mpi_real_precision, &
706 mpi_max, neko_comm, ierr)
707 steg = 2.0_rp*steg
708 end do
709 epsi = 0.1_rp * epsi
710 end do outer
711
712 ! Save the new designx
713 call device_copy(this%xold2%x_d, this%xold1%x_d, this%n)
714 call device_copy(this%xold1%x_d, designx%x_d, this%n)
715 call device_copy(designx%x_d, x%x_d, this%n)
716
717 ! update the parameters of the MMA object nesessary to compute KKT residual
718 call device_copy(this%y%x_d, y%x_d, this%m)
719 this%z = z
720 call device_copy(this%lambda%x_d, lambda%x_d, this%m)
721 this%zeta = zeta
722 call device_copy(this%xsi%x_d, xsi%x_d, this%n)
723 call device_copy(this%eta%x_d, eta%x_d, this%n)
724 call device_copy(this%mu%x_d, mu%x_d, this%m)
725 call device_copy(this%s%x_d, s%x_d, this%m)
726
727 end subroutine mma_subsolve_dpip_device
728
729
730
731end submodule mma_device
subroutine, public device_dx(dx_d, delx_d, diagx_d, gg_d, dlambda_d, n, m)
subroutine, public device_relambda(relambda_d, x_d, upp_d, low_d, pij_d, qij_d, n, m)
real(kind=rp) function, public device_maxval2(dxx_d, xx_d, cons, n)
subroutine, public device_delx(delx_d, x_d, low_d, upp_d, pij_d, qij_d, p0j_d, q0j_d, alpha_d, beta_d, lambda_d, epsi, n, m)
subroutine, public device_mma_gensub4(x_d, low_d, upp_d, pij_d, qij_d, n, m, bi_d)
subroutine, public device_deta(deta_d, eta_d, dx_d, x_d, beta_d, epsi, n)
subroutine, public device_rex(rex_d, x_d, low_d, upp_d, pij_d, p0j_d, qij_d, q0j_d, lambda_d, xsi_d, eta_d, n, m)
real(kind=rp) function, public device_maxval3(dx_d, x_d, alpha_d, cons, n)
subroutine, public device_diagx(diagx_d, x_d, xsi_d, low_d, upp_d, p0j_d, q0j_d, pij_d, qij_d, alpha_d, beta_d, eta_d, lambda_d, n, m)
subroutine, public device_mma_gensub3(x_d, df0dx_d, dfdx_d, low_d, upp_d, min_d, max_d, alpha_d, beta_d, p0j_d, q0j_d, pij_d, qij_d, n, m)
real(kind=rp) function, public device_norm(rex_d, n)
subroutine, public device_mma_max(xsi_d, x_d, alpha_d, n)
subroutine, public device_bb(bb_d, gg_d, delx_d, diagx_d, n, m)
subroutine, public device_dxsi(dxsi_d, xsi_d, dx_d, x_d, alpha_d, epsi, n)
subroutine, public device_mma_gensub2(low_d, upp_d, x_d, xold1_d, xold2_d, xmin_d, xmax_d, asydecr, asyincr, n)
subroutine, public device_updatebb(bb_d, dellambda_d, dely_d, d_d, mu_d, y_d, delz, m)
subroutine, public device_updateaa(aa_d, globaltmp_mm_d, s_d, lambda_d, d_d, mu_d, y_d, a_d, zeta, z, m)
real(kind=rp) function, public device_maxval(rex_d, n)
subroutine, public device_aa(aa_d, gg_d, diagx_d, n, m)
subroutine, public device_kkt_rex(rex_d, df0dx_d, dfdx_d, xsi_d, eta_d, lambda_d, n, m)
real(kind=rp) function, public device_lcsc2(a_d, b_d, n)
subroutine, public device_dy(dy_d, dely_d, dlambda_d, d_d, mu_d, y_d, n)
subroutine, public device_gg(gg_d, x_d, low_d, upp_d, pij_d, qij_d, n, m)
subroutine, public device_add2inv2(a_d, b_d, c, n)
subroutine, public device_max2(a_d, b, c_d, d, n)
Definition mma.f90:34