elem_scale Module Function

module function elem_scale(input, scale) result(c)

Element-wise scaling with explicit support for sample-independent scale.

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input

Input tensor [n, batch]

class(array_type), intent(in), target :: scale

Scale tensor [n, 1]

Return Value type(array_type), pointer

Scaled output tensor


Source Code

  module function elem_scale(input, scale) result(c)
    !! Element-wise scaling with explicit support for sample-independent scale.
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: input
    !! Input tensor [n, batch]
    class(array_type), intent(in), target :: scale
    !! Scale tensor [n, 1]
    type(array_type), pointer :: c
    !! Scaled output tensor

    ! Local variables
    integer :: i, s, n, ns
    !! Feature/sample indices and dimensions

    n  = size(input%val, 1)
    ns = size(input%val, 2)

    c => input%create_result(array_shape=[n, ns])
    do concurrent(s = 1:ns, i = 1:n)
       c%val(i, s) = input%val(i, s) * scale%val(i, 1)
    end do

    c%get_partial_left     => null()
    c%get_partial_right    => null()
    c%get_partial_left_val  => get_partial_elem_scale_input_val
    c%get_partial_right_val => get_partial_elem_scale_scale_val
    if(input%requires_grad .or. scale%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = input%is_forward .or. scale%is_forward
       c%operation        = 'elem_scale'
       c%left_operand     => input
       c%right_operand    => scale
       c%owns_left_operand  = input%is_temporary
       c%owns_right_operand = scale%is_temporary
    end if

  end function elem_scale