emit_duvenaud_readout_step Subroutine

private subroutine emit_duvenaud_readout_step(prefix, activation_name, t, nv, no, weight_data, nodes, num_nodes, inits, num_inits, step_sum)

Emit one Duvenaud readout timestep.

This expands to the timestep readout projection, the readout softmax, and the reduction over nodes before the timestep contributions are added.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix
character(len=*), intent(in) :: activation_name
integer, intent(in) :: t
integer, intent(in) :: nv
integer, intent(in) :: no
real(kind=real32), intent(in) :: weight_data(:)
type(onnx_node_type), intent(inout), dimension(:) :: nodes
integer, intent(inout) :: num_nodes
type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
integer, intent(inout) :: num_inits
character(len=128), intent(out) :: step_sum

Source Code

  subroutine emit_duvenaud_readout_step( &
       prefix, activation_name, t, nv, no, weight_data, &
       nodes, num_nodes, inits, num_inits, step_sum)
    !! Emit one Duvenaud readout timestep.
    !!
    !! This expands to the timestep readout projection, the readout softmax,
    !! and the reduction over nodes before the timestep contributions are added.
    use athena__onnx_utils, only: emit_node, emit_constant_int64
    use athena__onnx_msgpass_utils, only: get_timestep_output_name, &
         emit_weight_initialiser_2d
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    character(*), intent(in) :: activation_name
    integer, intent(in) :: t, nv, no
    real(real32), intent(in) :: weight_data(:)
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    character(128), intent(out) :: step_sum

    ! Local variables
    character(128) :: tp, z_name, weight_name, z_transpose
    character(128) :: matmul_out, softmax_out, axis1_name
    character(len=*), parameter :: onnx_softmax_axis0_attr = &
         '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
    character(len=*), parameter :: onnx_transpose_10_attr = &
         '        "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
         '"type": "INTS"}]'
    character(len=*), parameter :: onnx_reduce_sum_attr = &
         '        "attribute": [{"name": "keepdims", "i": "0", ' // &
         '"type": "INT"}]'

    write(tp, '(A,"_ro_t",I0)') trim(prefix), t
    call get_timestep_output_name( &
         prefix, t, activation_name, '_sq_out', '_sq', z_name)

    ! Store the readout matrix for timestep t as an ONNX initialiser.
    write(weight_name, '(A,"_R")') trim(tp)
    call emit_weight_initialiser_2d( &
         trim(weight_name), no, nv, weight_data, inits, num_inits)

    ! Transpose node features before multiplying by the readout matrix.
    write(z_transpose, '(A,"_zt")') trim(tp)
    call emit_node('Transpose', trim(tp)//'_transpose_z', &
         trim(z_transpose), onnx_transpose_10_attr, nodes, num_nodes, &
         in1=trim(z_name))

    write(matmul_out, '(A,"_Rz")') trim(tp)
    call emit_node('MatMul', trim(tp)//'_matmul_R', &
         trim(matmul_out), '', nodes, num_nodes, &
         in1=trim(weight_name), in2=trim(z_transpose))

    ! Softmax and ReduceSum reproduce the ATHENA readout accumulation.
    write(softmax_out, '(A,"_sm")') trim(tp)
    call emit_node('Softmax', trim(tp)//'_softmax', &
         trim(softmax_out), onnx_softmax_axis0_attr, nodes, num_nodes, &
         in1=trim(matmul_out))

    write(axis1_name, '(A,"_ax1")') trim(tp)
    call emit_constant_int64(trim(axis1_name), [1], [1], &
         nodes, num_nodes, inits, num_inits)

    write(step_sum, '(A,"_sum")') trim(tp)
    call emit_node('ReduceSum', trim(tp)//'_reducesum', &
         trim(step_sum), onnx_reduce_sum_attr, nodes, num_nodes, &
         in1=trim(softmax_out), in2=trim(axis1_name))

  end subroutine emit_duvenaud_readout_step