71 use num_types,
only: rp, dp, sp
72 use json_module,
only: json_file
73 use json_utils,
only: json_get_or_default
74 use vector,
only: vector_t
75 use matrix,
only: matrix_t
76 use comm,
only: pe_rank, neko_comm, pe_size, mpi_real_precision
77 use utils,
only: neko_error, filename_suffix
78 use neko_config,
only: neko_bcknd_device, neko_bcknd_cuda, neko_bcknd_hip, &
80 use device,
only: device_memcpy, host_to_device, device_to_host
81 use,
intrinsic :: iso_c_binding, only: c_ptr
82 use logger,
only: neko_log
83 use mpi_f08,
only: mpi_sum, mpi_allreduce, mpi_integer
84 use scratch_registry,
only: scratch_registry_t
92 integer :: n, m, n_global, max_iter
93 real(kind=rp) :: a0, asyinit, asyincr, asydecr, epsimin, &
95 real(kind=rp) :: move_limit = 0.2_rp
96 type(vector_t) :: xold1, xold2, low, upp, alpha, beta, a, c, d, xmax, xmin
97 logical :: is_initialized = .false.
98 logical :: is_updated = .false.
99 type(scratch_registry_t) :: scratch
100 character(len=:),
allocatable :: subsolver, bcknd
103 type(vector_t) :: p0j, q0j
104 type(matrix_t) :: pij, qij
108 real(kind=rp) :: z, zeta
109 type(vector_t) :: y, lambda, s, mu
110 type(vector_t) :: xsi, eta
113 generic,
public :: init => init_from_json, init_from_components
114 procedure,
public, pass(this) :: init_from_json => mma_init_from_json
115 procedure,
public, pass(this) :: init_from_components => &
116 mma_init_from_components
117 procedure,
public, pass(this) :: free => mma_free
118 procedure,
public, pass(this) :: get_n => mma_get_n
119 procedure,
public, pass(this) :: get_m => mma_get_m
120 procedure,
public, pass(this) :: get_residumax => mma_get_residumax
121 procedure,
public, pass(this) :: get_residunorm => mma_get_residunorm
122 procedure,
public, pass(this) :: get_max_iter => mma_get_max_iter
123 procedure,
public, pass(this) :: get_backend_and_subsolver => &
124 mma_get_backend_and_subsolver
126 generic,
public :: update => update_vector, update_cpu, update_device
127 procedure, pass(this) :: update_vector => mma_update_vector
128 procedure, pass(this) :: update_cpu => mma_update_cpu
129 procedure, pass(this) :: update_device => mma_update_device
131 generic,
public :: kkt => kkt_vector, kkt_cpu, kkt_device
132 procedure, pass(this) :: kkt_vector => mma_kkt_vector
133 procedure, pass(this) :: kkt_cpu => mma_kkt_cpu
134 procedure, pass(this) :: kkt_device => mma_kkt_device
136 procedure, pass(this) :: save_checkpoint => mma_save_checkpoint
137 procedure, pass(this) :: load_checkpoint => mma_load_checkpoint
140 procedure, pass(this) :: copy_from => mma_copy_from
147 real(kind=rp),
parameter :: a0_default = 1.0_rp
148 real(kind=rp),
parameter :: a_default = 0.0_rp
149 real(kind=rp),
parameter :: c_default = 100.0_rp
150 real(kind=rp),
parameter :: d_default = 0.0_rp
151 real(kind=rp),
parameter :: xmin_default = 0.0_rp
152 real(kind=rp),
parameter :: xmax_default = 1.0_rp
154 real(kind=rp),
parameter :: asyinit_default = 0.2_rp
155 real(kind=rp),
parameter :: asyincr_default = 1.05_rp
156 real(kind=rp),
parameter :: asydecr_default = 0.65_rp
157 real(kind=rp),
parameter :: move_limit_default = 0.2_rp
159 integer,
parameter :: max_iter_default = 100
160 character(len=*),
parameter :: subsolver_default =
"dip"
161 real(kind=rp),
parameter :: scale_default = 1.0_rp
162 logical,
parameter :: auto_scale_default = .false.
169 module subroutine mma_update_cpu(this, iter, x, df0dx, fval, dfdx)
170 class(mma_t),
intent(inout) :: this
171 integer,
intent(in) :: iter
172 real(kind=rp),
dimension(this%n),
intent(inout) :: x
173 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
174 real(kind=rp),
dimension(this%m),
intent(in) :: fval
175 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
176 end subroutine mma_update_cpu
179 module subroutine mma_kkt_cpu(this, x, df0dx, fval, dfdx)
180 class(mma_t),
intent(inout) :: this
181 real(kind=rp),
dimension(this%n),
intent(in) :: x
182 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
183 real(kind=rp),
dimension(this%m),
intent(in) :: fval
184 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
185 end subroutine mma_kkt_cpu
191 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
192 class(mma_t),
intent(inout) :: this
193 integer,
intent(in) :: iter
194 type(c_ptr),
intent(inout) :: x
195 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
196 end subroutine mma_update_device
199 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
200 class(mma_t),
intent(inout) :: this
201 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
202 end subroutine mma_kkt_device
210 module subroutine mma_save_checkpoint_hdf5(object, filename, overwrite)
211 class(mma_t),
intent(inout) :: object
212 character(len=*),
intent(in) :: filename
213 logical,
intent(in),
optional :: overwrite
214 end subroutine mma_save_checkpoint_hdf5
216 module subroutine mma_load_checkpoint_hdf5(object, filename)
217 class(mma_t),
intent(inout) :: object
218 character(len=*),
intent(in) :: filename
219 end subroutine mma_load_checkpoint_hdf5
228 subroutine mma_init_from_json(this, x, n, m, json, scale, auto_scale)
240 class(mma_t),
intent(inout) :: this
241 integer,
intent(in) :: n, m
242 type(vector_t),
intent(in) :: x
244 type(json_file),
intent(inout) :: json
247 real(kind=rp),
intent(out) :: scale
248 logical,
intent(out) :: auto_scale
256 real(kind=rp),
dimension(n) :: xmax, xmin
257 real(kind=rp),
dimension(m) :: a, c, d
258 character(len=:),
allocatable :: subsolver, bcknd, bcknd_default
261 real(kind=rp) :: a0 , xmax_const, xmin_const, a_const, c_const, d_const
262 real(kind=rp) :: move_limit
264 integer :: max_iter, n_global, ierr
265 real(kind=rp) :: epsimin, asyinit, asyincr, asydecr
267 call mpi_allreduce(n, n_global, 1, mpi_integer, &
268 mpi_sum, neko_comm, ierr)
271 if (neko_bcknd_device .eq. 1)
then
272 bcknd_default =
"device"
274 bcknd_default =
"cpu"
280 call json_get_or_default(json,
'mma.epsimin', epsimin, &
281 1.0e-9_rp * sqrt(real(m + n_global, rp)))
282 call json_get_or_default(json,
'mma.max_iter', max_iter, max_iter_default)
285 call json_get_or_default(json,
'mma.asyinit', asyinit, asyinit_default)
286 call json_get_or_default(json,
'mma.asyincr', asyincr, asyincr_default)
287 call json_get_or_default(json,
'mma.asydecr', asydecr, asydecr_default)
289 call json_get_or_default(json,
'mma.backend', bcknd, bcknd_default)
290 call json_get_or_default(json,
'mma.subsolver', subsolver, subsolver_default)
292 call json_get_or_default(json,
'mma.xmin', xmin_const, xmin_default)
293 call json_get_or_default(json,
'mma.xmax', xmax_const, xmax_default)
294 call json_get_or_default(json,
'mma.a0', a0, a0_default)
295 call json_get_or_default(json,
'mma.a', a_const, a_default)
296 call json_get_or_default(json,
'mma.c', c_const, c_default)
297 call json_get_or_default(json,
'mma.d', d_const, d_default)
298 call json_get_or_default(json,
'mma.move_limit', move_limit, move_limit_default)
300 call json_get_or_default(json,
'mma.scale', scale, scale_default)
301 call json_get_or_default(json,
'mma.auto_scale', auto_scale, auto_scale_default)
313 call this%init(x, n, m, a0, a, c, d, xmin, xmax, &
314 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver, &
317 end subroutine mma_init_from_json
320 subroutine mma_free(this)
321 class(mma_t),
intent(inout) :: this
323 call this%xold1%free()
324 call this%xold2%free()
325 call this%alpha%free()
326 call this%beta%free()
332 call this%xmax%free()
333 call this%xmin%free()
338 call this%lambda%free()
347 call this%scratch%free()
349 this%is_initialized = .false.
350 this%is_updated = .false.
351 end subroutine mma_free
354 subroutine mma_init_from_components(this, x, n, m, a0, a, c, d, xmin, xmax, &
355 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver, &
368 class(mma_t),
intent(inout) :: this
369 integer,
intent(in) :: n, m
370 type(vector_t),
intent(in) :: x
378 real(kind=rp),
intent(in),
dimension(n) :: xmax, xmin
379 real(kind=rp),
intent(in),
dimension(m) :: a, c, d
380 real(kind=rp),
intent(in) :: a0
381 integer,
intent(in),
optional :: max_iter
382 real(kind=rp),
intent(in),
optional :: epsimin, asyinit, asyincr, asydecr
383 real(kind=rp),
intent(in),
optional :: move_limit
384 character(len=*),
intent(in),
optional :: bcknd, subsolver
385 character(len=256) :: log_msg
389 call this%scratch%init()
394 call this%xold1%init(n)
395 call this%xold2%init(n)
399 call this%alpha%init(n)
400 call this%beta%init(n)
405 call this%low%init(n)
406 call this%upp%init(n)
407 call this%xmax%init(n)
408 call this%xmin%init(n)
411 call this%p0j%init(n)
412 call this%q0j%init(n)
413 call this%pij%init(m, n)
414 call this%qij%init(m, n)
419 call this%lambda%init(m)
422 call this%xsi%init(n)
423 call this%eta%init(n)
434 if (neko_bcknd_device .eq. 1)
then
435 call this%a%copy_from(host_to_device, sync = .false.)
436 call this%c%copy_from(host_to_device, sync = .false.)
437 call this%d%copy_from(host_to_device, sync = .false.)
438 call this%xmax%copy_from(host_to_device, sync = .false.)
439 call this%xmin%copy_from(host_to_device, sync = .true.)
443 this%residumax = huge(0.0_rp)
444 this%residunorm = huge(0.0_rp)
447 call mpi_allreduce(n, this%n_global, 1, mpi_integer, mpi_sum, neko_comm, &
454 if (.not.
present(max_iter)) this%max_iter = max_iter_default
455 if (.not.
present(epsimin))
then
456 this%epsimin = 1.0e-9_rp * sqrt(real(this%m + this%n_global, rp))
460 if (.not.
present(asyinit)) this%asyinit = asyinit_default
461 if (.not.
present(asyincr)) this%asyincr = asyincr_default
462 if (.not.
present(asydecr)) this%asydecr = asydecr_default
463 if (.not.
present(move_limit)) this%move_limit = move_limit_default
466 if (.not.
present(bcknd) .and. neko_bcknd_device .eq. 0)
then
468 else if (.not.
present(bcknd))
then
469 this%bcknd =
"device"
473 if (.not.
present(subsolver)) this%subsolver = subsolver_default
476 if (
present(max_iter)) this%max_iter = max_iter
477 if (
present(epsimin)) this%epsimin = epsimin
478 if (
present(asyinit)) this%asyinit = asyinit
479 if (
present(asyincr)) this%asyincr = asyincr
480 if (
present(asydecr)) this%asydecr = asydecr
481 if (
present(move_limit)) this%move_limit = move_limit
482 if (
present(bcknd)) this%bcknd = bcknd
483 if (
present(subsolver)) this%subsolver = subsolver
485 call neko_log%section(
'MMA Parameters')
487 write(log_msg,
'(A10,1X,A)')
'backend ', trim(this%bcknd)
488 call neko_log%message(log_msg)
489 write(log_msg,
'(A10,1X,A)')
'subsolver ', trim(this%subsolver)
490 call neko_log%message(log_msg)
492 write(log_msg,
'(A10,1X,I0)')
'n ', this%n_global
493 call neko_log%message(log_msg)
494 write(log_msg,
'(A10,1X,I0)')
'm ', this%m
495 call neko_log%message(log_msg)
496 write(log_msg,
'(A10,1X,I0)')
'max_iter ', this%max_iter
497 call neko_log%message(log_msg)
499 write(log_msg,
'(A10,1X,E11.5)')
'epsimin ', this%epsimin
500 call neko_log%message(log_msg)
502 write(log_msg,
'(A10,1X,E11.5)')
'asyinit ', this%asyinit
503 call neko_log%message(log_msg)
504 write(log_msg,
'(A10,1X,E11.5)')
'asyincr ', this%asyincr
505 call neko_log%message(log_msg)
506 write(log_msg,
'(A10,1X,E11.5)')
'asydecr ', this%asydecr
507 call neko_log%message(log_msg)
508 write(log_msg,
'(A10,1X,E11.5)')
'a0 ', this%a0
509 call neko_log%message(log_msg)
510 write(log_msg,
'(A10,1X,E11.5)')
'movelimit ', this%move_limit
511 call neko_log%message(log_msg)
513 call neko_log%message(
'Parameters a')
515 write(log_msg,
'(3X,A,I2,A,E11.5)')
'a(', i,
') = ', this%a%x(i)
516 call neko_log%message(log_msg)
518 call neko_log%message(
'Parameters c')
520 write(log_msg,
'(3X,A,I2,A,E11.5)')
'c(', i,
') = ', this%c%x(i)
521 call neko_log%message(log_msg)
523 call neko_log%message(
'Parameters d')
525 write(log_msg,
'(3X,A,I2,A,E11.5)')
'd(', i,
') = ', this%d%x(i)
526 call neko_log%message(log_msg)
529 call neko_log%end_section()
532 this%is_initialized = .true.
533 end subroutine mma_init_from_components
539 subroutine mma_update_vector(this, iter, x, df0dx, fval, dfdx)
540 class(mma_t),
intent(inout) :: this
541 integer,
intent(in) :: iter
542 type(vector_t),
intent(inout) :: x
543 type(vector_t),
intent(inout) :: df0dx, fval
544 type(matrix_t),
intent(inout) :: dfdx
547 select case (this%bcknd)
549 if (neko_bcknd_device .eq. 1)
then
550 call x%copy_from(device_to_host, sync = .false.)
551 call df0dx%copy_from(device_to_host, sync = .false.)
552 call fval%copy_from(device_to_host, sync = .false.)
553 call dfdx%copy_from(device_to_host, sync = .true.)
556 call mma_update_cpu(this, iter, x%x, df0dx%x, fval%x, dfdx%x)
558 if (neko_bcknd_device .eq. 1)
then
559 call x%copy_from(host_to_device, sync = .true.)
563 call mma_update_device(this, iter, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
566 end subroutine mma_update_vector
569 subroutine mma_kkt_vector(this, x, df0dx, fval, dfdx)
570 class(mma_t),
intent(inout) :: this
571 type(vector_t),
intent(inout) :: x, df0dx, fval
572 type(matrix_t),
intent(inout) :: dfdx
575 select case (this%bcknd )
577 if (neko_bcknd_device .eq. 1)
then
578 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
580 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
582 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
584 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
588 call mma_kkt_cpu(this, x%x, df0dx%x, fval%x, dfdx%x)
590 call mma_kkt_device(this, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
592 end subroutine mma_kkt_vector
601 subroutine mma_save_checkpoint(this, filename, overwrite)
602 class(mma_t),
intent(inout) :: this
603 character(len=*),
intent(in) :: filename
604 logical,
intent(in),
optional :: overwrite
605 character(len=12) :: file_ext
608 call filename_suffix(filename, file_ext)
610 select case (trim(file_ext))
611 case (
'h5',
'hdf5',
'hf5')
612 call mma_save_checkpoint_hdf5(this, filename, overwrite)
614 call neko_error(
'mma_save_checkpoint: Unsupported file format: ' // &
618 end subroutine mma_save_checkpoint
623 subroutine mma_load_checkpoint(this, filename)
624 class(mma_t),
intent(inout) :: this
625 character(len=*),
intent(in) :: filename
626 character(len=12) :: file_ext
629 call filename_suffix(filename, file_ext)
631 select case (trim(file_ext))
632 case (
'h5',
'hdf5',
'hf5')
633 call mma_load_checkpoint_hdf5(this, filename)
635 call neko_error(
'mma_load_checkpoint: Unsupported file format: ' // &
638 end subroutine mma_load_checkpoint
644 pure function mma_get_n(this)
result(n)
645 class(mma_t),
intent(in) :: this
648 end function mma_get_n
651 pure function mma_get_m(this)
result(m)
652 class(mma_t),
intent(in) :: this
655 end function mma_get_m
658 pure function mma_get_residumax(this)
result(residumax)
659 class(mma_t),
intent(in) :: this
660 real(kind=rp) :: residumax
661 residumax = this%residumax
662 end function mma_get_residumax
665 pure function mma_get_residunorm(this)
result(residunorm)
666 class(mma_t),
intent(in) :: this
667 real(kind=rp) :: residunorm
668 residunorm = this%residunorm
669 end function mma_get_residunorm
672 pure function mma_get_max_iter(this)
result(max_iter_value)
673 class(mma_t),
intent(in) :: this
674 integer :: max_iter_value
675 max_iter_value = this%max_iter
676 end function mma_get_max_iter
679 pure function mma_get_backend_and_subsolver(this)
result(backend_subsolver)
680 class(mma_t),
intent(in) :: this
681 character(len=:),
allocatable :: backend_subsolver
682 character(len=:),
allocatable :: backend
684 if (neko_bcknd_cuda .eq. 1)
then
686 else if (neko_bcknd_hip .eq. 1)
then
688 else if (neko_bcknd_opencl .eq. 1)
then
694 backend_subsolver =
'backend:' // trim(backend) //
', subsolver:' // &
696 end function mma_get_backend_and_subsolver
702 subroutine mma_copy_from(this, direction, sync)
703 class(mma_t),
intent(inout) :: this
704 integer,
intent(in) :: direction
705 logical,
intent(in) :: sync
707 call this%xold1%copy_from(direction, sync = .false.)
708 call this%xold2%copy_from(direction, sync = .false.)
709 call this%xmax%copy_from(direction, sync = .false.)
710 call this%xmin%copy_from(direction, sync = .false.)
712 call this%low%copy_from(direction, sync = .false.)
713 call this%upp%copy_from(direction, sync = .false.)
715 call this%a%copy_from(direction, sync = .false.)
716 call this%c%copy_from(direction, sync = .false.)
717 call this%d%copy_from(direction, sync = .false.)
718 call this%y%copy_from(direction, sync = .false.)
719 call this%s%copy_from(direction, sync = .false.)
721 call this%p0j%copy_from(direction, sync = .false.)
722 call this%q0j%copy_from(direction, sync = .false.)
723 call this%pij%copy_from(direction, sync = .false.)
724 call this%qij%copy_from(direction, sync = .false.)
725 call this%bi%copy_from(direction, sync = .false.)
727 call this%alpha%copy_from(direction, sync = .false.)
728 call this%beta%copy_from(direction, sync = .false.)
729 call this%lambda%copy_from(direction, sync = .false.)
730 call this%mu%copy_from(direction, sync = .false.)
731 call this%xsi%copy_from(direction, sync = .false.)
732 call this%eta%copy_from(direction, sync = sync)
734 end subroutine mma_copy_from
740 module subroutine mma_save_checkpoint_hdf5(object, filename, overwrite)
741 class(mma_t),
intent(inout) :: object
742 character(len=*),
intent(in) :: filename
743 logical,
intent(in),
optional :: overwrite
744 call neko_error(
'mma: HDF5 support not enabled rebuild with HAVE_HDF5')
745 end subroutine mma_save_checkpoint_hdf5
747 module subroutine mma_load_checkpoint_hdf5(object, filename)
748 class(mma_t),
intent(inout) :: object
749 character(len=*),
intent(in) :: filename
750 call neko_error(
'mma: HDF5 support not enabled rebuild with HAVE_HDF5')
751 end subroutine mma_load_checkpoint_hdf5