get_num_params_ono_attn Function

private pure function get_num_params_ono_attn(this) result(num_params)

Return the number of learnable parameters for the layer

Type Bound

orthogonal_attention_layer_type

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Layer instance

Return Value integer

Total number of learnable parameters


Source Code

  pure function get_num_params_ono_attn(this) result(num_params)
    !! Return the number of learnable parameters for the layer
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Layer instance
    integer :: num_params
    !! Total number of learnable parameters

    ! W_Q: key_dim * num_inputs
    ! W_K: key_dim * num_inputs
    ! W_V: num_outputs * num_inputs
    ! B:   num_inputs * num_basis  (basis weights to orthogonalise)
    ! W:   num_outputs * num_inputs (bypass)
    ! b:   num_outputs (optional)
    num_params = this%key_dim * this%num_inputs + &     ! W_Q
         this%key_dim * this%num_inputs + &              ! W_K
         this%num_outputs * this%num_inputs + &          ! W_V
         this%num_inputs * this%num_basis + &            ! B
         this%num_outputs * this%num_inputs              ! W
    if(this%use_bias) num_params = num_params + this%num_outputs

  end function get_num_params_ono_attn