set_hyperparams_duvenaud Subroutine

private subroutine set_hyperparams_duvenaud(this, num_vertex_features, num_edge_features, min_vertex_degree, max_vertex_degree, num_time_steps, num_outputs, message_activation, readout_activation, kernel_initialiser, verbose)

Set the hyperparameters for the message passing layer

Type Bound

duvenaud_msgpass_layer_type

Arguments

Type IntentOptional Attributes Name
class(duvenaud_msgpass_layer_type), intent(inout) :: this

Instance of the message passing layer

integer, intent(in), dimension(:) :: num_vertex_features

Number of vertex features

integer, intent(in), dimension(:) :: num_edge_features

Number of edge features

integer, intent(in) :: min_vertex_degree

Minimum vertex degree

integer, intent(in) :: max_vertex_degree

Maximum vertex degree

integer, intent(in) :: num_time_steps

Number of time steps

integer, intent(in) :: num_outputs

Number of outputs

class(base_actv_type), intent(in), allocatable :: message_activation

Message and readout activation functions

class(base_actv_type), intent(in), allocatable :: readout_activation

Message and readout activation functions

class(base_init_type), intent(in), allocatable :: kernel_initialiser

Kernel and bias initialisers

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  subroutine set_hyperparams_duvenaud( &
       this, &
       num_vertex_features, num_edge_features, &
       min_vertex_degree, &
       max_vertex_degree, &
       num_time_steps, &
       num_outputs, &
       message_activation, &
       readout_activation, &
       kernel_initialiser, &
       verbose &
  )
    !! Set the hyperparameters for the message passing layer
    use athena__activation, only: activation_setup
    use athena__initialiser, only: get_default_initialiser, initialiser_setup
    implicit none

    ! Arguments
    class(duvenaud_msgpass_layer_type), intent(inout) :: this
    !! Instance of the message passing layer
    integer, dimension(:), intent(in) :: num_vertex_features
    !! Number of vertex features
    integer, dimension(:), intent(in) :: num_edge_features
    !! Number of edge features
    integer, intent(in) :: min_vertex_degree
    !! Minimum vertex degree
    integer, intent(in) :: max_vertex_degree
    !! Maximum vertex degree
    integer, intent(in) :: num_time_steps
    !! Number of time steps
    integer, intent(in) :: num_outputs
    !! Number of outputs
    class(base_actv_type), allocatable, intent(in) :: &
         message_activation, &
         readout_activation
    !! Message and readout activation functions
    class(base_init_type), allocatable, intent(in) :: kernel_initialiser
    !! Kernel and bias initialisers
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: t
    !! Loop index
    character(len=256) :: buffer


    this%name = 'duvenaud'
    this%type = 'msgp'
    this%input_rank = 2
    this%output_rank = 1
    this%min_vertex_degree = min_vertex_degree
    this%max_vertex_degree = max_vertex_degree
    this%num_time_steps = num_time_steps
    this%num_outputs = num_outputs
    if(allocated(this%num_vertex_features)) &
         deallocate(this%num_vertex_features)
    if(allocated(this%num_edge_features)) &
         deallocate(this%num_edge_features)
    if(size(num_vertex_features, 1) .eq. 1)then
       allocate( &
            this%num_vertex_features(0:num_time_steps), &
            source = num_vertex_features(1) &
       )
    elseif(size(num_vertex_features, 1) .eq. num_time_steps + 1)then
       allocate( &
            this%num_vertex_features(0:this%num_time_steps), &
            source = num_vertex_features &
       )
    else
       write(*,*) "Error: num_vertex_features must be a scalar or a vector of &
            &length num_time_steps + 1"
       stop
    end if
    if(size(num_edge_features, 1) .eq. 1)then
       allocate( &
            this%num_edge_features(0:num_time_steps), &
            source = num_edge_features(1) &
       )
    elseif(size(num_edge_features, 1) .eq. num_time_steps + 1)then
       allocate( &
            this%num_edge_features(0:this%num_time_steps), &
            source = num_edge_features &
       )
    else
       write(*,*) "Error: num_edge_features must be a scalar or a vector of &
            &length num_time_steps + 1"
       stop
    end if
    this%use_graph_input = .true.
    this%use_graph_output = .false.
    if(allocated(this%activation)) deallocate(this%activation)
    if(allocated(this%activation_readout)) deallocate(this%activation_readout)
    if(.not.allocated(message_activation))then
       this%activation = activation_setup(default_message_actv_name)
    else
       allocate( this%activation, source=message_activation )
    end if
    if(.not.allocated(readout_activation))then
       this%activation_readout = activation_setup(default_readout_actv_name)
    else
       allocate(this%activation_readout, source=readout_activation)
    end if
    if(allocated(this%kernel_init)) deallocate(this%kernel_init)
    if(.not.allocated(kernel_initialiser))then
       buffer = get_default_initialiser(this%activation%name)
       this%kernel_init = initialiser_setup(buffer)
    else
       allocate(this%kernel_init, source=kernel_initialiser)
    end if
    if(present(verbose))then
       if(abs(verbose).gt.0)then
          write(*,'("DUVENAUD message activation function: ",A)') &
               trim(this%activation%name)
          write(*,'("DUVENAUD readout activation function: ",A)') &
               trim(this%activation_readout%name)
          write(*,'("DUVENAUD kernel initialiser: ",A)') &
               trim(this%kernel_init%name)
       end if
    end if

    if(allocated(this%num_params_msg)) deallocate(this%num_params_msg)
    allocate(this%num_params_msg(1:this%num_time_steps))
    do t = 1, this%num_time_steps
       this%num_params_msg(t) = &
            ( this%num_vertex_features(t-1) + this%num_edge_features(0) ) * &
            this%num_vertex_features(t) * &
            ( this%max_vertex_degree - this%min_vertex_degree + 1 )
    end do
    this%num_params_readout = &
         sum( this%num_vertex_features * this%num_outputs )

    if(allocated(this%input_shape)) deallocate(this%input_shape)
    if(allocated(this%output_shape)) deallocate(this%output_shape)

  end subroutine set_hyperparams_duvenaud