Read 1D batch normalisation layer from file
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(batchnorm1d_layer_type), | intent(inout) | :: | this |
Instance of the 1D batch normalisation layer |
||
| integer, | intent(in) | :: | unit |
File unit |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
subroutine read_batchnorm1d(this, unit, verbose) !! Read 1D batch normalisation layer from file use athena__tools_infile, only: assign_val, assign_vec, move use coreutils, only: to_lower, to_upper, icount use athena__initialiser, only: initialiser_setup implicit none ! Arguments class(batchnorm1d_layer_type), intent(inout) :: this !! Instance of the 1D batch normalisation layer integer, intent(in) :: unit !! File unit integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: stat, verbose_ = 0 !! File status and verbosity level integer :: i, j, k, c, itmp1, iline, final_line !! Temporary integers and loop indices integer :: num_channels !! Number of channels real(real32) :: momentum = 0._real32, epsilon = 1.E-5_real32 !! Momentum and epsilon class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser !! Initialisers class(base_init_type), allocatable :: & moving_mean_initialiser, moving_variance_initialiser !! Moving mean and variance initialisers character(14) :: gamma_initialiser_name='', beta_initialiser_name='' !! Initialisers character(14) :: & moving_mean_initialiser_name='', & moving_variance_initialiser_name='' !! Moving mean and variance initialisers character(256) :: buffer, tag, err_msg !! Buffer, tag, and error message integer, dimension(2) :: input_shape !! Input shape real(real32), allocatable, dimension(:) :: data_list !! Data list integer, dimension(2) :: param_lines !! Lines where parameters are found ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose ! Loop over tags in layer card !--------------------------------------------------------------------------- iline = 0 param_lines = 0 final_line = 0 tag_loop: do ! Check for end of file !------------------------------------------------------------------------ read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if if(trim(adjustl(buffer)).eq."") cycle tag_loop ! Check for end of layer card !------------------------------------------------------------------------ if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then final_line = iline backspace(unit) exit tag_loop end if iline = iline + 1 tag = trim(adjustl(buffer)) if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) ! Read parameters from save file !------------------------------------------------------------------------ select case(trim(tag)) case("INPUT_SHAPE") call assign_vec(buffer, input_shape, itmp1) case("MOMENTUM") call assign_val(buffer, momentum, itmp1) case("EPSILON") call assign_val(buffer, epsilon, itmp1) case("NUM_CHANNELS") call assign_val(buffer, num_channels, itmp1) write(0,*) "NUM_CHANNELS and INPUT_SHAPE are conflicting parameters" write(0,*) "NUM_CHANNELS will be ignored" case("GAMMA_INITIALISER", "KERNEL_INITIALISER") if(param_lines(1).ne.0)then write(err_msg,'("GAMMA and GAMMA_INITIALISER defined. Using GAMMA only.")') call print_warning(err_msg) end if call assign_val(buffer, gamma_initialiser_name, itmp1) case("BETA_INITIALISER", "BIAS_INITIALISER") if(param_lines(2).ne.0)then write(err_msg,'("BETA and BETA_INITIALISER defined. Using BETA only.")') call print_warning(err_msg) end if call assign_val(buffer, beta_initialiser_name, itmp1) case("MOVING_MEAN_INITIALISER") call assign_val(buffer, moving_mean_initialiser_name, itmp1) case("MOVING_VARIANCE_INITIALISER") call assign_val(buffer, moving_variance_initialiser_name, itmp1) case("GAMMA") gamma_initialiser_name = 'zeros' param_lines(1) = iline case("BETA") beta_initialiser_name = 'zeros' param_lines(2) = iline case default ! Don't look for "e" due to scientific notation of numbers ! ... i.e. exponent (E+00) if(scan(to_lower(trim(adjustl(buffer))),& 'abcdfghijklmnopqrstuvwxyz').eq.0)then cycle tag_loop elseif(tag(:3).eq.'END')then cycle tag_loop end if write(err_msg,'("Unrecognised line in input file: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end select end do tag_loop gamma_initialiser = initialiser_setup(gamma_initialiser_name) beta_initialiser = initialiser_setup(beta_initialiser_name) moving_mean_initialiser = initialiser_setup(moving_mean_initialiser_name) moving_variance_initialiser = initialiser_setup(moving_variance_initialiser_name) ! Set hyperparameters and initialise layer !--------------------------------------------------------------------------- num_channels = input_shape(size(input_shape)) call this%set_hyperparams( & momentum = momentum, & epsilon = epsilon, & gamma_init_mean = this%gamma_init_mean, & gamma_init_std = this%gamma_init_std, & beta_init_mean = this%beta_init_mean, & beta_init_std = this%beta_init_std, & gamma_initialiser = gamma_initialiser, & beta_initialiser = beta_initialiser, & moving_mean_initialiser = moving_mean_initialiser, & moving_variance_initialiser = moving_variance_initialiser, & verbose = verbose_ & ) call this%init(input_shape = input_shape) ! Check if WEIGHTS card was found !--------------------------------------------------------------------------- allocate(data_list(num_channels), source=0._real32) do i = 2, 1, -1 if(param_lines(i).eq.0) cycle call move(unit, param_lines(i) - iline, iostat=stat) iline = param_lines(i) + 1 c = 1 k = 1 data_list = 0._real32 data_concat_loop: do while(c.le.num_channels) iline = iline + 1 read(unit,'(A)',iostat=stat) buffer if(stat.ne.0) exit data_concat_loop k = icount(buffer) read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1) c = c + k end do data_concat_loop read(unit,'(A)',iostat=stat) buffer select case(i) case(1) ! gamma this%params(1)%val(1:this%num_channels,1) = data_list if(trim(adjustl(buffer)).ne."END GAMMA")then write(err_msg,'("END GAMMA not where expected: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end if case(2) ! beta this%params(1)%val(this%num_channels+1:this%num_channels*2,1) = & data_list if(trim(adjustl(buffer)).ne."END BETA")then write(err_msg,'("END BETA not where expected: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end if end select end do deallocate(data_list) ! Check for end of layer card !--------------------------------------------------------------------------- call move(unit, final_line - iline, iostat=stat) read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then write(0,*) trim(adjustl(buffer)) write(err_msg,'("END ",A," not where expected")') to_upper(this%name) call stop_program(err_msg) return end if end subroutine read_batchnorm1d