36 use num_types,
only: rp
37 use json_module,
only: json_file
38 use json_utils,
only: json_get_or_default
39 use vector,
only: vector_t
40 use matrix,
only: matrix_t
41 use mpi_f08,
only: mpi_allreduce, mpi_integer, mpi_sum, mpi_comm_world
42 use comm,
only: pe_rank
43 use utils,
only: neko_error
44 use neko_config,
only: neko_bcknd_device, neko_bcknd_cuda, neko_bcknd_hip, &
46 use device,
only: device_memcpy, host_to_device, device_to_host
47 use,
intrinsic :: iso_c_binding, only: c_ptr
54 integer :: n, m, max_iter
55 real(kind=rp) :: a0, f0val, asyinit, asyincr, asydecr, epsimin, &
57 type(vector_t) :: xold1, xold2, low, upp, alpha, beta, a, c, d, xmax, xmin
58 logical :: is_initialized = .false.
59 logical :: is_updated = .false.
60 character(len=:),
allocatable :: bcknd, subsolver
63 type(vector_t) :: p0j, q0j
64 type(matrix_t) :: pij, qij
68 real(kind=rp) :: z, zeta
69 type(vector_t) :: y, lambda, s, mu
70 type(vector_t) :: xsi, eta
73 generic,
public :: init => init_from_json, init_from_components
74 procedure,
public, pass(this) :: init_from_json => mma_init_from_json
75 procedure,
public, pass(this) :: init_from_components => &
76 mma_init_from_components
77 procedure,
public, pass(this) :: free => mma_free
78 procedure,
public, pass(this) :: get_n => mma_get_n
79 procedure,
public, pass(this) :: get_m => mma_get_m
80 procedure,
public, pass(this) :: get_residumax => mma_get_residumax
81 procedure,
public, pass(this) :: get_residunorm => mma_get_residunorm
82 procedure,
public, pass(this) :: get_max_iter => mma_get_max_iter
83 procedure,
public, pass(this) :: get_backend_and_subsolver => &
84 mma_get_backend_and_subsolver
87 generic,
public :: update => update_vector, update_cpu, update_device
88 procedure, pass(this) :: update_vector => mma_update_vector
89 procedure, pass(this) :: update_cpu => mma_update_cpu
90 procedure, pass(this) :: update_device => mma_update_device
92 generic,
public :: kkt => kkt_vector, kkt_cpu, kkt_device
93 procedure, pass(this) :: kkt_vector => mma_kkt_vector
94 procedure, pass(this) :: kkt_cpu => mma_kkt_cpu
95 procedure, pass(this) :: kkt_device => mma_kkt_device
104 module subroutine mma_update_cpu(this, iter, x, df0dx, fval, dfdx)
105 class(mma_t),
intent(inout) :: this
106 integer,
intent(in) :: iter
107 real(kind=rp),
dimension(this%n),
intent(inout) :: x
108 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
109 real(kind=rp),
dimension(this%m),
intent(in) :: fval
110 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
111 end subroutine mma_update_cpu
114 module subroutine mma_kkt_cpu(this, x, df0dx, fval, dfdx)
115 class(mma_t),
intent(inout) :: this
116 real(kind=rp),
dimension(this%n),
intent(in) :: x
117 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
118 real(kind=rp),
dimension(this%m),
intent(in) :: fval
119 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
120 end subroutine mma_kkt_cpu
126 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
127 class(mma_t),
intent(inout) :: this
128 integer,
intent(in) :: iter
129 type(c_ptr),
intent(inout) :: x
130 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
131 end subroutine mma_update_device
134 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
135 class(mma_t),
intent(inout) :: this
136 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
137 end subroutine mma_kkt_device
144 subroutine mma_init_from_json(this, x, n, m, json, scale, auto_scale)
156 class(mma_t),
intent(inout) :: this
157 integer,
intent(in) :: n, m
158 type(vector_t),
intent(in) :: x
160 type(json_file),
intent(inout) :: json
163 real(kind=rp),
intent(out) :: scale
164 logical,
intent(out) :: auto_scale
172 real(kind=rp),
dimension(n) :: xmax, xmin
173 real(kind=rp),
dimension(m) :: a, c, d
174 character(len=:),
allocatable :: subsolver, bcknd, bcknd_default
177 real(kind=rp) :: a0 , xmax_const, xmin_const, a_const, c_const, d_const
179 integer :: max_iter, n_global, ierr
180 real(kind=rp) :: epsimin, asyinit, asyincr, asydecr
182 call mpi_allreduce(n, n_global, 1, mpi_integer, &
183 mpi_sum, mpi_comm_world, ierr)
186 if (neko_bcknd_device .eq. 1)
then
187 bcknd_default =
"device"
189 bcknd_default =
"cpu"
195 call json_get_or_default(json,
'mma.epsimin', epsimin, &
196 1.0e-9_rp * sqrt(real(m + n_global, rp)))
197 call json_get_or_default(json,
'mma.max_iter', max_iter, 100)
200 call json_get_or_default(json,
'mma.asyinit', asyinit, 0.5_rp)
201 call json_get_or_default(json,
'mma.asyincr', asyincr, 1.2_rp)
202 call json_get_or_default(json,
'mma.asydecr', asydecr, 0.7_rp)
204 call json_get_or_default(json,
'mma.backend', bcknd, bcknd_default)
205 call json_get_or_default(json,
'mma.subsolver', subsolver,
'dip')
207 call json_get_or_default(json,
'mma.xmin', xmin_const, 0.0_rp)
208 call json_get_or_default(json,
'mma.xmax', xmax_const, 1.0_rp)
209 call json_get_or_default(json,
'mma.a0', a0, 1.0_rp)
210 call json_get_or_default(json,
'mma.a', a_const, 0.0_rp)
211 call json_get_or_default(json,
'mma.c', c_const, 100.0_rp)
212 call json_get_or_default(json,
'mma.d', d_const, 0.0_rp)
214 call json_get_or_default(json,
'mma.scale', scale, 10.0_rp)
215 call json_get_or_default(json,
'mma.auto_scale', auto_scale, .false.)
224 if (pe_rank .eq. 0)
then
225 print *,
"Initializing MMA backend to >>> ", bcknd
231 call this%init(x, n, m, a0, a, c, d, xmin, xmax, &
232 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
234 end subroutine mma_init_from_json
237 subroutine mma_free(this)
238 class(mma_t),
intent(inout) :: this
240 call this%xold1%free()
241 call this%xold2%free()
242 call this%alpha%free()
243 call this%beta%free()
249 call this%xmax%free()
250 call this%xmin%free()
255 call this%lambda%free()
265 this%is_initialized = .false.
266 this%is_updated = .false.
267 end subroutine mma_free
270 subroutine mma_init_from_components(this, x, n, m, a0, a, c, d, xmin, xmax, &
271 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
283 class(mma_t),
intent(inout) :: this
284 integer,
intent(in) :: n, m
285 type(vector_t),
intent(in) :: x
293 real(kind=rp),
intent(in),
dimension(n) :: xmax, xmin
294 real(kind=rp),
intent(in),
dimension(m) :: a, c, d
295 real(kind=rp),
intent(in) :: a0
296 integer,
intent(in),
optional :: max_iter
297 real(kind=rp),
intent(in),
optional :: epsimin, asyinit, asyincr, asydecr
298 character(len=:),
intent(in),
allocatable :: bcknd, subsolver
305 call this%xold1%init(n)
306 call this%xold2%init(n)
310 call this%alpha%init(n)
311 call this%beta%init(n)
316 call this%low%init(n)
317 call this%upp%init(n)
318 call this%xmax%init(n)
319 call this%xmin%init(n)
322 call this%p0j%init(n)
323 call this%q0j%init(n)
324 call this%pij%init(m, n)
325 call this%qij%init(m, n)
330 call this%lambda%init(m)
333 call this%xsi%init(n)
334 call this%eta%init(n)
345 if (neko_bcknd_device .eq. 1)
then
346 call device_memcpy(this%a%x, this%a%x_d, m, host_to_device, &
348 call device_memcpy(this%c%x, this%c%x_d, m, host_to_device, &
350 call device_memcpy(this%d%x, this%d%x_d, m, host_to_device, &
352 call device_memcpy(this%xmax%x, this%xmax%x_d, n, host_to_device, &
354 call device_memcpy(this%xmin%x, this%xmin%x_d, n, host_to_device, &
360 this%residumax = huge(0.0_rp)
361 this%residunorm = huge(0.0_rp)
366 if (pe_rank == 0)
then
367 print *,
"MMA initialized with CPU backend!"
370 if (pe_rank == 0)
then
371 if (neko_bcknd_cuda .eq. 1)
then
372 print *,
"MMA initialized with CUDA backend!"
373 else if (neko_bcknd_hip .eq. 1)
then
374 print *,
"MMA initialized with HIP backend!"
375 else if (neko_bcknd_opencl .eq. 1)
then
376 print *,
"MMA initialized with OPENCL backend!"
378 call neko_error(
'Unknown backend device in mma_init_components')
382 call neko_error(
'Unknown backend in mma_init_components')
389 if (.not.
present(epsimin)) this%epsimin = 1.0e-9_rp * sqrt(real(m + n, rp))
390 if (.not.
present(max_iter)) this%max_iter = 100
393 if (.not.
present(asyinit)) this%asyinit = 0.5_rp
394 if (.not.
present(asyincr)) this%asyincr = 1.2_rp
395 if (.not.
present(asydecr)) this%asydecr = 0.7_rp
398 if (
present(max_iter)) this%max_iter = max_iter
399 if (
present(epsimin)) this%epsimin = epsimin
400 if (
present(asyinit)) this%asyinit = asyinit
401 if (
present(asyincr)) this%asyincr = asyincr
402 if (
present(asydecr)) this%asydecr = asydecr
404 this%subsolver = subsolver
405 if (pe_rank == 0)
then
406 if (this%subsolver .eq.
"dip")
then
407 print *,
"Using dual solver for MMA subsolve."
408 elseif (this%subsolver .eq.
"dpip")
then
409 print *,
"Using dual-primal solver for MMA subsolve."
411 call neko_error(
'Unknown subsolver for MMA, mma_init_from_components')
415 if (pe_rank .eq. 0)
then
416 print *,
"MMA is initialized with a0 = ", a0,
", a = ", a,
", c = ", c, &
417 ", d = ", d,
"epsimin = ", this%epsimin
420 this%is_initialized = .true.
421 end subroutine mma_init_from_components
424 subroutine mma_update_vector(this, iter, x, df0dx, fval, dfdx)
425 class(mma_t),
intent(inout) :: this
426 integer,
intent(in) :: iter
427 type(vector_t),
intent(inout) :: x
428 type(vector_t),
intent(inout) :: df0dx, fval
429 type(matrix_t),
intent(inout) :: dfdx
432 select case (this%bcknd)
434 if (neko_bcknd_device .eq. 1)
then
435 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
437 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
439 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
441 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
445 call mma_update_cpu(this, iter, x%x, df0dx%x, fval%x, dfdx%x)
447 if (neko_bcknd_device .eq. 1)
then
448 call device_memcpy(x%x, x%x_d, this%n, host_to_device, sync = .true.)
452 call mma_update_device(this, iter, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
455 end subroutine mma_update_vector
458 subroutine mma_kkt_vector(this, x, df0dx, fval, dfdx)
459 class(mma_t),
intent(inout) :: this
460 type(vector_t),
intent(inout) :: x, df0dx, fval
461 type(matrix_t),
intent(inout) :: dfdx
464 select case (this%bcknd )
466 if (neko_bcknd_device .eq. 1)
then
467 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
469 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
471 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
473 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
477 call mma_kkt_cpu(this, x%x, df0dx%x, fval%x, dfdx%x)
479 call mma_kkt_device(this, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
481 end subroutine mma_kkt_vector
487 pure function mma_get_n(this)
result(n)
488 class(mma_t),
intent(in) :: this
491 end function mma_get_n
494 pure function mma_get_m(this)
result(m)
495 class(mma_t),
intent(in) :: this
498 end function mma_get_m
501 pure function mma_get_residumax(this)
result(residumax)
502 class(mma_t),
intent(in) :: this
503 real(kind=rp) :: residumax
504 residumax = this%residumax
505 end function mma_get_residumax
508 pure function mma_get_residunorm(this)
result(residunorm)
509 class(mma_t),
intent(in) :: this
510 real(kind=rp) :: residunorm
511 residunorm = this%residunorm
512 end function mma_get_residunorm
515 pure function mma_get_max_iter(this)
result(max_iter_value)
516 class(mma_t),
intent(in) :: this
517 integer :: max_iter_value
518 max_iter_value = this%max_iter
519 end function mma_get_max_iter
522 pure function mma_get_backend_and_subsolver(this)
result(backend_subsolver)
523 class(mma_t),
intent(in) :: this
524 character(len=:),
allocatable :: backend_subsolver
525 character(len=:),
allocatable :: backend
527 if (neko_bcknd_cuda .eq. 1)
then
529 else if (neko_bcknd_hip .eq. 1)
then
531 else if (neko_bcknd_opencl .eq. 1)
then
537 backend_subsolver =
'backend:' // trim(backend) //
', subsolver:' // &
539 end function mma_get_backend_and_subsolver