71 use num_types,
only: rp, dp, sp
72 use json_module,
only: json_file
73 use json_utils,
only: json_get_or_default
74 use vector,
only: vector_t
75 use matrix,
only: matrix_t
76 use comm,
only: pe_rank, neko_comm, pe_size, mpi_real_precision
77 use utils,
only: neko_error, filename_suffix
78 use neko_config,
only: neko_bcknd_device, neko_bcknd_cuda, neko_bcknd_hip, &
80 use device,
only: device_memcpy, host_to_device, device_to_host
81 use,
intrinsic :: iso_c_binding, only: c_ptr
82 use logger,
only: neko_log
83 use mpi_f08,
only: mpi_sum, mpi_allreduce, mpi_integer
91 integer :: n, m, n_global, max_iter
92 real(kind=rp) :: a0, asyinit, asyincr, asydecr, epsimin, &
94 type(vector_t) :: xold1, xold2, low, upp, alpha, beta, a, c, d, xmax, xmin
95 logical :: is_initialized = .false.
96 logical :: is_updated = .false.
97 character(len=:),
allocatable :: subsolver, bcknd
100 type(vector_t) :: p0j, q0j
101 type(matrix_t) :: pij, qij
105 real(kind=rp) :: z, zeta
106 type(vector_t) :: y, lambda, s, mu
107 type(vector_t) :: xsi, eta
110 generic,
public :: init => init_from_json, init_from_components
112 procedure,
public, pass(this) :: init_from_components => &
113 mma_init_from_components
114 procedure,
public, pass(this) :: free => mma_free
115 procedure,
public, pass(this) :: get_n => mma_get_n
116 procedure,
public, pass(this) :: get_m => mma_get_m
117 procedure,
public, pass(this) :: get_residumax => mma_get_residumax
118 procedure,
public, pass(this) :: get_residunorm => mma_get_residunorm
119 procedure,
public, pass(this) :: get_max_iter => mma_get_max_iter
120 procedure,
public, pass(this) :: get_backend_and_subsolver => &
121 mma_get_backend_and_subsolver
123 generic,
public :: update => update_vector, update_cpu, update_device
124 procedure, pass(this) :: update_vector => mma_update_vector
125 procedure, pass(this) :: update_cpu => mma_update_cpu
126 procedure, pass(this) :: update_device => mma_update_device
128 generic,
public :: kkt => kkt_vector, kkt_cpu, kkt_device
129 procedure, pass(this) :: kkt_vector => mma_kkt_vector
130 procedure, pass(this) :: kkt_cpu => mma_kkt_cpu
131 procedure, pass(this) :: kkt_device => mma_kkt_device
133 procedure, pass(this) :: save_checkpoint => mma_save_checkpoint
134 procedure, pass(this) :: load_checkpoint => mma_load_checkpoint
137 procedure, pass(this) :: copy_from => mma_copy_from
146 module subroutine mma_update_cpu(this, iter, x, df0dx, fval, dfdx)
147 class(mma_t),
intent(inout) :: this
148 integer,
intent(in) :: iter
149 real(kind=rp),
dimension(this%n),
intent(inout) :: x
150 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
151 real(kind=rp),
dimension(this%m),
intent(in) :: fval
152 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
153 end subroutine mma_update_cpu
156 module subroutine mma_kkt_cpu(this, x, df0dx, fval, dfdx)
157 class(mma_t),
intent(inout) :: this
158 real(kind=rp),
dimension(this%n),
intent(in) :: x
159 real(kind=rp),
dimension(this%n),
intent(in) :: df0dx
160 real(kind=rp),
dimension(this%m),
intent(in) :: fval
161 real(kind=rp),
dimension(this%m, this%n),
intent(in) :: dfdx
162 end subroutine mma_kkt_cpu
168 module subroutine mma_update_device(this, iter, x, df0dx, fval, dfdx)
169 class(mma_t),
intent(inout) :: this
170 integer,
intent(in) :: iter
171 type(c_ptr),
intent(inout) :: x
172 type(c_ptr),
intent(in) :: df0dx, fval, dfdx
173 end subroutine mma_update_device
176 module subroutine mma_kkt_device(this, x, df0dx, fval, dfdx)
177 class(mma_t),
intent(inout) :: this
178 type(c_ptr),
intent(in) :: x, df0dx, fval, dfdx
179 end subroutine mma_kkt_device
187 module subroutine mma_save_checkpoint_hdf5(object, filename, overwrite)
188 class(mma_t),
intent(inout) :: object
189 character(len=*),
intent(in) :: filename
190 logical,
intent(in),
optional :: overwrite
191 end subroutine mma_save_checkpoint_hdf5
193 module subroutine mma_load_checkpoint_hdf5(object, filename)
194 class(mma_t),
intent(inout) :: object
195 character(len=*),
intent(in) :: filename
196 end subroutine mma_load_checkpoint_hdf5
205 subroutine mma_init_from_json(this, x, n, m, json, scale, auto_scale)
217 class(mma_t),
intent(inout) :: this
218 integer,
intent(in) :: n, m
219 type(vector_t),
intent(in) :: x
221 type(json_file),
intent(inout) :: json
224 real(kind=rp),
intent(out) :: scale
225 logical,
intent(out) :: auto_scale
233 real(kind=rp),
dimension(n) :: xmax, xmin
234 real(kind=rp),
dimension(m) :: a, c, d
235 character(len=:),
allocatable :: subsolver, bcknd, bcknd_default
238 real(kind=rp) :: a0 , xmax_const, xmin_const, a_const, c_const, d_const
240 integer :: max_iter, n_global, ierr
241 real(kind=rp) :: epsimin, asyinit, asyincr, asydecr
243 call mpi_allreduce(n, n_global, 1, mpi_integer, &
244 mpi_sum, neko_comm, ierr)
247 if (neko_bcknd_device .eq. 1)
then
248 bcknd_default =
"device"
250 bcknd_default =
"cpu"
256 call json_get_or_default(json,
'mma.epsimin', epsimin, &
257 1.0e-9_rp * sqrt(real(m + n_global, rp)))
258 call json_get_or_default(json,
'mma.max_iter', max_iter, 100)
261 call json_get_or_default(json,
'mma.asyinit', asyinit, 0.5_rp)
262 call json_get_or_default(json,
'mma.asyincr', asyincr, 1.2_rp)
263 call json_get_or_default(json,
'mma.asydecr', asydecr, 0.7_rp)
265 call json_get_or_default(json,
'mma.backend', bcknd, bcknd_default)
266 call json_get_or_default(json,
'mma.subsolver', subsolver,
'dip')
268 call json_get_or_default(json,
'mma.xmin', xmin_const, 0.0_rp)
269 call json_get_or_default(json,
'mma.xmax', xmax_const, 1.0_rp)
270 call json_get_or_default(json,
'mma.a0', a0, 1.0_rp)
271 call json_get_or_default(json,
'mma.a', a_const, 0.0_rp)
272 call json_get_or_default(json,
'mma.c', c_const, 100.0_rp)
273 call json_get_or_default(json,
'mma.d', d_const, 0.0_rp)
275 call json_get_or_default(json,
'mma.scale', scale, 10.0_rp)
276 call json_get_or_default(json,
'mma.auto_scale', auto_scale, .false.)
288 call this%init(x, n, m, a0, a, c, d, xmin, xmax, &
289 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
291 end subroutine mma_init_from_json
294 subroutine mma_free(this)
295 class(mma_t),
intent(inout) :: this
297 call this%xold1%free()
298 call this%xold2%free()
299 call this%alpha%free()
300 call this%beta%free()
306 call this%xmax%free()
307 call this%xmin%free()
312 call this%lambda%free()
322 this%is_initialized = .false.
323 this%is_updated = .false.
324 end subroutine mma_free
327 subroutine mma_init_from_components(this, x, n, m, a0, a, c, d, xmin, xmax, &
328 max_iter, epsimin, asyinit, asyincr, asydecr, bcknd, subsolver)
340 class(mma_t),
intent(inout) :: this
341 integer,
intent(in) :: n, m
342 type(vector_t),
intent(in) :: x
350 real(kind=rp),
intent(in),
dimension(n) :: xmax, xmin
351 real(kind=rp),
intent(in),
dimension(m) :: a, c, d
352 real(kind=rp),
intent(in) :: a0
353 integer,
intent(in),
optional :: max_iter
354 real(kind=rp),
intent(in),
optional :: epsimin, asyinit, asyincr, asydecr
355 character(len=*),
intent(in),
optional :: bcknd, subsolver
356 character(len=256) :: log_msg
364 call this%xold1%init(n)
365 call this%xold2%init(n)
369 call this%alpha%init(n)
370 call this%beta%init(n)
375 call this%low%init(n)
376 call this%upp%init(n)
377 call this%xmax%init(n)
378 call this%xmin%init(n)
381 call this%p0j%init(n)
382 call this%q0j%init(n)
383 call this%pij%init(m, n)
384 call this%qij%init(m, n)
389 call this%lambda%init(m)
392 call this%xsi%init(n)
393 call this%eta%init(n)
404 if (neko_bcknd_device .eq. 1)
then
405 call this%a%copy_from(host_to_device, sync = .false.)
406 call this%c%copy_from(host_to_device, sync = .false.)
407 call this%d%copy_from(host_to_device, sync = .false.)
408 call this%xmax%copy_from(host_to_device, sync = .false.)
409 call this%xmin%copy_from(host_to_device, sync = .true.)
413 this%residumax = huge(0.0_rp)
414 this%residunorm = huge(0.0_rp)
417 call mpi_allreduce(n, this%n_global, 1, mpi_integer, mpi_sum, neko_comm, &
424 if (.not.
present(max_iter)) this%max_iter = 100
425 if (.not.
present(epsimin))
then
426 this%epsimin = 1.0e-9_rp * sqrt(real(this%m + this%n_global, rp))
430 if (.not.
present(asyinit)) this%asyinit = 0.5_rp
431 if (.not.
present(asyincr)) this%asyincr = 1.2_rp
432 if (.not.
present(asydecr)) this%asydecr = 0.7_rp
435 if (.not.
present(bcknd) .and. neko_bcknd_device .eq. 0)
then
437 else if (.not.
present(bcknd))
then
438 this%bcknd =
"device"
442 if (.not.
present(subsolver)) this%subsolver =
"dip"
445 if (
present(max_iter)) this%max_iter = max_iter
446 if (
present(epsimin)) this%epsimin = epsimin
447 if (
present(asyinit)) this%asyinit = asyinit
448 if (
present(asyincr)) this%asyincr = asyincr
449 if (
present(asydecr)) this%asydecr = asydecr
450 if (
present(bcknd)) this%bcknd = bcknd
451 if (
present(subsolver)) this%subsolver = subsolver
453 call neko_log%section(
'MMA Parameters')
455 write(log_msg,
'(A10,1X,A)')
'backend ', trim(this%bcknd)
456 call neko_log%message(log_msg)
457 write(log_msg,
'(A10,1X,A)')
'subsolver ', trim(this%subsolver)
458 call neko_log%message(log_msg)
460 write(log_msg,
'(A10,1X,I0)')
'n ', this%n_global
461 call neko_log%message(log_msg)
462 write(log_msg,
'(A10,1X,I0)')
'm ', this%m
463 call neko_log%message(log_msg)
464 write(log_msg,
'(A10,1X,I0)')
'max_iter ', this%max_iter
465 call neko_log%message(log_msg)
467 write(log_msg,
'(A10,1X,E11.5)')
'epsimin ', this%epsimin
468 call neko_log%message(log_msg)
470 write(log_msg,
'(A10,1X,E11.5)')
'asyinit ', this%asyinit
471 call neko_log%message(log_msg)
472 write(log_msg,
'(A10,1X,E11.5)')
'asyincr ', this%asyincr
473 call neko_log%message(log_msg)
474 write(log_msg,
'(A10,1X,E11.5)')
'asydecr ', this%asydecr
475 call neko_log%message(log_msg)
476 write(log_msg,
'(A10,1X,E11.5)')
'a0 ', this%a0
477 call neko_log%message(log_msg)
479 call neko_log%message(
'Parameters a')
481 write(log_msg,
'(3X,A,I2,A,E11.5)')
'a(', i,
') = ', this%a%x(i)
482 call neko_log%message(log_msg)
484 call neko_log%message(
'Parameters c')
486 write(log_msg,
'(3X,A,I2,A,E11.5)')
'c(', i,
') = ', this%c%x(i)
487 call neko_log%message(log_msg)
489 call neko_log%message(
'Parameters d')
491 write(log_msg,
'(3X,A,I2,A,E11.5)')
'd(', i,
') = ', this%d%x(i)
492 call neko_log%message(log_msg)
495 call neko_log%end_section()
498 this%is_initialized = .true.
499 end subroutine mma_init_from_components
505 subroutine mma_update_vector(this, iter, x, df0dx, fval, dfdx)
506 class(mma_t),
intent(inout) :: this
507 integer,
intent(in) :: iter
508 type(vector_t),
intent(inout) :: x
509 type(vector_t),
intent(inout) :: df0dx, fval
510 type(matrix_t),
intent(inout) :: dfdx
513 select case (this%bcknd)
515 if (neko_bcknd_device .eq. 1)
then
516 call x%copy_from(device_to_host, sync = .false.)
517 call df0dx%copy_from(device_to_host, sync = .false.)
518 call fval%copy_from(device_to_host, sync = .false.)
519 call dfdx%copy_from(device_to_host, sync = .true.)
522 call mma_update_cpu(this, iter, x%x, df0dx%x, fval%x, dfdx%x)
524 if (neko_bcknd_device .eq. 1)
then
525 call x%copy_from(host_to_device, sync = .true.)
529 call mma_update_device(this, iter, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
532 end subroutine mma_update_vector
535 subroutine mma_kkt_vector(this, x, df0dx, fval, dfdx)
536 class(mma_t),
intent(inout) :: this
537 type(vector_t),
intent(inout) :: x, df0dx, fval
538 type(matrix_t),
intent(inout) :: dfdx
541 select case (this%bcknd )
543 if (neko_bcknd_device .eq. 1)
then
544 call device_memcpy(x%x, x%x_d, this%n, device_to_host, &
546 call device_memcpy(df0dx%x, df0dx%x_d, this%n, device_to_host, &
548 call device_memcpy(fval%x, fval%x_d, this%m, device_to_host, &
550 call device_memcpy(dfdx%x, dfdx%x_d, this%m * this%n, device_to_host,&
554 call mma_kkt_cpu(this, x%x, df0dx%x, fval%x, dfdx%x)
556 call mma_kkt_device(this, x%x_d, df0dx%x_d, fval%x_d, dfdx%x_d)
558 end subroutine mma_kkt_vector
567 subroutine mma_save_checkpoint(this, filename, overwrite)
568 class(mma_t),
intent(inout) :: this
569 character(len=*),
intent(in) :: filename
570 logical,
intent(in),
optional :: overwrite
571 character(len=12) :: file_ext
574 call filename_suffix(filename, file_ext)
576 select case (trim(file_ext))
577 case (
'h5',
'hdf5',
'hf5')
578 call mma_save_checkpoint_hdf5(this, filename, overwrite)
580 call neko_error(
'mma_save_checkpoint: Unsupported file format: ' // &
584 end subroutine mma_save_checkpoint
589 subroutine mma_load_checkpoint(this, filename)
590 class(mma_t),
intent(inout) :: this
591 character(len=*),
intent(in) :: filename
592 character(len=12) :: file_ext
595 call filename_suffix(filename, file_ext)
597 select case (trim(file_ext))
598 case (
'h5',
'hdf5',
'hf5')
599 call mma_load_checkpoint_hdf5(this, filename)
601 call neko_error(
'mma_load_checkpoint: Unsupported file format: ' // &
604 end subroutine mma_load_checkpoint
610 pure function mma_get_n(this)
result(n)
611 class(mma_t),
intent(in) :: this
614 end function mma_get_n
617 pure function mma_get_m(this)
result(m)
618 class(mma_t),
intent(in) :: this
621 end function mma_get_m
624 pure function mma_get_residumax(this)
result(residumax)
625 class(mma_t),
intent(in) :: this
626 real(kind=rp) :: residumax
627 residumax = this%residumax
628 end function mma_get_residumax
631 pure function mma_get_residunorm(this)
result(residunorm)
632 class(mma_t),
intent(in) :: this
633 real(kind=rp) :: residunorm
634 residunorm = this%residunorm
635 end function mma_get_residunorm
638 pure function mma_get_max_iter(this)
result(max_iter_value)
639 class(mma_t),
intent(in) :: this
640 integer :: max_iter_value
641 max_iter_value = this%max_iter
642 end function mma_get_max_iter
645 pure function mma_get_backend_and_subsolver(this)
result(backend_subsolver)
646 class(mma_t),
intent(in) :: this
647 character(len=:),
allocatable :: backend_subsolver
648 character(len=:),
allocatable :: backend
650 if (neko_bcknd_cuda .eq. 1)
then
652 else if (neko_bcknd_hip .eq. 1)
then
654 else if (neko_bcknd_opencl .eq. 1)
then
660 backend_subsolver =
'backend:' // trim(backend) //
', subsolver:' // &
662 end function mma_get_backend_and_subsolver
668 subroutine mma_copy_from(this, direction, sync)
669 class(mma_t),
intent(inout) :: this
670 integer,
intent(in) :: direction
671 logical,
intent(in) :: sync
673 call this%xold1%copy_from(direction, sync = .false.)
674 call this%xold2%copy_from(direction, sync = .false.)
675 call this%xmax%copy_from(direction, sync = .false.)
676 call this%xmin%copy_from(direction, sync = .false.)
678 call this%low%copy_from(direction, sync = .false.)
679 call this%upp%copy_from(direction, sync = .false.)
681 call this%a%copy_from(direction, sync = .false.)
682 call this%c%copy_from(direction, sync = .false.)
683 call this%d%copy_from(direction, sync = .false.)
684 call this%y%copy_from(direction, sync = .false.)
685 call this%s%copy_from(direction, sync = .false.)
687 call this%p0j%copy_from(direction, sync = .false.)
688 call this%q0j%copy_from(direction, sync = .false.)
689 call this%pij%copy_from(direction, sync = .false.)
690 call this%qij%copy_from(direction, sync = .false.)
691 call this%bi%copy_from(direction, sync = .false.)
693 call this%alpha%copy_from(direction, sync = .false.)
694 call this%beta%copy_from(direction, sync = .false.)
695 call this%lambda%copy_from(direction, sync = .false.)
696 call this%mu%copy_from(direction, sync = .false.)
697 call this%xsi%copy_from(direction, sync = .false.)
698 call this%eta%copy_from(direction, sync = sync)
700 end subroutine mma_copy_from
706 module subroutine mma_save_checkpoint_hdf5(object, filename, overwrite)
707 class(mma_t),
intent(inout) :: object
708 character(len=*),
intent(in) :: filename
709 logical,
intent(in),
optional :: overwrite
710 call neko_error(
'mma: HDF5 support not enabled rebuild with HAVE_HDF5')
711 end subroutine mma_save_checkpoint_hdf5
713 module subroutine mma_load_checkpoint_hdf5(object, filename)
714 class(mma_t),
intent(inout) :: object
715 character(len=*),
intent(in) :: filename
716 call neko_error(
'mma: HDF5 support not enabled rebuild with HAVE_HDF5')
717 end subroutine mma_load_checkpoint_hdf5
subroutine mma_init_from_json(this, x, n, m, json, scale, auto_scale)
Device KKT check for convergence.