softmax_reverse_array Function

function softmax_reverse_array(softmax, gradient, dim) result(output)

Softmax function for reverse mode autodiff

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: softmax
class(array_type), intent(in), target :: gradient
integer, intent(in) :: dim

Return Value type(array_type), pointer


Source Code

  function softmax_reverse_array(softmax, gradient, dim) result(output)
    !! Softmax function for reverse mode autodiff
    implicit none
    class(array_type), intent(in), target :: softmax
    class(array_type), intent(in), target :: gradient
    integer, intent(in) :: dim
    type(array_type), pointer :: output

    integer :: i
    real(real32), dimension(size(softmax%val,1), size(softmax%val,2)) :: temp_val


    output => softmax%create_result()
    temp_val = gradient%val * softmax%val
    if(dim.eq.1)then
       do concurrent(i=1:size(softmax%val,1))
          temp_val(i, :) = temp_val(i, :) - softmax%val(i, :) * sum(temp_val(i, :))
       end do
    elseif(dim.eq.2)then
       do concurrent(i=1:size(softmax%val,2))
          temp_val(:, i) = temp_val(:, i) - softmax%val(:, i) * sum(temp_val(:, i))
       end do
    else
       call stop_program("softmax_reverse_array: Unsupported dimension")
    end if
    output%val = temp_val
    output%indices = [dim]

    output%get_partial_left => get_partial_softmax_reverse_left
    output%get_partial_left_val => get_partial_softmax_reverse_left_val
    output%get_partial_right => get_partial_softmax_reverse_right
    output%get_partial_right_val => get_partial_softmax_reverse_right_val
    if(softmax%requires_grad .or. gradient%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = softmax%is_forward .or. gradient%is_forward
       output%operation = 'softmax_reverse'
       output%left_operand => softmax
       output%right_operand => gradient
       output%owns_left_operand = softmax%is_temporary
       output%owns_right_operand = gradient%is_temporary
    end if

  end function softmax_reverse_array