emit_duvenaud_readout_impl Subroutine

private subroutine emit_duvenaud_readout_impl(prefix, layer, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, readout_output)

Emit ONNX nodes for Duvenaud readout

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix
class(duvenaud_msgpass_layer_type), intent(in) :: layer
type(onnx_node_type), intent(inout), dimension(:) :: nodes
integer, intent(inout) :: num_nodes
integer, intent(in) :: max_nodes
type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
integer, intent(inout) :: num_inits
integer, intent(in) :: max_inits
character(len=128), intent(out) :: readout_output

Source Code

  subroutine emit_duvenaud_readout_impl( &
       prefix, layer, &
       nodes, num_nodes, max_nodes, &
       inits, num_inits, max_inits, &
       readout_output &
  )
    !! Emit ONNX nodes for Duvenaud readout
    use athena__onnx_utils, only: emit_node, emit_constant_int64
    implicit none
    character(*), intent(in) :: prefix
    class(duvenaud_msgpass_layer_type), intent(in) :: layer
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    integer, intent(in) :: max_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    integer, intent(in) :: max_inits
    character(128), intent(out) :: readout_output

    ! Local variables
    integer :: t
    character(128) :: tmp1, prev_accum, step_sum

    do t = 1, layer%num_time_steps
       call emit_duvenaud_readout_step( &
            prefix, layer%activation%name, t, &
            layer%num_vertex_features(t), layer%num_outputs, &
            layer%params(t + layer%num_time_steps)%val(:,1), &
            nodes, num_nodes, inits, num_inits, step_sum)

       ! Accumulate across timesteps
       if(t .eq. 1)then
          prev_accum = trim(step_sum)
       else
          write(tmp1, '(A,"_ro_t",I0,"_accum")') trim(prefix), t
          call emit_node('Add', trim(tmp1)//'_node', &
               trim(tmp1), '', nodes, num_nodes, &
               in1=trim(prev_accum), in2=trim(step_sum))
          prev_accum = trim(tmp1)
       end if
    end do

    ! Unsqueeze to add batch dimension: [no] → [1, no]
    write(tmp1, '(A,"_ro_ax0")') trim(prefix)
    call emit_constant_int64(trim(tmp1), [0], [1], &
         nodes, num_nodes, inits, num_inits)
    write(readout_output, '(A,"_readout")') trim(prefix)
    call emit_node('Unsqueeze', trim(prefix)//'_us_readout', &
         trim(readout_output), '', nodes, num_nodes, &
         in1=trim(prev_accum), in2=trim(tmp1))

  end subroutine emit_duvenaud_readout_impl