create_from_onnx_orthogonal_attention_layer Function

public function create_from_onnx_orthogonal_attention_layer(meta_key, meta_value, inits, verbose) result(layer)

Build orthogonal attention layer from ONNX metadata and return layer

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: meta_key

NOP metadata key/value pair

character(len=*), intent(in) :: meta_value

NOP metadata key/value pair

type(onnx_initialiser_type), intent(in), dimension(:) :: inits

ONNX initialisers containing parameter tensors

integer, intent(in), optional :: verbose

Verbosity level

Return Value class(base_layer_type), allocatable

Constructed orthogonal attention layer


Source Code

  function create_from_onnx_orthogonal_attention_layer( &
       meta_key, meta_value, inits, verbose &
  ) result(layer)
    !! Build orthogonal attention layer from ONNX metadata and return layer
    use athena__orthogonal_attention_layer, only: &
         orthogonal_attention_layer_type
    implicit none

    ! Arguments
    character(*), intent(in) :: meta_key, meta_value
    !! NOP metadata key/value pair
    type(onnx_initialiser_type), dimension(:), intent(in) :: inits
    !! ONNX initialisers containing parameter tensors
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    class(base_layer_type), allocatable :: layer
    !! Constructed orthogonal attention layer

    ! Local variables
    integer :: num_inputs, num_outputs, num_modes, key_dim, verbose_
    !! Parsed layer dimensions, key dimension and effective verbosity level
    logical :: use_bias
    !! Whether the imported layer uses bias
    character(64) :: activation_name, nop_prefix
    integer :: k, pos, pos2
    !! Parsing indices
    character(256) :: token, key, val

    verbose_ = 0
    if(present(verbose)) verbose_ = verbose

    num_inputs = 0; num_outputs = 0; num_modes = 0; key_dim = 0
    use_bias = .true.; activation_name = 'none'

    call parse_nop_metadata(meta_value, &
         num_inputs, num_outputs, num_modes, use_bias, activation_name)

    ! Also parse key_dim
    pos = 1
    do while(pos .le. len_trim(meta_value))
       pos2 = index(meta_value(pos:), ';')
       if(pos2 .eq. 0)then
          token = meta_value(pos:len_trim(meta_value))
          pos = len_trim(meta_value) + 1
       else
          token = meta_value(pos:pos+pos2-2)
          pos = pos + pos2
       end if
       k = index(token, '=')
       if(k .eq. 0) cycle
       key = trim(adjustl(token(1:k-1)))
       val = trim(adjustl(token(k+1:)))
       if(trim(key) .eq. 'key_dim') read(val, *) key_dim
    end do

    nop_prefix = extract_nop_prefix(meta_key)

    block
      type(orthogonal_attention_layer_type) :: attn_layer

      attn_layer = orthogonal_attention_layer_type( &
           num_outputs = num_outputs, &
           num_basis = num_modes, &
           key_dim = key_dim, &
           num_inputs = num_inputs, &
           use_bias = use_bias, &
           activation = trim(activation_name) &
      )

      ! params: (1) W_Q, (2) W_K, (3) W_V, (4) B, (5) W, (6) b
      call load_nop_param_from_inits( &
           attn_layer%params(1), nop_prefix, '_param1', &
           inits, size(inits), [key_dim, num_inputs])
      call load_nop_param_from_inits( &
           attn_layer%params(2), nop_prefix, '_param2', &
           inits, size(inits), [key_dim, num_inputs])
      call load_nop_param_from_inits( &
           attn_layer%params(3), nop_prefix, '_param3', &
           inits, size(inits), [num_outputs, num_inputs])
      call load_nop_param_from_inits( &
           attn_layer%params(4), nop_prefix, '_param4', &
           inits, size(inits), [num_inputs, num_modes])
      call load_nop_param_from_inits( &
           attn_layer%params(5), nop_prefix, '_param5', &
           inits, size(inits), [num_outputs, num_inputs])
      if(use_bias)then
         call load_nop_param_from_inits( &
              attn_layer%params(6), nop_prefix, '_param6', &
              inits, size(inits), [num_outputs, 1])
      end if

      allocate(layer, source=attn_layer)
    end block

  end function create_from_onnx_orthogonal_attention_layer