submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule !! Submodule containing implementations for extended diffstruc array operations use coreutils, only: stop_program use diffstruc, only: & operator(+), operator(-), operator(*), concat, exp, sum, merge contains !############################################################################### module function add_array_ptr(a, idx1, idx2) result(c) !! Add two autodiff arrays implicit none ! Arguments type(array_ptr_type), dimension(:), intent(in) :: a integer, intent(in) :: idx1, idx2 type(array_type), pointer :: c ! Local variables integer :: i c => a(1)%array(idx1, idx2) + a(2)%array(idx1, idx2) do i = 3, size(a), 1 c => c + a(i)%array(idx1, idx2) end do end function add_array_ptr !############################################################################### !############################################################################### module function concat_array_ptr(a, idx1, idx2, dim) result(c) !! Concatenate two autodiff arrays along a specified dimension implicit none ! Arguments type(array_ptr_type), dimension(:), intent(in) :: a integer, intent(in) :: idx1, idx2, dim type(array_type), pointer :: c ! Local variables integer :: i c => concat(a(1)%array(idx1, idx2), a(2)%array(idx1, idx2), dim) do i = 3, size(a), 1 c => concat(c, a(i)%array(idx1, idx2), dim) end do end function concat_array_ptr !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module function add_bias(input, bias, dim, dim_act_on_shape) result(output) !! Add bias to input array along specified dimension implicit none ! Arguments class(array_type), intent(in), target :: input class(array_type), intent(in), target :: bias integer, intent(in) :: dim logical, intent(in), optional :: dim_act_on_shape type(array_type), pointer :: output ! Local variables integer :: i, j, k, s, idx, itmp1 integer :: num_elements_pre, num_elements_post, num_dims logical :: dim_act_on_shape_ if(present(dim_act_on_shape))then dim_act_on_shape_ = dim_act_on_shape else dim_act_on_shape_ = .false. end if output => input%create_result() allocate(output%indices(2)) output%indices(1) = dim if(dim_act_on_shape_)then num_dims = size(input%shape) if(dim .gt. num_dims)then call stop_program("Dimension for add_bias exceeds input dimensions") return elseif(size(bias%shape) .ne. 1)then call stop_program("Bias must be a 1D array") return end if num_elements_pre = 1 num_elements_post = 1 do i = 1, num_dims if(i .lt. dim)then num_elements_pre = num_elements_pre * input%shape(i) elseif(i .gt. dim)then num_elements_post = num_elements_post * input%shape(i) end if end do itmp1 = num_elements_pre * input%shape(dim) do s = 1, size(input%val, 2) do k = 1, num_elements_post do j = 1, bias%shape(1) idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 do i = 1, num_elements_pre output%val(idx + i, s) = input%val(idx + i, s) + bias%val(j,1) end do end do end do end do output%indices(2) = 1 else call stop_program("add_bias: dim_act_on_shape=.false. not implemented yet") output%indices(2) = 0 end if output%get_partial_left => get_partial_add output%get_partial_right => get_partial_add_bias output%get_partial_left_val => get_partial_add_val output%get_partial_right_val => get_partial_add_bias_val if(input%requires_grad .or. bias%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward .or. bias%is_forward output%operation = 'add_bias' output%left_operand => input output%right_operand => bias end if end function add_bias !------------------------------------------------------------------------------- function get_partial_add(this, upstream_grad) result(output) !! Get partial derivative with respect to left operand implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output output = upstream_grad end function get_partial_add !------------------------------------------------------------------------------- pure subroutine get_partial_add_val(this, upstream_grad, output) !! Get partial derivative with respect to left operand implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output if(size(upstream_grad,2).ne.size(output,2))then if(size(output,1).eq.1)then output(1,1) = sum(upstream_grad) else output(:,1) = sum(upstream_grad, dim=2) end if else if(size(output,1).eq.1.and.size(output,1).ne.size(upstream_grad,1))then output(1,:) = sum(upstream_grad,1) else output = upstream_grad end if end if end subroutine get_partial_add_val !------------------------------------------------------------------------------- function get_partial_add_bias(this, upstream_grad) result(output) !! Get partial derivative with respect to bias operand implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) call this%get_partial_right_val(upstream_grad%val, output%val) end function get_partial_add_bias !------------------------------------------------------------------------------- pure subroutine get_partial_add_bias_val(this, upstream_grad, output) implicit none ! Arguments class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: i, j, k, s, idx, itmp1 integer :: num_elements_pre, num_elements_post, num_dims num_dims = size(this%left_operand%shape) num_elements_pre = 1 num_elements_post = 1 do i = 1, num_dims if(i .lt. this%indices(1))then num_elements_pre = num_elements_pre * this%left_operand%shape(i) elseif(i .gt. this%indices(1))then num_elements_post = num_elements_post * this%left_operand%shape(i) end if end do itmp1 = num_elements_pre * this%left_operand%shape(this%indices(1)) output = 0._real32 do s = 1, size(upstream_grad, 2) do k = 1, num_elements_post do j = 1, this%right_operand%shape(1) idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 do i = 1, num_elements_pre output(j,1) = output(j,1) + upstream_grad(idx + i, s) end do end do end do end do end subroutine get_partial_add_bias_val !############################################################################### !############################################################################### module function piecewise_array(input, gradient, limit) result(output) !! Apply piecewise activation function to input array implicit none ! Arguments class(array_type), intent(in), target :: input real(real32), intent(in) :: gradient real(real32), intent(in) :: limit type(array_type), pointer :: output type(array_type), pointer :: b_array output => input%create_result() where(input%val.ge.limit) output%val = gradient * (input%val - limit) + limit elsewhere(input%val.le.-limit) output%val = gradient * (input%val + limit) - limit elsewhere output%val = input%val end where output%get_partial_left => get_partial_piecewise output%get_partial_left_val => get_partial_piecewise_val if(input%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'piecewise' output%left_operand => input output%owns_left_operand = input%is_temporary end if allocate(b_array) b_array%is_sample_dependent = .false. b_array%requires_grad = .false. call b_array%allocate(array_shape=[2, 1]) b_array%val(1,1) = gradient b_array%val(2,1) = limit output%right_operand => b_array output%owns_right_operand = .true. end function piecewise_array !------------------------------------------------------------------------------- function get_partial_piecewise(this, upstream_grad) result(output) !! Get partial derivative of piecewise activation implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: ptr ptr => merge( & upstream_grad, & upstream_grad * this%right_operand%val(1,1), & this%left_operand%val.le.-this%right_operand%val(2,1) .or. & this%left_operand%val.ge.this%right_operand%val(2,1) & ) call output%assign_and_deallocate_source(ptr) end function get_partial_piecewise !------------------------------------------------------------------------------- pure subroutine get_partial_piecewise_val(this, upstream_grad, output) !! Get partial derivative of piecewise activation (in-place version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output where(this%left_operand%val.le.this%right_operand%val(2,1) .or. & this%left_operand%val.ge.-this%right_operand%val(2,1) & ) output = upstream_grad elsewhere output = upstream_grad * this%right_operand%val(1,1) end where end subroutine get_partial_piecewise_val !############################################################################### !############################################################################### module function softmax_array(input, dim) result(output) implicit none class(array_type), intent(in), target :: input integer, intent(in) :: dim type(array_type), pointer :: output integer :: i output => input%create_result() if(dim.eq.1)then do i = 1, size(input%val, 1) output%val(i, :) = exp(input%val(i, :) - maxval(input%val(i,:))) output%val(i, :) = output%val(i, :) / sum(output%val(i, :)) end do elseif(dim.eq.2)then do i = 1, size(input%val, 2) output%val(:, i) = exp(input%val(:, i) - maxval(input%val(:, i))) output%val(:, i) = output%val(:, i) / sum(output%val(:, i)) end do else call stop_program("softmax_array: Unsupported dimension") end if allocate(output%indices(1)) output%indices(1) = dim output%get_partial_left => get_partial_softmax output%get_partial_left_val => get_partial_softmax_val output%get_partial_left_val_sum => get_partial_softmax_val_sum if(input%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'softmax' output%left_operand => input output%owns_left_operand = input%is_temporary end if end function softmax_array !------------------------------------------------------------------------------- function get_partial_softmax(this, upstream_grad) result(output) !! Get partial derivative of softmax activation implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: ptr integer :: dim if(this%indices(1).eq.1)then dim = 2 else dim = 1 end if ! ptr => this * upstream_grad ! ptr => ptr - this * sum(ptr, dim=dim) ptr => softmax_reverse_array(this, upstream_grad, this%indices(1)) call output%assign_and_deallocate_source(ptr) end function get_partial_softmax !------------------------------------------------------------------------------- pure subroutine get_partial_softmax_val(this, upstream_grad, output) !! Get partial derivative of softmax activation (in-place version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: s, dim if(this%indices(1).eq.1)then dim = 2 else dim = 1 end if output = this%val * upstream_grad if(dim.eq.1)then do s = 1, size(this%val, 2) output(:, s) = output(:, s) - this%val(:, s) * sum(output(:, s)) end do elseif(dim.eq.2)then do s = 1, size(this%val, 1) output(s, :) = output(s, :) - this%val(s, :) * sum(output(s, :)) end do end if end subroutine get_partial_softmax_val !------------------------------------------------------------------------------- pure subroutine get_partial_softmax_val_sum(this, upstream_grad, output) !! Get partial derivative of softmax activation (in-place version, summed over samples) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:), intent(out) :: output integer :: s, i, nfeat, nsamp real(real32) :: dot output = 0.0_real32 if(this%indices(1) .eq. 1)then nsamp = size(this%val,1) nfeat = size(this%val,2) do s = 1, nsamp ! compute g·y dot = 0.0_real32 do i = 1, nfeat dot = dot + upstream_grad(s,i) * this%val(s,i) end do ! accumulate reduced gradient do concurrent( i = 1 : nfeat ) output(i) = output(i) + this%val(s,i) * (upstream_grad(s,i) - dot) end do end do else nsamp = size(this%val,2) nfeat = size(this%val,1) do s = 1, nsamp dot = 0.0_real32 do i = 1, nfeat dot = dot + upstream_grad(i,s) * this%val(i,s) end do do concurrent( i = 1 : nfeat ) output(i) = output(i) + this%val(i,s) * (upstream_grad(i,s) - dot) end do end do end if end subroutine get_partial_softmax_val_sum !############################################################################### !############################################################################### module function swish_array(input, beta) result(output) !! Swish activation function implicit none ! Arguments class(array_type), intent(in), target :: input real(real32), intent(in) :: beta type(array_type), pointer :: output type(array_type), pointer :: b_array output => input%create_result() output%val = input%val * (1._real32 / (1._real32 + exp(-beta * input%val))) output%get_partial_left => get_partial_swish output%get_partial_left_val => get_partial_swish_val if(input%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'swish' output%left_operand => input output%owns_left_operand = input%is_temporary end if allocate(b_array) b_array%is_sample_dependent = .false. b_array%is_scalar = .true. b_array%requires_grad = .false. call b_array%allocate(array_shape=[1, 1]) b_array%val(1,1) = beta output%right_operand => b_array output%owns_right_operand = .true. end function swish_array !------------------------------------------------------------------------------- function get_partial_swish(this, upstream_grad) result(output) !! Get partial derivative of swish activation implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: ptr type(array_type), pointer :: exp_term exp_term => exp(this%right_operand%val(1,1) * this%left_operand) ptr => upstream_grad * exp_term * ( & this%right_operand%val(1,1) * this%left_operand + & exp_term + 1._real32 & ) / ( ( exp_term + 1._real32 )**2._real32 ) call output%assign_and_deallocate_source(ptr) end function get_partial_swish !------------------------------------------------------------------------------- pure subroutine get_partial_swish_val(this, upstream_grad, output) !! Get partial derivative of swish activation (in-place version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output real(real32), dimension(size(this%val,1), size(this%val,2)) :: exp_term exp_term = exp(this%right_operand%val(1,1) * this%left_operand%val) output = upstream_grad * exp_term * ( & this%right_operand%val(1,1) * this%left_operand%val + & exp_term + 1._real32 & ) / ( ( exp_term + 1._real32 )**2._real32 ) end subroutine get_partial_swish_val !############################################################################### !############################################################################### function softmax_reverse_array(softmax, gradient, dim) result(output) !! Softmax function for reverse mode autodiff implicit none class(array_type), intent(in), target :: softmax class(array_type), intent(in), target :: gradient integer, intent(in) :: dim type(array_type), pointer :: output integer :: i real(real32), dimension(size(softmax%val,1), size(softmax%val,2)) :: temp_val output => softmax%create_result() temp_val = gradient%val * softmax%val if(dim.eq.1)then do concurrent(i=1:size(softmax%val,1)) temp_val(i, :) = temp_val(i, :) - softmax%val(i, :) * sum(temp_val(i, :)) end do elseif(dim.eq.2)then do concurrent(i=1:size(softmax%val,2)) temp_val(:, i) = temp_val(:, i) - softmax%val(:, i) * sum(temp_val(:, i)) end do else call stop_program("softmax_reverse_array: Unsupported dimension") end if output%val = temp_val output%indices = [dim] output%get_partial_left => get_partial_softmax_reverse_left output%get_partial_left_val => get_partial_softmax_reverse_left_val output%get_partial_right => get_partial_softmax_reverse_right output%get_partial_right_val => get_partial_softmax_reverse_right_val if(softmax%requires_grad .or. gradient%requires_grad)then output%requires_grad = .true. output%is_forward = softmax%is_forward .or. gradient%is_forward output%operation = 'softmax_reverse' output%left_operand => softmax output%right_operand => gradient output%owns_left_operand = softmax%is_temporary output%owns_right_operand = gradient%is_temporary end if end function softmax_reverse_array !------------------------------------------------------------------------------- function get_partial_softmax_reverse_left(this, upstream_grad) result(output) !! Get partial derivative of softmax reverse operation implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: sum_yg, sum_yu type(array_type), pointer :: ptr sum_yg => sum(this%left_operand * this%right_operand, dim=this%indices(1)) sum_yu => sum(this%left_operand * upstream_grad, dim=this%indices(1)) ptr => upstream_grad * (this%right_operand - sum_yg) - this%right_operand * sum_yu call output%assign_and_deallocate_source(ptr) end function get_partial_softmax_reverse_left !------------------------------------------------------------------------------- function get_partial_softmax_reverse_right(this, upstream_grad) result(output) !! Get partial derivative of softmax reverse operation implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output type(array_type), pointer :: ptr ptr => ( & upstream_grad - & sum(this%left_operand * upstream_grad, dim=this%indices(1)) & ) * this%left_operand call output%assign_and_deallocate_source(ptr) end function get_partial_softmax_reverse_right !------------------------------------------------------------------------------- pure subroutine get_partial_softmax_reverse_left_val(this, upstream_grad, output) !! Get partial derivative of softmax reverse operation (in-place version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: dim, i real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yg real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu dim = this%indices(1) sum_yg = sum(this%left_operand%val * this%right_operand%val, dim=dim) sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) if(dim.eq.1)then do concurrent(i=1:size(this%val,2)) output(:, i) = & upstream_grad(:, i) * (this%right_operand%val(:, i) - sum_yg(i)) - & this%right_operand%val(:, i) * sum_yu(i) end do else do concurrent(i=1:size(this%val,1)) output(i, :) = & upstream_grad(i, :) * (this%right_operand%val(i, :) - sum_yg(i)) - & this%right_operand%val(i, :) * sum_yu(i) end do end if end subroutine get_partial_softmax_reverse_left_val !------------------------------------------------------------------------------- pure subroutine get_partial_softmax_reverse_right_val(this, upstream_grad, output) !! Get partial derivative of softmax reverse operation (in-place version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: dim, i real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu dim = this%indices(1) if(dim.eq.1)then sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) do concurrent(i=1:size(this%val,1)) output(i, :) = upstream_grad(i, :) - sum_yu(i) * this%left_operand%val(i, :) end do else sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) do concurrent(i=1:size(this%val,2)) output(:, i) = upstream_grad(:, i) - sum_yu(i) * this%left_operand%val(:, i) end do end if end subroutine get_partial_softmax_reverse_right_val !############################################################################### end submodule athena__diffstruc_extd_submodule