athena__diffstruc_extd_submodule_batchnorm Submodule

Submodule containing implementations for extended diffstruc array operations


Uses


Functions

function get_partial_batchnorm_left(this, upstream_grad) result(output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(inout) :: this
type(array_type), intent(in) :: upstream_grad

Return Value type(array_type)

function get_partial_batchnorm_right(this, upstream_grad) result(output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(inout) :: this
type(array_type), intent(in) :: upstream_grad

Return Value type(array_type)


Subroutines

pure subroutine get_partial_batchnorm_left_val(this, upstream_grad, output)

Get partial derivative wrt input for batchnorm (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

pure subroutine get_partial_batchnorm_right_val(this, upstream_grad, output)

Get partial derivative wrt params for batchnorm (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Module Functions

module function batchnorm(input, params, momentum, mean, variance, epsilon) result(output)

Batch normalisation operation

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input
class(array_type), intent(in), target :: params
real(kind=real32), intent(in) :: momentum
real(kind=real32), intent(in), dimension(:) :: mean
real(kind=real32), intent(in), dimension(:) :: variance
real(kind=real32), intent(in) :: epsilon

Return Value type(batchnorm_array_type), pointer

module function batchnorm_inference(input, params, mean, variance, epsilon) result(output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input
class(array_type), intent(in), target :: params
real(kind=real32), intent(in), dimension(:) :: mean
real(kind=real32), intent(in), dimension(:) :: variance
real(kind=real32), intent(in) :: epsilon

Return Value type(batchnorm_array_type), pointer