70 use num_types,
only: rp
71 use json_module,
only: json_file
72 use json_utils,
only: json_get_or_default
73 use vector,
only: vector_t
74 use matrix,
only: matrix_t
75 use mpi_f08,
only: mpi_allreduce, mpi_integer, mpi_sum, mpi_comm_world
76 use comm,
only: pe_rank
77 use utils,
only: neko_error
78 use neko_config,
only: neko_bcknd_device, neko_bcknd_cuda, neko_bcknd_hip, &
80 use device,
only: device_memcpy, host_to_device, device_to_host
81 use,
intrinsic :: iso_c_binding, only: c_ptr
82 use logger,
only: neko_log
89 integer :: n, m, max_iter
90 real(kind=rp) :: a0, f0val, asyinit, asyincr, asydecr, epsimin, &
92 type(vector_t) :: xold1, xold2, low, upp, alpha, beta, a, c, d, xmax, xmin
93 logical :: is_initialized = .false.
94 logical :: is_updated = .false.
95 character(len=:),
allocatable :: bcknd, subsolver
98 type(vector_t) :: p0j, q0j
99 type(matrix_t) :: pij, qij
103 real(kind=rp) :: z, zeta
104 type(vector_t) :: y, lambda, s, mu
105 type(vector_t) :: xsi, eta
108 generic,
public :: init => init_from_json, init_from_components
109 procedure,
public, pass(this) :: init_from_json => mma_init_from_json
110 procedure,
public, pass(this) :: init_from_components => &
111 mma_init_from_components
112 procedure,
public, pass(this) :: free => mma_free
113 procedure,
public, pass(this) :: get_n => mma_get_n
114 procedure,
public, pass(this) :: get_m => mma_get_m
115 procedure,
public, pass(this) :: get_residumax => mma_get_residumax
116 procedure,
public, pass(this) :: get_residunorm => mma_get_residunorm
117 procedure,
public, pass(this) :: get_max_iter => mma_get_max_iter
118 procedure,
public, pass(this) :: get_backend_and_subsolver => &
119 mma_get_backend_and_subsolver
122 generic,
public :: update => update_vector, update_cpu, update_device
123 procedure, pass(this) :: update_vector => mma_update_vector
124 procedure, pass(this) :: update_cpu => mma_update_cpu
125 procedure, pass(this) :: update_device => mma_update_device
127 generic,
public :: kkt => kkt_vector, kkt_cpu, kkt_device
128 procedure, pass(this) :: kkt_vector => mma_kkt_vector
129 procedure, pass(this) :: kkt_cpu => mma_kkt_cpu
130 procedure, pass(this) :: kkt_device => mma_kkt_device
139 module subroutine mma_update_cpu(this, iter, x, df0dx, fval, dfdx)
140 class(mma_t),
intent(inout) :: this
141 integer,
intent(in) :: iter
142 real(kind=rp),
dimension(this%n),
intent(inout) :: x
143 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
144 real(kind=rp),
dimension(this%m),
intent(in) :: fval
145 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
146 end subroutine mma_update_cpu
149 module subroutine mma_kkt_cpu(this, x, df0dx, fval, dfdx)
150 class(mma_t),
intent(inout) :: this
151 real(kind=rp),
dimension(this%n),
intent(in) :: x
152 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
153 real(kind=rp),
dimension(this%m),
intent(in) :: fval
154 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
155 end subroutine mma_kkt_cpu
161 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
162 class(mma_t),
intent(inout) :: this
163 integer,
intent(in) :: iter
164 type(c_ptr),
intent(inout) :: x
165 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
166 end subroutine mma_update_device
169 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
170 class(mma_t),
intent(inout) :: this
171 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
172 end subroutine mma_kkt_device
179 subroutine mma_init_from_json(this, x, n, m, json, scale, auto_scale)
191 class(mma_t),
intent(inout) :: this
192 integer,
intent(in) :: n, m
193 type(vector_t),
intent(in) :: x
195 type(json_file),
intent(inout) :: json
198 real(kind=rp),
intent(out) :: scale
199 logical,
intent(out) :: auto_scale
207 real(kind=rp),
dimension(n) :: xmax, xmin
208 real(kind=rp),
dimension(m) :: a, c, d
209 character(len=:),
allocatable :: subsolver, bcknd, bcknd_default
212 real(kind=rp) :: a0 , xmax_const, xmin_const, a_const, c_const, d_const
214 integer :: max_iter, n_global, ierr
215 real(kind=rp) :: epsimin, asyinit, asyincr, asydecr
217 call mpi_allreduce(n, n_global, 1, mpi_integer, &
218 mpi_sum, mpi_comm_world, ierr)
221 if (neko_bcknd_device .eq. 1)
then
222 bcknd_default =
"device"
224 bcknd_default =
"cpu"
230 call json_get_or_default(json,
'mma.epsimin', epsimin, &
231 1.0e-9_rp * sqrt(real(m + n_global, rp)))
232 call json_get_or_default(json,
'mma.max_iter', max_iter, 100)
235 call json_get_or_default(json,
'mma.asyinit', asyinit, 0.5_rp)
236 call json_get_or_default(json,
'mma.asyincr', asyincr, 1.2_rp)
237 call json_get_or_default(json,
'mma.asydecr', asydecr, 0.7_rp)
239 call json_get_or_default(json,
'mma.backend', bcknd, bcknd_default)
240 call json_get_or_default(json,
'mma.subsolver', subsolver,
'dip')
242 call json_get_or_default(json,
'mma.xmin', xmin_const, 0.0_rp)
243 call json_get_or_default(json,
'mma.xmax', xmax_const, 1.0_rp)
244 call json_get_or_default(json,
'mma.a0', a0, 1.0_rp)
245 call json_get_or_default(json,
'mma.a', a_const, 0.0_rp)
246 call json_get_or_default(json,
'mma.c', c_const, 100.0_rp)
247 call json_get_or_default(json,
'mma.d', d_const, 0.0_rp)
249 call json_get_or_default(json,
'mma.scale', scale, 10.0_rp)
250 call json_get_or_default(json,
'mma.auto_scale', auto_scale, .false.)
262 call this%init(x, n, m, a0, a, c, d, xmin, xmax, &
263 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
265 end subroutine mma_init_from_json
268 subroutine mma_free(this)
269 class(mma_t),
intent(inout) :: this
271 call this%xold1%free()
272 call this%xold2%free()
273 call this%alpha%free()
274 call this%beta%free()
280 call this%xmax%free()
281 call this%xmin%free()
286 call this%lambda%free()
296 this%is_initialized = .false.
297 this%is_updated = .false.
298 end subroutine mma_free
301 subroutine mma_init_from_components(this, x, n, m, a0, a, c, d, xmin, xmax, &
302 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
314 class(mma_t),
intent(inout) :: this
315 integer,
intent(in) :: n, m
316 type(vector_t),
intent(in) :: x
324 real(kind=rp),
intent(in),
dimension(n) :: xmax, xmin
325 real(kind=rp),
intent(in),
dimension(m) :: a, c, d
326 real(kind=rp),
intent(in) :: a0
327 integer,
intent(in),
optional :: max_iter
328 real(kind=rp),
intent(in),
optional :: epsimin, asyinit, asyincr, asydecr
329 character(len=:),
intent(in),
allocatable :: bcknd, subsolver
330 character(len=256) :: log_msg
338 call this%xold1%init(n)
339 call this%xold2%init(n)
343 call this%alpha%init(n)
344 call this%beta%init(n)
349 call this%low%init(n)
350 call this%upp%init(n)
351 call this%xmax%init(n)
352 call this%xmin%init(n)
355 call this%p0j%init(n)
356 call this%q0j%init(n)
357 call this%pij%init(m, n)
358 call this%qij%init(m, n)
363 call this%lambda%init(m)
366 call this%xsi%init(n)
367 call this%eta%init(n)
378 if (neko_bcknd_device .eq. 1)
then
379 call device_memcpy(this%a%x, this%a%x_d, m, host_to_device, &
381 call device_memcpy(this%c%x, this%c%x_d, m, host_to_device, &
383 call device_memcpy(this%d%x, this%d%x_d, m, host_to_device, &
385 call device_memcpy(this%xmax%x, this%xmax%x_d, n, host_to_device, &
387 call device_memcpy(this%xmin%x, this%xmin%x_d, n, host_to_device, &
393 this%residumax = huge(0.0_rp)
394 this%residunorm = huge(0.0_rp)
400 if (.not.
present(epsimin)) this%epsimin = 1.0e-9_rp * sqrt(real(m + n, rp))
401 if (.not.
present(max_iter)) this%max_iter = 100
404 if (.not.
present(asyinit)) this%asyinit = 0.5_rp
405 if (.not.
present(asyincr)) this%asyincr = 1.2_rp
406 if (.not.
present(asydecr)) this%asydecr = 0.7_rp
409 if (
present(max_iter)) this%max_iter = max_iter
410 if (
present(epsimin)) this%epsimin = epsimin
411 if (
present(asyinit)) this%asyinit = asyinit
412 if (
present(asyincr)) this%asyincr = asyincr
413 if (
present(asydecr)) this%asydecr = asydecr
415 this%subsolver = subsolver
417 call neko_log%section(
'MMA Parameters')
419 write(log_msg,
'(A10,1X,A)')
'backend ', trim(this%bcknd)
420 call neko_log%message(log_msg)
421 write(log_msg,
'(A10,1X,A)')
'subsolver ', trim(this%subsolver)
422 call neko_log%message(log_msg)
424 write(log_msg,
'(A10,1X,I0)')
'n ', this%n
425 call neko_log%message(log_msg)
426 write(log_msg,
'(A10,1X,I0)')
'm ', this%m
427 call neko_log%message(log_msg)
428 write(log_msg,
'(A10,1X,I0)')
'max_iter ', this%max_iter
429 call neko_log%message(log_msg)
431 write(log_msg,
'(A10,1X,E11.5)')
'epsimin ', this%epsimin
432 call neko_log%message(log_msg)
434 write(log_msg,
'(A10,1X,E11.5)')
'asyinit ', this%asyinit
435 call neko_log%message(log_msg)
436 write(log_msg,
'(A10,1X,E11.5)')
'asyincr ', this%asyincr
437 call neko_log%message(log_msg)
438 write(log_msg,
'(A10,1X,E11.5)')
'asydecr ', this%asydecr
439 call neko_log%message(log_msg)
440 write(log_msg,
'(A10,1X,E11.5)')
'a0 ', this%a0
441 call neko_log%message(log_msg)
443 call neko_log%message(
'Parameters a')
445 write(log_msg,
'(3X,A,I2,A,E11.5)')
'a(', i,
') = ', this%a%x(i)
446 call neko_log%message(log_msg)
448 call neko_log%message(
'Parameters c')
450 write(log_msg,
'(3X,A,I2,A,E11.5)')
'c(', i,
') = ', this%c%x(i)
451 call neko_log%message(log_msg)
453 call neko_log%message(
'Parameters d')
455 write(log_msg,
'(3X,A,I2,A,E11.5)')
'd(', i,
') = ', this%d%x(i)
456 call neko_log%message(log_msg)
459 call neko_log%end_section()
462 this%is_initialized = .true.
463 end subroutine mma_init_from_components
466 subroutine mma_update_vector(this, iter, x, df0dx, fval, dfdx)
467 class(mma_t),
intent(inout) :: this
468 integer,
intent(in) :: iter
469 type(vector_t),
intent(inout) :: x
470 type(vector_t),
intent(inout) :: df0dx, fval
471 type(matrix_t),
intent(inout) :: dfdx
474 select case (this%bcknd)
476 if (neko_bcknd_device .eq. 1)
then
477 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
479 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
481 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
483 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
487 call mma_update_cpu(this, iter, x%x, df0dx%x, fval%x, dfdx%x)
489 if (neko_bcknd_device .eq. 1)
then
490 call device_memcpy(x%x, x%x_d, this%n, host_to_device, sync = .true.)
494 call mma_update_device(this, iter, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
497 end subroutine mma_update_vector
500 subroutine mma_kkt_vector(this, x, df0dx, fval, dfdx)
501 class(mma_t),
intent(inout) :: this
502 type(vector_t),
intent(inout) :: x, df0dx, fval
503 type(matrix_t),
intent(inout) :: dfdx
506 select case (this%bcknd )
508 if (neko_bcknd_device .eq. 1)
then
509 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
511 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
513 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
515 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
519 call mma_kkt_cpu(this, x%x, df0dx%x, fval%x, dfdx%x)
521 call mma_kkt_device(this, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
523 end subroutine mma_kkt_vector
529 pure function mma_get_n(this)
result(n)
530 class(mma_t),
intent(in) :: this
533 end function mma_get_n
536 pure function mma_get_m(this)
result(m)
537 class(mma_t),
intent(in) :: this
540 end function mma_get_m
543 pure function mma_get_residumax(this)
result(residumax)
544 class(mma_t),
intent(in) :: this
545 real(kind=rp) :: residumax
546 residumax = this%residumax
547 end function mma_get_residumax
550 pure function mma_get_residunorm(this)
result(residunorm)
551 class(mma_t),
intent(in) :: this
552 real(kind=rp) :: residunorm
553 residunorm = this%residunorm
554 end function mma_get_residunorm
557 pure function mma_get_max_iter(this)
result(max_iter_value)
558 class(mma_t),
intent(in) :: this
559 integer :: max_iter_value
560 max_iter_value = this%max_iter
561 end function mma_get_max_iter
564 pure function mma_get_backend_and_subsolver(this)
result(backend_subsolver)
565 class(mma_t),
intent(in) :: this
566 character(len=:),
allocatable :: backend_subsolver
567 character(len=:),
allocatable :: backend
569 if (neko_bcknd_cuda .eq. 1)
then
571 else if (neko_bcknd_hip .eq. 1)
then
573 else if (neko_bcknd_opencl .eq. 1)
then
579 backend_subsolver =
'backend:' // trim(backend) //
', subsolver:' // &
581 end function mma_get_backend_and_subsolver