accumulate_edge_gradients_3d_val Subroutine

pure subroutine accumulate_edge_gradients_3d_val(upstream_grad, output, input_shape, indices, adj_ja)

Accumulate edge gradients for 3D padding - raw array version

Arguments

Type IntentOptional Attributes Name
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(inout), dimension(:,:) :: output
integer, intent(in), dimension(5) :: input_shape
integer, intent(in), dimension(:) :: indices
integer, intent(in), dimension(:,:) :: adj_ja

Source Code

  pure subroutine accumulate_edge_gradients_3d_val(upstream_grad, output, &
       input_shape, indices, adj_ja)
    !! Accumulate edge gradients for 3D padding - raw array version
    implicit none

    ! Arguments
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(inout) :: output
    integer, dimension(5), intent(in) :: input_shape
    integer, dimension(:), intent(in) :: indices
    integer, dimension(:,:), intent(in) :: adj_ja

    ! Local variables
    integer :: i, j, k, m, s, f, idim
    integer :: step1, step2, step3, idx_in, idx_out, idx_shift
    integer :: input_h, input_w, input_d
    integer :: output_h, output_w, output_d
    integer, dimension(2,3) :: orig, dest
    real(real32) :: grad_sum

    input_h = input_shape(1)
    input_w = input_shape(2)
    input_d = input_shape(3)
    output_h = input_h + 2 * indices(2)
    output_w = input_w + 2 * indices(3)
    output_d = input_d + 2 * indices(4)

    if(indices(6) .eq. 0) return

    idx_shift = indices(5) * 6

    select case(indices(1))
    case(3, 4) ! circular or reflection
       do f = 1, indices(6)
          idim = indices(7 + indices(5) + f)
          orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
          orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
          orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
          dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
          dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
          dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)

          step1 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 1)
          step2 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 2)
          step3 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 3)

          do s = 1, input_shape(5)
             do m = 1, input_shape(4)
                do k = dest(1,3), dest(2,3)
                   do j = dest(1,2), dest(2,2)
                      do i = dest(1,1), dest(2,1)
                         idx_out = i + (j-1) * output_h + &
                              (k-1) * output_h * output_w + &
                              (m - 1) * output_h * output_w * output_d
                         idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
                              (orig(1,2) + step2 * (j - dest(1,2)) - 1) * &
                              input_h + &
                              (orig(1,3) + step3 * (k - dest(1,3)) - 1) * &
                              input_h * input_w + &
                              (m - 1) * input_h * input_w * input_d
                         output(idx_in, s) = output(idx_in, s) + &
                              upstream_grad(idx_out, s)
                      end do
                   end do
                end do
             end do
          end do
       end do
    case(5) ! replication
       do f = 1, indices(6)
          idim = indices(7 + indices(5) + f)
          orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
          orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
          orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
          dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
          dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
          dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)

          select case(idim)
          case(1) ! Edge along dimension 1
             do s = 1, input_shape(5)
                do m = 1, input_shape(4)
                   do i = dest(1,1), dest(2,1)
                      idx_in = i - dest(1,1) + 1 + &
                           (orig(1,2) - 1) * input_h + &
                           (orig(1,3) - 1) * input_h * input_w + &
                           (m - 1) * input_h * input_w * input_d
                      grad_sum = 0._real32
                      do k = dest(1,3), dest(2,3)
                         do j = dest(1,2), dest(2,2)
                            idx_out = i + (j - 1) * output_h + &
                                 (k - 1) * output_h * output_w + &
                                 (m - 1) * output_h * output_w * output_d
                            grad_sum = grad_sum + upstream_grad(idx_out, s)
                         end do
                      end do
                      output(idx_in, s) = output(idx_in, s) + grad_sum
                   end do
                end do
             end do
          case(2) ! Edge along dimension 2
             do s = 1, input_shape(5)
                do m = 1, input_shape(4)
                   do j = dest(1,2), dest(2,2)
                      idx_in = orig(1,1) + &
                           (j - dest(1,2)) * input_h + &
                           (orig(1,3) - 1) * input_h * input_w + &
                           (m - 1) * input_h * input_w * input_d
                      grad_sum = 0._real32
                      do k = dest(1,3), dest(2,3)
                         do i = dest(1,1), dest(2,1)
                            idx_out = i + (j - 1) * output_h + &
                                 (k - 1) * output_h * output_w + &
                                 (m - 1) * output_h * output_w * output_d
                            grad_sum = grad_sum + upstream_grad(idx_out, s)
                         end do
                      end do
                      output(idx_in, s) = output(idx_in, s) + grad_sum
                   end do
                end do
             end do
          case(3) ! Edge along dimension 3
             do s = 1, input_shape(5)
                do m = 1, input_shape(4)
                   do k = dest(1,3), dest(2,3)
                      idx_in = orig(1,1) + &
                           (orig(1,2) - 1) * input_h + &
                           (k - dest(1,3)) * input_h * input_w + &
                           (m - 1) * input_h * input_w * input_d
                      grad_sum = 0._real32
                      do j = dest(1,2), dest(2,2)
                         do i = dest(1,1), dest(2,1)
                            idx_out = i + (j - 1) * output_h + &
                                 (k - 1) * output_h * output_w + &
                                 (m - 1) * output_h * output_w * output_d
                            grad_sum = grad_sum + upstream_grad(idx_out, s)
                         end do
                      end do
                      output(idx_in, s) = output(idx_in, s) + grad_sum
                   end do
                end do
             end do
          end select
       end do
    end select

  end subroutine accumulate_edge_gradients_3d_val