get_partial_conv1d_input_val Subroutine

pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output)

Get partial derivative wrt input for 1D convolution (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output)
    !! Get partial derivative wrt input for 1D convolution (subroutine version)
    implicit none

    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, k, c_in, c_out, s
    integer :: i_in, k_idx, out_idx
    integer :: input_h, kernel_h, output_h, num_channels, num_filters
    integer :: stride, dilation
    real(real32) :: grad_val, kernel_val


    ! Unpack parameters
    num_channels = this%indices(1)
    num_filters = this%indices(2)
    stride = this%adj_ja(1,1)
    dilation = this%adj_ja(1,2)
    kernel_h = this%adj_ja(1,3)

    input_h = this%left_operand%shape(1)
    output_h = this%shape(1)

    output = 0._real32

    ! Parallelised over batch, channels, and output positions
    do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, &
         i = 1:output_h, c_out = 1:num_filters)
       out_idx = i + (c_out-1)*output_h
       grad_val = upstream_grad(out_idx, s)

       if(abs(grad_val) .gt. 1.e-30_real32)then
          do k = 1, kernel_h
             i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
             if(i_in .ge. 1 .and. i_in .le. input_h)then
                k_idx = k + ( c_in - 1 ) * kernel_h + &
                     ( c_out - 1 ) * kernel_h * num_channels
                kernel_val = this%right_operand%val(k_idx, 1)
                output(i_in + ( c_in - 1 ) * input_h, s) = &
                     output(i_in + ( c_in - 1 ) * input_h, s) + &
                     grad_val * kernel_val
             end if
          end do
       end if
    end do

  end subroutine get_partial_conv1d_input_val