Accumulate face gradients for 3D padding - raw array version
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(inout), | dimension(:,:) | :: | output | ||
| integer, | intent(in), | dimension(5) | :: | input_shape | ||
| integer, | intent(in), | dimension(:) | :: | indices | ||
| integer, | intent(in), | dimension(:,:) | :: | adj_ja |
pure subroutine accumulate_face_gradients_3d_val(upstream_grad, output, & input_shape, indices, adj_ja) !! Accumulate face gradients for 3D padding - raw array version implicit none ! Arguments real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(inout) :: output integer, dimension(5), intent(in) :: input_shape integer, dimension(:), intent(in) :: indices integer, dimension(:,:), intent(in) :: adj_ja ! Local variables integer :: i, j, k, m, s, f, idim integer :: step1, step2, step3, idx_in, idx_out integer :: input_h, input_w, input_d integer :: output_h, output_w, output_d integer, dimension(2,3) :: orig, dest real(real32) :: grad_sum input_h = input_shape(1) input_w = input_shape(2) input_d = input_shape(3) output_h = input_h + 2 * indices(2) output_w = input_w + 2 * indices(4) output_d = input_d + 2 * indices(4) if(indices(5) .eq. 0) return select case(indices(1)) case(3, 4) ! circular or reflection do f = 1, indices(5) idim = indices(7 + f) orig(1:2,1) = adj_ja(1,(f-1)*6 + 1:(f-1)*6 + 2) orig(1:2,2) = adj_ja(1,(f-1)*6 + 3:(f-1)*6 + 4) orig(1:2,3) = adj_ja(1,(f-1)*6 + 5:(f-1)*6 + 6) dest(1:2,1) = adj_ja(2,(f-1)*6 + 1:(f-1)*6 + 2) dest(1:2,2) = adj_ja(2,(f-1)*6 + 3:(f-1)*6 + 4) dest(1:2,3) = adj_ja(2,(f-1)*6 + 5:(f-1)*6 + 6) step1 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 1) step2 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 2) step3 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 3) do s = 1, input_shape(5) do m = 1, input_shape(4) do k = dest(1,3), dest(2,3) do j = dest(1,2), dest(2,2) do i = dest(1,1), dest(2,1) idx_out = i + (j-1) * output_h + & (k-1) * output_h * output_w + & (m - 1) * output_h * output_w * output_d idx_in = orig(1,1) + step1 * (i - dest(1,1)) + & (orig(1,2) + step2 * (j - dest(1,2)) - 1) * & input_h + & (orig(1,3) + step3 * (k - dest(1,3)) - 1) * & input_h * input_w + & (m - 1) * input_h * input_w * input_d output(idx_in, s) = output(idx_in, s) + & upstream_grad(idx_out, s) end do end do end do end do end do end do case(5) ! replication do f = 1, indices(5) idim = indices(7 + f) orig(1:2,1) = adj_ja(1,(f-1)*6 + 1:(f-1)*6 + 2) orig(1:2,2) = adj_ja(1,(f-1)*6 + 3:(f-1)*6 + 4) orig(1:2,3) = adj_ja(1,(f-1)*6 + 5:(f-1)*6 + 6) dest(1:2,1) = adj_ja(2,(f-1)*6 + 1:(f-1)*6 + 2) dest(1:2,2) = adj_ja(2,(f-1)*6 + 3:(f-1)*6 + 4) dest(1:2,3) = adj_ja(2,(f-1)*6 + 5:(f-1)*6 + 6) select case(idim) case(1) ! Face perpendicular to dimension 1 do s = 1, input_shape(5) do m = 1, input_shape(4) do k = dest(1,3), dest(2,3) do j = dest(1,2), dest(2,2) idx_in = orig(1,1) + & ( j - dest(1,2) ) * input_h + & ( k - dest(1,3) ) * input_h * input_w + & (m - 1) * input_h * input_w * input_d grad_sum = 0._real32 do i = dest(1,1), dest(2,1) idx_out = i + (j-1) * output_h + & (k-1) * output_h * output_w + & (m - 1) * output_h * output_w * output_d grad_sum = grad_sum + upstream_grad(idx_out, s) end do output(idx_in, s) = output(idx_in, s) + grad_sum end do end do end do end do case(2) ! Face perpendicular to dimension 2 do s = 1, input_shape(5) do m = 1, input_shape(4) do k = dest(1,3), dest(2,3) do i = dest(1,1), dest(2,1) idx_in = i - dest(1,1) + 1 + & ( k - dest(1,3) ) * input_h * input_w + & (m - 1) * input_h * input_w * input_d grad_sum = 0._real32 do j = dest(1,2), dest(2,2) idx_out = i + (j-1) * output_h + & (k-1) * output_h * output_w + & (m - 1) * output_h * output_w * output_d grad_sum = grad_sum + upstream_grad(idx_out, s) end do output(idx_in, s) = output(idx_in, s) + grad_sum end do end do end do end do case(3) ! Face perpendicular to dimension 3 do s = 1, input_shape(5) do m = 1, input_shape(4) do j = dest(1,2), dest(2,2) do i = dest(1,1), dest(2,1) idx_in = i - dest(1,1) + 1 + & ( j - dest(1,2) ) * input_h + & (m - 1) * input_h * input_w * input_d grad_sum = 0._real32 do k = dest(1,3), dest(2,3) idx_out = i + (j-1) * output_h + & (k-1) * output_h * output_w + & (m - 1) * output_h * output_w * output_d grad_sum = grad_sum + upstream_grad(idx_out, s) end do output(idx_in, s) = output(idx_in, s) + grad_sum end do end do end do end do end select end do end select end subroutine accumulate_face_gradients_3d_val