read_batchnorm2d Subroutine

private subroutine read_batchnorm2d(this, unit, verbose)

Read 2D batch normalisation layer from file

Type Bound

batchnorm2d_layer_type

Arguments

Type IntentOptional Attributes Name
class(batchnorm2d_layer_type), intent(inout) :: this

Instance of the 2D batch normalisation layer

integer, intent(in) :: unit

File unit

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  subroutine read_batchnorm2d(this, unit, verbose)
    !! Read 2D batch normalisation layer from file
    use athena__tools_infile, only: assign_val, assign_vec, move
    use coreutils, only: to_lower, to_upper, icount
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    class(batchnorm2d_layer_type), intent(inout) :: this
    !! Instance of the 2D batch normalisation layer
    integer, intent(in) :: unit
    !! File unit
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: stat, verbose_ = 0
    !! Status and verbosity level
    integer :: i, j, k, l, c, itmp1, iline, final_line
    !! Loop variables and temporary integer
    integer :: num_channels
    !! Number of channels
    real(real32) :: momentum = 0._real32, epsilon = 1.E-5_real32
    !! Momentum and epsilon
    class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
    !! Initialisers
    class(base_init_type), allocatable :: &
         moving_mean_initialiser, moving_variance_initialiser
    !! Moving mean and variance initialisers
    character(14) :: gamma_initialiser_name='', beta_initialiser_name=''
    !! Initialisers
    character(14) :: &
         moving_mean_initialiser_name='', &
         moving_variance_initialiser_name=''
    !! Moving mean and variance initialisers
    character(256) :: buffer, tag, err_msg
    !! Buffer, tag, and error message

    integer, dimension(3) :: input_shape
    !! Input shape
    real(real32), allocatable, dimension(:) :: data_list
    !! Data list
    integer, dimension(2) :: param_lines
    !! Lines where parameters are found


    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose


    ! Loop over tags in layer card
    !---------------------------------------------------------------------------
    iline = 0
    param_lines = 0
    final_line = 0
    tag_loop: do

       ! Check for end of file
       !------------------------------------------------------------------------
       read(unit,'(A)',iostat=stat) buffer
       if(stat.ne.0)then
          write(err_msg,'("file encountered error (EoF?) before END ",A)') &
               to_upper(this%name)
          call stop_program(err_msg)
          return
       end if
       if(trim(adjustl(buffer)).eq."") cycle tag_loop

       ! Check for end of layer card
       !------------------------------------------------------------------------
       if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
          final_line = iline
          backspace(unit)
          exit tag_loop
       end if
       iline = iline + 1

       tag=trim(adjustl(buffer))
       if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))

       ! Read parameters from save file
       !------------------------------------------------------------------------
       select case(trim(tag))
       case("INPUT_SHAPE")
          call assign_vec(buffer, input_shape, itmp1)
       case("MOMENTUM")
          call assign_val(buffer, momentum, itmp1)
       case("EPSILON")
          call assign_val(buffer, epsilon, itmp1)
       case("NUM_CHANNELS")
          call assign_val(buffer, num_channels, itmp1)
          write(0,*) "NUM_CHANNELS and INPUT_SHAPE are conflicting parameters"
          write(0,*) "NUM_CHANNELS will be ignored"
       case("GAMMA_INITIALISER", "KERNEL_INITIALISER")
          if(param_lines(1).ne.0)then
             write(err_msg,'("GAMMA and GAMMA_INITIALISER defined. Using GAMMA only.")')
             call print_warning(err_msg)
          end if
          call assign_val(buffer, gamma_initialiser_name, itmp1)
       case("BETA_INITIALISER", "BIAS_INITIALISER")
          if(param_lines(2).ne.0)then
             write(err_msg,'("BETA and BETA_INITIALISER defined. Using BETA only.")')
             call print_warning(err_msg)
          end if
          call assign_val(buffer, beta_initialiser_name, itmp1)
       case("MOVING_MEAN_INITIALISER")
          call assign_val(buffer, moving_mean_initialiser_name, itmp1)
       case("MOVING_VARIANCE_INITIALISER")
          call assign_val(buffer, moving_variance_initialiser_name, itmp1)
       case("GAMMA")
          gamma_initialiser_name = 'zeros'
          param_lines(1) = iline
       case("BETA")
          beta_initialiser_name   = 'zeros'
          param_lines(2) = iline
       case default
          ! Don't look for "e" due to scientific notation of numbers
          ! ... i.e. exponent (E+00)
          if(scan(to_lower(trim(adjustl(buffer))),&
               'abcdfghijklmnopqrstuvwxyz').eq.0)then
             cycle tag_loop
          elseif(tag(:3).eq.'END')then
             cycle tag_loop
          end if
          write(err_msg,'("Unrecognised line in input file: ",A)') &
               trim(adjustl(buffer))
          call stop_program(err_msg)
          return
       end select
    end do tag_loop
    gamma_initialiser = initialiser_setup(gamma_initialiser_name)
    beta_initialiser = initialiser_setup(beta_initialiser_name)
    moving_mean_initialiser = initialiser_setup(moving_mean_initialiser_name)
    moving_variance_initialiser = initialiser_setup(moving_variance_initialiser_name)


    ! Set hyperparameters and initialise layer
    !---------------------------------------------------------------------------
    num_channels = input_shape(size(input_shape,1))
    call this%set_hyperparams( &
         momentum = momentum, &
         epsilon = epsilon, &
         gamma_init_mean = this%gamma_init_mean, &
         gamma_init_std = this%gamma_init_std, &
         beta_init_mean = this%beta_init_mean, &
         beta_init_std = this%beta_init_std, &
         gamma_initialiser = gamma_initialiser, &
         beta_initialiser = beta_initialiser, &
         moving_mean_initialiser = moving_mean_initialiser, &
         moving_variance_initialiser = moving_variance_initialiser, &
         verbose = verbose_ &
    )
    call this%init(input_shape = input_shape)


    ! Check if WEIGHTS card was found
    !---------------------------------------------------------------------------
    allocate(data_list(num_channels), source=0._real32)
    do i = 2, 1, -1
       if(param_lines(i).eq.0) cycle
       call move(unit, param_lines(i) - iline, iostat=stat)
       iline = param_lines(i) + 1
       c = 1
       k = 1
       data_list = 0._real32
       data_concat_loop: do while(c.le.num_channels)
          iline = iline + 1
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit data_concat_loop
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do data_concat_loop
       read(unit,'(A)',iostat=stat) buffer
       select case(i)
       case(1) ! gamma
          this%params(1)%val(1:this%num_channels,1) = data_list
          if(trim(adjustl(buffer)).ne."END GAMMA")then
             write(err_msg,'("END GAMMA not where expected: ",A)') &
                  trim(adjustl(buffer))
             call stop_program(err_msg)
             return
          end if
       case(2) ! beta
          this%params(1)%val(this%num_channels+1:this%num_channels*2,1) = &
               data_list
          if(trim(adjustl(buffer)).ne."END BETA")then
             write(err_msg,'("END BETA not where expected: ",A)') &
                  trim(adjustl(buffer))
             call stop_program(err_msg)
             return
          end if
       end select
    end do
    deallocate(data_list)


    ! Check for end of layer card
    !---------------------------------------------------------------------------
    call move(unit, final_line - iline, iostat=stat)
    read(unit,'(A)') buffer
    if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
       write(0,*) trim(adjustl(buffer))
       write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
       call stop_program(err_msg)
       return
    end if

  end subroutine read_batchnorm2d