detect_gnn_expanded_activation Function

private function detect_gnn_expanded_activation(prefix, nodes, num_nodes) result(name)

Detect the activation op used in a GNN layer cluster.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Layer node prefix (e.g. "node_2")

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

Return Value character(len=64)

Detected ATHENA activation name


Source Code

  function detect_gnn_expanded_activation( &
       prefix, nodes, num_nodes) result(name)
    !! Detect the activation op used in a GNN layer cluster.
    use athena__onnx_utils, only: onnx_to_athena_activation
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix (e.g. "node_2")
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries
    character(64) :: name
    !! Detected ATHENA activation name

    integer :: i
    character(128) :: check_prefix

    name = 'none'
    write(check_prefix, '(A,"_t")') trim(prefix)

    do i = 1, num_nodes
       if(index(trim(nodes(i)%name), &
            trim(check_prefix)) .ne. 1) cycle
       select case(trim(nodes(i)%op_type))
       case('Relu', 'LeakyRelu', 'Sigmoid', &
            'Tanh', 'Selu', 'Swish')
          name = onnx_to_athena_activation( &
               trim(nodes(i)%op_type))
       end select
    end do

  end function detect_gnn_expanded_activation