Emit ONNX nodes for Duvenaud readout
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | layer | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| integer, | intent(in) | :: | max_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| integer, | intent(in) | :: | max_inits | |||
| character(len=128), | intent(out) | :: | readout_output |
subroutine emit_duvenaud_readout_impl( & prefix, layer, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & readout_output & ) !! Emit ONNX nodes for Duvenaud readout use athena__onnx_utils, only: emit_node, emit_constant_int64 implicit none character(*), intent(in) :: prefix class(duvenaud_msgpass_layer_type), intent(in) :: layer type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes integer, intent(in) :: max_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits integer, intent(in) :: max_inits character(128), intent(out) :: readout_output ! Local variables integer :: t character(128) :: tmp1, prev_accum, step_sum do t = 1, layer%num_time_steps call emit_duvenaud_readout_step( & prefix, layer%activation%name, t, & layer%num_vertex_features(t), layer%num_outputs, & layer%params(t + layer%num_time_steps)%val(:,1), & nodes, num_nodes, inits, num_inits, step_sum) ! Accumulate across timesteps if(t .eq. 1)then prev_accum = trim(step_sum) else write(tmp1, '(A,"_ro_t",I0,"_accum")') trim(prefix), t call emit_node('Add', trim(tmp1)//'_node', & trim(tmp1), '', nodes, num_nodes, & in1=trim(prev_accum), in2=trim(step_sum)) prev_accum = trim(tmp1) end if end do ! Unsqueeze to add batch dimension: [no] → [1, no] write(tmp1, '(A,"_ro_ax0")') trim(prefix) call emit_constant_int64(trim(tmp1), [0], [1], & nodes, num_nodes, inits, num_inits) write(readout_output, '(A,"_readout")') trim(prefix) call emit_node('Unsqueeze', trim(prefix)//'_us_readout', & trim(readout_output), '', nodes, num_nodes, & in1=trim(prev_accum), in2=trim(tmp1)) end subroutine emit_duvenaud_readout_impl