35module simulation_checkpoint
36 use num_types,
only: rp, sp, dp
37 use case,
only: case_t
38 use json_file_module,
only: json_file
39 use json_utils,
only: json_get, json_get_or_default
40 use scalar_scheme,
only: scalar_scheme_t
41 use time_state,
only: time_state_t
42 use chkp_output,
only: chkp_output_t
43 use field,
only: field_t
44 use field_list,
only: field_list_t
45 use logger,
only: neko_log, log_size, neko_log_debug
46 use mpi_f08,
only: mpi_wtime, mpi_barrier
47 use comm,
only: neko_comm, pe_rank
48 use utils,
only: neko_error
49 use math,
only: copy, rzero
50 use profiler,
only: profiler_start_region, profiler_end_region
51 use neko_config,
only: neko_bcknd_device
52 use device,
only: device_memcpy, device_to_host, host_to_device
53 use registry,
only: neko_registry
64 logical :: enabled = .false.
66 character(len=256) :: algorithm =
"linear"
68 character(len=256) :: filename =
"forward_checkpoint"
70 character(len=256) :: path =
"checkpoints/"
72 character(len=8) :: fmt =
"chkp"
74 integer :: n_saves_memory = 10
76 logical :: keep_checkpoints = .false.
79 integer :: n_saves_disc = 0
80 integer :: n_timesteps = 0
81 integer :: first_valid_timestep = 2
82 integer :: loaded_checkpoint = -1
85 type(field_list_t) :: state_list
86 type(host_array),
dimension(:,:),
allocatable :: state_storage
89 type(chkp_output_t) :: chkp_output
93 generic,
public :: init => init_from_json, init_from_components
95 procedure,
public, pass(this) :: init_from_json => &
96 checkpoint_init_from_json
98 procedure,
public, pass(this) :: init_from_components => &
99 checkpoint_init_from_components
101 procedure,
public, pass(this) :: free => checkpoint_free
103 procedure,
public, pass(this) :: reset => checkpoint_reset
105 procedure,
public, pass(this) ::
save => checkpoint_save
107 procedure,
public, pass(this) :: restore => checkpoint_restore
110 procedure, pass(this) :: save_data => checkpoint_save_data
112 procedure, pass(this) :: load_data => checkpoint_load_data
116 real(kind=rp),
allocatable ::
data(:)
119 procedure, pass(this) :: init => host_array_init
120 procedure, pass(this) :: free => host_array_free
121 procedure, pass(this) :: is_allocated => host_array_is_allocated
129 module subroutine checkpoint_save_linear(this, neko_case)
130 class(simulation_checkpoint_t),
intent(inout) :: this
131 class(case_t),
intent(inout) :: neko_case
132 end subroutine checkpoint_save_linear
135 module subroutine checkpoint_restore_linear(this, neko_case, tstep)
136 class(simulation_checkpoint_t),
intent(inout) :: this
137 class(case_t),
target,
intent(inout) :: neko_case
138 integer,
intent(in) :: tstep
139 end subroutine checkpoint_restore_linear
148 subroutine checkpoint_init_from_json(this, neko_case, params)
149 class(simulation_checkpoint_t),
intent(inout) :: this
150 class(case_t),
target,
intent(inout) :: neko_case
151 type(json_file),
target,
intent(inout) :: params
152 integer :: n_saves_memory
153 character(len=:),
allocatable :: path, filename, algorithm, fmt
154 character(len=256),
dimension(:),
allocatable :: extra_field_names
155 type(field_list_t) :: extra_fields
156 type(field_t),
pointer :: fi
158 logical :: enabled, keep_checkpoints
160 call json_get_or_default(params,
"enabled", enabled, .false.)
161 if (.not. enabled)
return
163 call json_get_or_default(params,
"algorithm", algorithm,
"linear")
164 call json_get_or_default(params,
"n_memory", n_saves_memory, 10)
165 call json_get_or_default(params,
"path", path,
"checkpoints/")
166 call json_get_or_default(params,
"filename", filename,
"checkpoint")
167 call json_get_or_default(params,
"format", fmt,
"chkp")
168 call json_get_or_default(params,
"keep_checkpoints", keep_checkpoints, &
171 if (
"extra_fields" .in. params)
then
172 allocate(extra_field_names(0))
173 call json_get(params,
"extra_fields", extra_field_names)
174 call extra_fields%init(
size(extra_field_names))
175 do i = 1,
size(extra_field_names)
176 fi => neko_registry%get_field(extra_field_names(i))
177 call extra_fields%assign(i, fi)
180 call this%init_from_components(neko_case, algorithm, n_saves_memory, &
181 path, filename, fmt, keep_checkpoints, extra_fields)
184 call this%init_from_components(neko_case, algorithm, n_saves_memory, &
185 path, filename, fmt, keep_checkpoints)
188 end subroutine checkpoint_init_from_json
191 subroutine checkpoint_init_from_components(this, neko_case, algorithm, &
192 n_saves_memory, path, filename, fmt, keep_checkpoints, extra_fields)
193 class(simulation_checkpoint_t),
intent(inout),
target :: this
194 class(case_t),
target,
intent(inout) :: neko_case
195 character(len=*),
optional,
intent(in) :: algorithm
196 integer,
optional,
intent(in) :: n_saves_memory
197 character(len=*),
optional,
intent(in) :: path
198 character(len=*),
optional,
intent(in) :: filename
199 character(len=*),
optional,
intent(in) :: fmt
200 logical,
optional,
intent(in) :: keep_checkpoints
201 type(field_list_t),
optional,
intent(inout) :: extra_fields
202 type(field_t),
pointer :: si
203 character(len=LOG_SIZE) :: msg
204 integer :: i, n_states
210 this%enabled = .true.
211 if (
present(algorithm)) this%algorithm = algorithm
212 if (
present(n_saves_memory)) this%n_saves_memory = n_saves_memory
213 if (
present(path)) this%path = trim(path)
214 if (
present(filename)) this%filename = trim(filename)
215 if (
present(fmt)) this%fmt = trim(fmt)
216 if (
present(keep_checkpoints)) this%keep_checkpoints = keep_checkpoints
218 inquire(file = trim(this%path), exist = exists)
219 if (.not. exists)
then
220 call mpi_barrier(neko_comm)
221 if (pe_rank .eq. 0)
then
222 call execute_command_line(
"mkdir -p '" // trim(this%path) //
"'")
224 call mpi_barrier(neko_comm)
228 call this%chkp_output%init(neko_case%chkp, this%filename, &
229 fmt = this%fmt, path = this%path, overwrite = .true.)
232 if (
allocated(neko_case%scalars))
then
233 n_states = n_states +
size(neko_case%scalars%scalar_fields)
235 if (
present(extra_fields))
then
236 n_states = n_states + extra_fields%size()
239 call this%state_list%init(n_states)
242 call this%state_list%assign(1, neko_case%fluid%p)
243 call this%state_list%assign(2, neko_case%fluid%u)
244 call this%state_list%assign(3, neko_case%fluid%v)
245 call this%state_list%assign(4, neko_case%fluid%w)
249 if (
allocated(neko_case%scalars))
then
250 do i = 1,
size(neko_case%scalars%scalar_fields)
251 si => neko_case%scalars%scalar_fields(i)%scalar%s
252 call this%state_list%assign(n_states + i, si)
254 n_states = n_states +
size(neko_case%scalars%scalar_fields)
258 if (
present(extra_fields))
then
259 do i = 1, extra_fields%size()
260 si => extra_fields%get_by_index(i)
261 call this%state_list%assign(n_states + i, si)
263 n_states = n_states + extra_fields%size()
267 allocate(this%state_storage(this%n_saves_memory, this%state_list%size()))
270 call neko_log%section(
"Checkpointing")
272 write(msg,
'(A, A)')
"Algorithm: ", trim(this%algorithm)
273 call neko_log%message(trim(msg))
274 write(msg,
'(A,I0)')
"Number of checkpoints in RAM: ", this%n_saves_memory
275 call neko_log%message(trim(msg))
276 write(msg,
'(A, A)')
"Checkpoint file path: ", trim(this%path)
277 call neko_log%message(trim(msg))
278 write(msg,
'(A, A)')
"Checkpoint file name: ", trim(this%filename)
279 call neko_log%message(trim(msg))
280 write(msg,
'(A, A)')
"Checkpoint file format: ", trim(this%fmt)
281 call neko_log%message(trim(msg))
283 if (.not. this%keep_checkpoints)
then
284 call neko_log%message(
"Checkpoint files will be deleted.")
286 call neko_log%message(
"Checkpoint files will be kept.")
289 call neko_log%message(
"Fields in checkpoint:", neko_log_debug)
290 do i = 1, this%state_list%size()
291 si => this%state_list%get(i)
292 call neko_log%message(
" - " // trim(si%name), neko_log_debug)
295 call neko_log%end_section()
297 end subroutine checkpoint_init_from_components
300 subroutine checkpoint_free(this)
301 class(simulation_checkpoint_t),
intent(inout) :: this
303 character(len=1024) :: file_name
305 integer :: stat, unit
308 if (
allocated(this%state_storage))
then
309 do i = 1, this%n_saves_memory
310 do j = 1, this%state_list%size()
311 call this%state_storage(i, j)%free()
316 call this%state_list%free()
317 if (
allocated(this%state_storage))
deallocate(this%state_storage)
320 if (.not. this%keep_checkpoints .and. pe_rank .eq. 0)
then
321 do i = this%n_timesteps, 1, -1
322 call this%chkp_output%set_counter(i)
323 file_name = this%chkp_output%file_%get_fname()
324 inquire(file = trim(file_name), exist = exists)
326 open(newunit = unit, file = trim(file_name), iostat = stat, &
328 if (stat .eq. 0)
close(unit, status =
'delete')
332 call mpi_barrier(neko_comm)
335 this%enabled = .false.
336 this%filename =
"checkpoint"
338 this%algorithm =
"linear"
339 this%n_saves_memory = 10
340 this%keep_checkpoints = .false.
342 this%n_saves_disc = 0
344 this%first_valid_timestep = 2
345 this%loaded_checkpoint = -1
347 end subroutine checkpoint_free
353 subroutine checkpoint_save(this, neko_case)
354 class(simulation_checkpoint_t),
intent(inout) :: this
355 class(case_t),
intent(inout) :: neko_case
357 if (.not. this%enabled)
return
359 call profiler_start_region(
"Checkpoint save")
362 this%n_timesteps = this%n_timesteps + 1
364 select case (this%algorithm)
366 call checkpoint_save_linear(this, neko_case)
368 call neko_error(
"Unknown checkpoint algorithm: " // this%algorithm)
371 call profiler_end_region(
"Checkpoint save")
372 end subroutine checkpoint_save
375 subroutine checkpoint_restore(this, neko_case, tstep)
376 class(simulation_checkpoint_t),
intent(inout) :: this
377 class(case_t),
target,
intent(inout) :: neko_case
378 integer,
intent(in) :: tstep
379 character(len=256) :: msg
381 if (.not. this%enabled)
return
383 call profiler_start_region(
"Checkpoint restore")
385 if (tstep .lt. 1 .or. tstep .gt. this%n_timesteps)
then
386 write(msg,
'(A,I0,A,I0,A)')
"Requested timestep ", tstep, &
387 " is out of range [1, ", this%n_timesteps,
"]"
388 call neko_error(trim(msg))
391 select case (this%algorithm)
393 call checkpoint_restore_linear(this, neko_case, tstep)
395 call neko_error(
"Unknown checkpoint algorithm: " // this%algorithm)
398 call profiler_end_region(
"Checkpoint restore")
399 end subroutine checkpoint_restore
404 subroutine checkpoint_save_data(this, index)
405 class(simulation_checkpoint_t),
intent(inout) :: this
406 integer,
intent(in) :: index
407 type(field_t),
pointer :: si
409 character(len=1024) :: msg
411 if (index .lt. 1 .or. index .gt. this%n_saves_memory)
then
412 write(msg,
'(A,I0,A,I0,A)')
"Checkpoint save index ", index, &
413 " is out of range [1, ", this%n_saves_memory,
"]"
414 call neko_error(trim(msg))
418 do i = 1, this%state_list%size()
419 if (.not. this%state_storage(index, i)%is_allocated())
then
420 si => this%state_list%get(i)
421 call this%state_storage(index, i)%init(si%size())
426 if (neko_bcknd_device .eq. 0)
then
427 do i = 1, this%state_list%size()
428 si => this%state_list%get(i)
429 call copy(this%state_storage(index, i)%data, si%x, si%size())
432 do i = 1, this%state_list%size()
433 si => this%state_list%get(i)
434 call device_memcpy(this%state_storage(index, i)%data, si%x_d, &
435 si%size(), device_to_host, this%state_list%size() .eq. i)
438 end subroutine checkpoint_save_data
443 subroutine checkpoint_load_data(this, index)
444 class(simulation_checkpoint_t),
intent(inout) :: this
445 integer,
intent(in) :: index
446 type(field_t),
pointer :: si
447 character(len=1024) :: msg
450 if (index .lt. 1 .or. index .gt. this%n_saves_memory)
then
451 write(msg,
'(A,I0,A,I0,A)')
"Checkpoint save index ", index, &
452 " is out of range [1, ", this%n_saves_memory,
"]"
453 call neko_error(trim(msg))
457 if (neko_bcknd_device .eq. 0)
then
458 do i = 1, this%state_list%size()
459 si => this%state_list%get(i)
460 call copy(si%x, this%state_storage(index, i)%data, si%size())
463 do i = 1, this%state_list%size()
464 si => this%state_list%get(i)
465 call device_memcpy(this%state_storage(index, i)%data, si%x_d, &
466 si%size(), host_to_device, this%state_list%size() .eq. i)
470 end subroutine checkpoint_load_data
476 subroutine checkpoint_reset(this)
477 class(simulation_checkpoint_t),
intent(inout) :: this
480 if (.not. this%enabled)
return
483 this%loaded_checkpoint = -1
484 this%n_saves_disc = 0
487 do i = 1,
size(this%state_storage, 1)
488 do j = 1,
size(this%state_storage, 2)
489 call rzero(this%state_storage(i, j)%data, &
490 this%state_storage(i, j)%size)
494 end subroutine checkpoint_reset
499 subroutine host_array_init(this, size)
500 class(host_array),
intent(inout) :: this
501 integer,
intent(in) :: size
505 allocate(this%data(size))
506 call rzero(this%data, this%size)
508 end subroutine host_array_init
510 subroutine host_array_free(this)
511 class(host_array),
intent(inout) :: this
514 if (
allocated(this%data))
deallocate(this%data)
516 end subroutine host_array_free
518 pure function host_array_is_allocated(this)
result(is_alloc)
519 class(host_array),
intent(in) :: this
522 is_alloc =
allocated(this%data)
524 end function host_array_is_allocated
526end module simulation_checkpoint