init_duvenaud Subroutine

private subroutine init_duvenaud(this, input_shape, verbose)

Initialise the message passing layer

Type Bound

duvenaud_msgpass_layer_type

Arguments

Type IntentOptional Attributes Name
class(duvenaud_msgpass_layer_type), intent(inout) :: this

Instance of the fully connected layer

integer, intent(in), dimension(:) :: input_shape

Input shape

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  subroutine init_duvenaud(this, input_shape, verbose)
    !! Initialise the message passing layer
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    class(duvenaud_msgpass_layer_type), intent(inout) :: this
    !! Instance of the fully connected layer
    integer, dimension(:), intent(in) :: input_shape
    !! Input shape
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: t
    !! Loop index
    integer :: verbose_ = 0
    !! Verbosity level


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose


    !---------------------------------------------------------------------------
    ! Initialise number of inputs
    !---------------------------------------------------------------------------
    if(.not.allocated(this%input_shape)) call this%set_shape([input_shape])
    this%output_shape = [this%num_outputs]
    this%num_params = this%get_num_params()


    !---------------------------------------------------------------------------
    ! Allocate weight, weight steps (velocities), output, and activation
    !---------------------------------------------------------------------------
    allocate(this%weight_shape(3,2*this%num_time_steps))
    allocate(this%params(this%num_time_steps*2))
    do t = 1, this%num_time_steps
       this%weight_shape(:,t) = [ &
            this%num_vertex_features(t), &
            this%num_vertex_features(t-1) + this%num_edge_features(0), &
            this%max_vertex_degree - this%min_vertex_degree + 1 &
       ]
       this%weight_shape(:,t+this%num_time_steps) = &
            [ this%num_outputs, this%num_vertex_features(t), 1 ]
       call this%params(t)%allocate( [ this%weight_shape(:,t), 1 ] )
       call this%params(t+this%num_time_steps)%allocate( &
            [ this%weight_shape(:2,t+this%num_time_steps), 1 ] &
       )
       call this%params(t)%set_requires_grad(.true.)
       this%params(t)%fix_pointer = .true.
       this%params(t)%is_temporary = .false.
       this%params(t)%is_sample_dependent = .false.
       this%params(t)%indices = [ this%min_vertex_degree, this%max_vertex_degree ]
       call this%params(t+this%num_time_steps)%set_requires_grad(.true.)
       this%params(t+this%num_time_steps)%fix_pointer = .true.
       this%params(t+this%num_time_steps)%is_temporary = .false.
       this%params(t+this%num_time_steps)%is_sample_dependent = .false.
    end do


    !---------------------------------------------------------------------------
    ! Initialise weights (kernels)
    !---------------------------------------------------------------------------
    do t = 1, this%num_time_steps, 1
       call this%kernel_init%initialise( &
            this%params(t)%val(:,1), &
            fan_in = this%num_vertex_features(t-1) + this%num_edge_features(0), &
            fan_out = this%num_vertex_features(t), &
            spacing = [ this%num_vertex_features(t-1) ] &
       )
       call this%kernel_init%initialise( &
            this%params(t+this%num_time_steps)%val(:,1), &
            fan_in = sum(this%num_vertex_features), &
            fan_out = this%num_outputs, &
            spacing = this%num_vertex_features &
       )
    end do


    !---------------------------------------------------------------------------
    ! Allocate arrays
    !---------------------------------------------------------------------------
    if(allocated(this%output)) deallocate(this%output)
    allocate(this%output(1,1))
    if(allocated(this%z)) deallocate(this%z)

  end subroutine init_duvenaud