Emit one Duvenaud readout timestep.
This expands to the timestep readout projection, the readout softmax, and the reduction over nodes before the timestep contributions are added.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| character(len=*), | intent(in) | :: | activation_name | |||
| integer, | intent(in) | :: | t | |||
| integer, | intent(in) | :: | nv | |||
| integer, | intent(in) | :: | no | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| character(len=128), | intent(out) | :: | step_sum |
subroutine emit_duvenaud_readout_step( & prefix, activation_name, t, nv, no, weight_data, & nodes, num_nodes, inits, num_inits, step_sum) !! Emit one Duvenaud readout timestep. !! !! This expands to the timestep readout projection, the readout softmax, !! and the reduction over nodes before the timestep contributions are added. use athena__onnx_utils, only: emit_node, emit_constant_int64 use athena__onnx_msgpass_utils, only: get_timestep_output_name, & emit_weight_initialiser_2d implicit none ! Arguments character(*), intent(in) :: prefix character(*), intent(in) :: activation_name integer, intent(in) :: t, nv, no real(real32), intent(in) :: weight_data(:) type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits character(128), intent(out) :: step_sum ! Local variables character(128) :: tp, z_name, weight_name, z_transpose character(128) :: matmul_out, softmax_out, axis1_name character(len=*), parameter :: onnx_softmax_axis0_attr = & ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]' character(len=*), parameter :: onnx_transpose_10_attr = & ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // & '"type": "INTS"}]' character(len=*), parameter :: onnx_reduce_sum_attr = & ' "attribute": [{"name": "keepdims", "i": "0", ' // & '"type": "INT"}]' write(tp, '(A,"_ro_t",I0)') trim(prefix), t call get_timestep_output_name( & prefix, t, activation_name, '_sq_out', '_sq', z_name) ! Store the readout matrix for timestep t as an ONNX initialiser. write(weight_name, '(A,"_R")') trim(tp) call emit_weight_initialiser_2d( & trim(weight_name), no, nv, weight_data, inits, num_inits) ! Transpose node features before multiplying by the readout matrix. write(z_transpose, '(A,"_zt")') trim(tp) call emit_node('Transpose', trim(tp)//'_transpose_z', & trim(z_transpose), onnx_transpose_10_attr, nodes, num_nodes, & in1=trim(z_name)) write(matmul_out, '(A,"_Rz")') trim(tp) call emit_node('MatMul', trim(tp)//'_matmul_R', & trim(matmul_out), '', nodes, num_nodes, & in1=trim(weight_name), in2=trim(z_transpose)) ! Softmax and ReduceSum reproduce the ATHENA readout accumulation. write(softmax_out, '(A,"_sm")') trim(tp) call emit_node('Softmax', trim(tp)//'_softmax', & trim(softmax_out), onnx_softmax_axis0_attr, nodes, num_nodes, & in1=trim(matmul_out)) write(axis1_name, '(A,"_ax1")') trim(tp) call emit_constant_int64(trim(axis1_name), [1], [1], & nodes, num_nodes, inits, num_inits) write(step_sum, '(A,"_sum")') trim(tp) call emit_node('ReduceSum', trim(tp)//'_reducesum', & trim(step_sum), onnx_reduce_sum_attr, nodes, num_nodes, & in1=trim(softmax_out), in2=trim(axis1_name)) end subroutine emit_duvenaud_readout_step