init_batch Module Subroutine

module subroutine init_batch(this, input_shape, verbose)

Initialise batch normalisation layer

Arguments

Type IntentOptional Attributes Name
class(batch_layer_type), intent(inout) :: this

Instance of the layer

integer, intent(in), dimension(:) :: input_shape

Input shape

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  module subroutine init_batch(this, input_shape, verbose)
    !! Initialise batch normalisation layer
    use athena__initialiser, only: initialiser_setup
    use athena__misc_types, only: base_init_type
    implicit none

    ! Arguments
    class(batch_layer_type), intent(inout) :: this
    !! Instance of the layer
    integer, dimension(:), intent(in) :: input_shape
    !! Input shape
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    integer :: verbose_ = 0


    !---------------------------------------------------------------------------
    ! initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose


    !---------------------------------------------------------------------------
    ! initialise input shape
    !---------------------------------------------------------------------------
    if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)


    !---------------------------------------------------------------------------
    ! set up number of channels, width, height
    !---------------------------------------------------------------------------
    if(allocated(this%output)) deallocate(this%output)
    allocate(this%output_shape(this%input_rank))
    if(size(this%input_shape).eq.1)then
       this%output_shape(1) = this%input_shape(1)
       this%output_shape(2) = 1
    else
       this%output_shape = this%input_shape
    end if
    this%num_channels = this%input_shape(this%input_rank)
    this%num_params = this%get_num_params()
    allocate(this%params(1))
    call this%params(1)%allocate([2 * this%num_channels, 1])
    call this%params(1)%set_requires_grad(.true.)
    allocate(this%weight_shape(1,1))
    this%weight_shape(:,1) = [ this%num_channels ]
    this%bias_shape = [this%num_channels]


    !---------------------------------------------------------------------------
    ! allocate mean and variance
    !---------------------------------------------------------------------------
    allocate(this%mean(this%num_channels), source=0._real32)
    allocate(this%variance, source=this%mean)


    !---------------------------------------------------------------------------
    ! initialise gamma
    !---------------------------------------------------------------------------
    call this%kernel_init%initialise(this%params(1)%val(1:this%num_channels,1), &
         fan_in =this%num_channels, &
         fan_out=this%num_channels)

    ! initialise beta
    !---------------------------------------------------------------------------
    call this%bias_init%initialise(this%params(1)%val(this%num_channels+1:,1), &
         fan_in =this%num_channels, &
         fan_out=this%num_channels)


    !---------------------------------------------------------------------------
    ! initialise moving mean
    !---------------------------------------------------------------------------
    call this%moving_mean_init%initialise(this%mean, &
         fan_in =this%num_channels, &
         fan_out=this%num_channels)

    ! initialise moving variance
    !---------------------------------------------------------------------------
    call this%moving_variance_init%initialise(this%variance, &
         fan_in =this%num_channels, &
         fan_out=this%num_channels)


    !---------------------------------------------------------------------------
    ! Allocate arrays
    !---------------------------------------------------------------------------
    if(this%use_graph_input)then
       call stop_program( &
            "Graph input not supported for batch normalisation layer" &
       )
       return
    end if
    if(allocated(this%output)) deallocate(this%output)
    allocate( batchnorm_array_type :: this%output(1,1) )

  end subroutine init_batch