emit_onnx_nodes_neural_operator Subroutine

private subroutine emit_onnx_nodes_neural_operator(this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, input_name, is_last_layer, format)

Emit decomposed standard ONNX nodes for a Neural Operator layer.

Forward: v = sigma(W * u + w_k * mean(u) + b)

Type Bound

neural_operator_layer_type

Arguments

Type IntentOptional Attributes Name
class(neural_operator_layer_type), intent(in) :: this

Neural operator layer instance

character(len=*), intent(in) :: prefix

Layer name prefix

type(onnx_node_type), intent(inout), dimension(:) :: nodes

Node accumulator

integer, intent(inout) :: num_nodes

Node counter

integer, intent(in) :: max_nodes

Node limit

type(onnx_initialiser_type), intent(inout), dimension(:) :: inits

Initialiser accumulator

integer, intent(inout) :: num_inits

Initialiser counter

integer, intent(in) :: max_inits

Initialiser limit

character(len=*), intent(in), optional :: input_name

Name of the input tensor

logical, intent(in), optional :: is_last_layer

Whether this is the last layer

integer, intent(in), optional :: format

Export format selector


Source Code

  subroutine emit_onnx_nodes_neural_operator( &
       this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, &
       max_inits, input_name, is_last_layer, format)
    !! Emit decomposed standard ONNX nodes for a Neural Operator layer.
    !!
    !! Forward: v = sigma(W * u + w_k * mean(u) + b)
    implicit none

    ! Arguments
    class(neural_operator_layer_type), intent(in) :: this
    !! Neural operator layer instance
    character(*), intent(in) :: prefix
    !! Layer name prefix
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Node accumulator
    integer, intent(inout) :: num_nodes
    !! Node counter
    integer, intent(in) :: max_nodes
    !! Node limit
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    !! Initialiser accumulator
    integer, intent(inout) :: num_inits
    !! Initialiser counter
    integer, intent(in) :: max_inits
    !! Initialiser limit
    character(*), optional, intent(in) :: input_name
    !! Name of the input tensor
    logical, optional, intent(in) :: is_last_layer
    !! Whether this is the last layer
    integer, optional, intent(in) :: format
    !! Export format selector

    ! Local variables
    integer :: n
    character(128) :: w_name, wk_name, b_name
    character(128) :: trans_in_out, mm_w_out, reduce_out
    character(128) :: mul_out, add_out, add_b_out, final_output, &
         output_source
    character(4096) :: reduce_attr
    integer :: format_

    format_ = 1
    if(present(format)) format_ = format
    if(format_ .ne. 2) return
    if(.not.present(input_name)) return
    if(.not.present(is_last_layer)) return

    write(w_name, '(A,".W")') trim(prefix)
    write(wk_name, '(A,".w_k")') trim(prefix)
    write(b_name, '(A,".b")') trim(prefix)

    write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix)
    write(mm_w_out, '("/",A,"/MatMul_output_0")') trim(prefix)
    write(reduce_out, '("/",A,"/ReduceMean_output_0")') trim(prefix)
    write(mul_out, '("/",A,"/Mul_output_0")') trim(prefix)
    write(add_out, '("/",A,"/Add_output_0")') trim(prefix)
    write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix)

    reduce_attr = '        "attribute": [{"name": "axes", "ints": ' // &
         '["0"], "type": "INTS"}, {"name": "keepdims", "i": "1", ' // &
         '"type": "INT"}]'

    ! Transpose(input)
    call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, &
         num_nodes, trim(trans_in_out))

    ! MatMul(W, x_t)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(w_name)
    nodes(num_nodes)%inputs(2) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_w_out)
    nodes(num_nodes)%attributes_json = ''

    ! ReduceMean(x_t, axis=0)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/ReduceMean")') trim(prefix)
    nodes(num_nodes)%op_type = 'ReduceMean'
    allocate(nodes(num_nodes)%inputs(1))
    nodes(num_nodes)%inputs(1) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(reduce_out)
    nodes(num_nodes)%attributes_json = reduce_attr

    ! Mul(w_k, mean)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Mul")') trim(prefix)
    nodes(num_nodes)%op_type = 'Mul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(wk_name)
    nodes(num_nodes)%inputs(2) = trim(reduce_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mul_out)
    nodes(num_nodes)%attributes_json = ''

    ! Add(local, kernel)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix)
    nodes(num_nodes)%op_type = 'Add'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(mm_w_out)
    nodes(num_nodes)%inputs(2) = trim(mul_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(add_out)
    nodes(num_nodes)%attributes_json = ''

    ! Add(combined, bias)
    if(this%use_bias)then
       num_nodes = num_nodes + 1
       write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix)
       nodes(num_nodes)%op_type = 'Add'
       allocate(nodes(num_nodes)%inputs(2))
       nodes(num_nodes)%inputs(1) = trim(add_out)
       nodes(num_nodes)%inputs(2) = trim(b_name)
       allocate(nodes(num_nodes)%outputs(1))
       nodes(num_nodes)%outputs(1) = trim(add_b_out)
       nodes(num_nodes)%attributes_json = ''
    end if

    if(this%use_bias)then
       output_source = add_b_out
    else
       output_source = add_out
    end if
    call emit_nop_output_tail(trim(prefix), trim(this%activation%name), &
         is_last_layer, trim(output_source), nodes, num_nodes, final_output)

    ! Initialisers
    n = this%num_outputs * this%num_inputs
    call emit_matrix_initialiser(trim(w_name), this%params(1)%val(:,1), &
         this%num_outputs, this%num_inputs, inits, num_inits)

    ! w_k: mean-field kernel [n_out, 1]
    call emit_float_initialiser(trim(wk_name), this%params(2)%val(:,1), &
         [this%num_outputs, 1], inits, num_inits)

    ! b: bias [n_out, 1]
    if(this%use_bias)then
       call emit_float_initialiser(trim(b_name), this%params(3)%val(:,1), &
            [this%num_outputs, 1], inits, num_inits)
    end if

  end subroutine emit_onnx_nodes_neural_operator