Initialise parameter storage and output buffers for the layer
orthogonal_attention_layer_type
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to initialise |
||
| integer, | intent(in), | dimension(:) | :: | input_shape |
Input shape used to infer num_inputs |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
subroutine init_ono_attn(this, input_shape, verbose) !! Initialise parameter storage and output buffers for the layer implicit none ! Arguments class(orthogonal_attention_layer_type), intent(inout) :: this !! Layer instance to initialise integer, dimension(:), intent(in) :: input_shape !! Input shape used to infer num_inputs integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: num_inputs, idx, nparams !! Effective fan-in size and reserved scratch integers integer :: verbose_ = 0 !! Effective verbosity level if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Set shapes !--------------------------------------------------------------------------- if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) this%num_inputs = this%input_shape(1) this%output_shape = [this%num_outputs] this%num_params = this%get_num_params() !--------------------------------------------------------------------------- ! Allocate learnable parameters ! ! params(1): W_Q query projection [key_dim x num_inputs] ! params(2): W_K key projection [key_dim x num_inputs] ! params(3): W_V value projection [num_outputs x num_inputs] ! params(4): B basis weights [num_inputs x num_basis] ! params(5): W bypass weights [num_outputs x num_inputs] ! params(6): b bias [num_outputs] (optional) !--------------------------------------------------------------------------- allocate(this%weight_shape(2,5)) this%weight_shape(:,1) = [ this%key_dim, this%num_inputs ] this%weight_shape(:,2) = [ this%key_dim, this%num_inputs ] this%weight_shape(:,3) = [ this%num_outputs, this%num_inputs ] this%weight_shape(:,4) = [ this%num_inputs, this%num_basis ] this%weight_shape(:,5) = [ this%num_outputs, this%num_inputs ] if(this%use_bias)then this%bias_shape = [ this%num_outputs ] allocate(this%params(6)) else allocate(this%params(5)) end if num_inputs = this%num_inputs if(this%use_bias) num_inputs = this%num_inputs + 1 ! W_Q call this%params(1)%allocate([this%key_dim, this%num_inputs, 1]) call this%params(1)%set_requires_grad(.true.) this%params(1)%fix_pointer = .true. this%params(1)%is_sample_dependent = .false. this%params(1)%is_temporary = .false. ! W_K call this%params(2)%allocate([this%key_dim, this%num_inputs, 1]) call this%params(2)%set_requires_grad(.true.) this%params(2)%fix_pointer = .true. this%params(2)%is_sample_dependent = .false. this%params(2)%is_temporary = .false. ! W_V call this%params(3)%allocate([this%num_outputs, this%num_inputs, 1]) call this%params(3)%set_requires_grad(.true.) this%params(3)%fix_pointer = .true. this%params(3)%is_sample_dependent = .false. this%params(3)%is_temporary = .false. ! B (basis weights) call this%params(4)%allocate([this%num_inputs, this%num_basis, 1]) call this%params(4)%set_requires_grad(.true.) this%params(4)%fix_pointer = .true. this%params(4)%is_sample_dependent = .false. this%params(4)%is_temporary = .false. ! W (bypass) call this%params(5)%allocate([this%num_outputs, this%num_inputs, 1]) call this%params(5)%set_requires_grad(.true.) this%params(5)%fix_pointer = .true. this%params(5)%is_sample_dependent = .false. this%params(5)%is_temporary = .false. ! b (bias, optional) if(this%use_bias)then call this%params(6)%allocate([this%bias_shape, 1]) call this%params(6)%set_requires_grad(.true.) this%params(6)%fix_pointer = .true. this%params(6)%is_sample_dependent = .false. this%params(6)%is_temporary = .false. end if !--------------------------------------------------------------------------- ! Initialise learnable parameters !--------------------------------------------------------------------------- call this%kernel_init%initialise( & this%params(1)%val(:,1), & fan_in = this%num_inputs, fan_out = this%key_dim, & spacing = [ this%key_dim ] & ) call this%kernel_init%initialise( & this%params(2)%val(:,1), & fan_in = this%num_inputs, fan_out = this%key_dim, & spacing = [ this%key_dim ] & ) call this%kernel_init%initialise( & this%params(3)%val(:,1), & fan_in = num_inputs, fan_out = this%num_outputs, & spacing = [ this%num_outputs ] & ) call this%kernel_init%initialise( & this%params(4)%val(:,1), & fan_in = this%num_inputs, fan_out = this%num_basis, & spacing = [ this%num_inputs ] & ) call this%kernel_init%initialise( & this%params(5)%val(:,1), & fan_in = num_inputs, fan_out = this%num_outputs, & spacing = [ this%num_outputs ] & ) if(this%use_bias)then call this%bias_init%initialise( & this%params(6)%val(:,1), & fan_in = num_inputs, fan_out = this%num_outputs & ) end if !--------------------------------------------------------------------------- ! Allocate output arrays !--------------------------------------------------------------------------- if(allocated(this%output)) deallocate(this%output) allocate(this%output(1,1)) if(this%z(1)%allocated) call this%z(1)%deallocate() end subroutine init_ono_attn