module athena__concat_layer !! Module containing implementation of a concatenation layer !! !! This module implements a merge layer that concatenates multiple input !! tensors along a specified dimension (features for 2D, channels for 3D). !! !! Mathematical operation: !! output = [input_1 || input_2 || ... || input_N] !! !! where || denotes concatenation along the appropriate dimension. !! Output size along concatenation dimension = sum of input sizes. !! Gradients are split to corresponding input portions during backpropagation. use coreutils, only: real32, stop_program use athena__base_layer, only: merge_layer_type, base_layer_type use diffstruc, only: array_type, operator(+) use athena__diffstruc_extd, only: array_ptr_type, concat_layers implicit none private public :: concat_layer_type public :: read_concat_layer type, extends(merge_layer_type) :: concat_layer_type !! Type for concatenate layer with overloaded procedures integer, dimension(:,:), allocatable :: io_map !! I/O mapping for the layer contains procedure, pass(this) :: set_hyperparams => set_hyperparams_concat !! Set the hyperparameters for concatenate layer procedure, pass(this) :: init => init_concat !! Initialise concatenate layer procedure, pass(this) :: print_to_unit => print_to_unit_concat !! Print the layer to a file procedure, pass(this) :: read => read_concat !! Read the layer from a file procedure, pass(this) :: calc_input_shape => calc_input_shape_concat !! Calculate input shape based on shapes of input layers procedure, pass(this) :: combine => combine_concat end type concat_layer_type interface concat_layer_type !! Interface for setting up the concatenate layer module function layer_setup( & input_layer_ids, input_rank, verbose & ) result(layer) !! Setup a concatenate layer integer, dimension(:), intent(in) :: input_layer_ids !! Input layer IDs integer, optional, intent(in) :: input_rank !! Input rank integer, optional, intent(in) :: verbose !! Verbosity level type(concat_layer_type) :: layer end function layer_setup end interface concat_layer_type contains !############################################################################### module function layer_setup( & input_layer_ids, input_rank, verbose & ) result(layer) !! Setup a concatenate layer implicit none ! Arguments integer, dimension(:), intent(in) :: input_layer_ids !! Input layer IDs integer, optional, intent(in) :: input_rank !! Input rank integer, optional, intent(in) :: verbose !! Verbosity level type(concat_layer_type) :: layer !! Instance of the concatenate layer ! Local variables integer :: input_rank_ = 0 !! Input rank integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Set hyperparameters !--------------------------------------------------------------------------- if(present(input_rank))then input_rank_ = input_rank else call stop_program( & "input_rank or input_shape must be provided to concat layer" & ) return end if call layer%set_hyperparams( & input_layer_ids = input_layer_ids, & input_rank = input_rank_, & verbose = verbose_ & ) end function layer_setup !############################################################################### !############################################################################### subroutine set_hyperparams_concat( & this, & input_layer_ids, & input_rank, & verbose & ) !! Set the hyperparameters for concatenate layer implicit none ! Arguments class(concat_layer_type), intent(inout) :: this !! Instance of the concatenate layer integer, dimension(:), intent(in) :: input_layer_ids !! Input layer IDs integer, intent(in) :: input_rank !! Input rank integer, optional, intent(in) :: verbose !! Verbosity level this%name = "concatenate" this%type = "merg" this%merge_mode = 2 ! concatenate mode this%input_layer_ids = input_layer_ids this%input_rank = input_rank this%output_rank = input_rank end subroutine set_hyperparams_concat !############################################################################### !############################################################################### subroutine init_concat(this, input_shape, verbose) !! Initialise concatenate layer implicit none ! Arguments class(concat_layer_type), intent(inout) :: this !! Instance of the concatenate layer integer, dimension(:), intent(in) :: input_shape !! Input shape integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ = 0 !! Verbosity level !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Initialise input shape !--------------------------------------------------------------------------- this%input_rank = size(input_shape) if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) !--------------------------------------------------------------------------- ! Initialise output shape !--------------------------------------------------------------------------- this%output_shape = this%input_shape this%output_rank = size(this%output_shape) !--------------------------------------------------------------------------- ! Allocate arrays !--------------------------------------------------------------------------- if(allocated(this%output)) deallocate(this%output) allocate(this%output(1,1)) end subroutine init_concat !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### subroutine print_to_unit_concat(this, unit) !! Print concatenate layer to unit implicit none ! Arguments class(concat_layer_type), intent(in) :: this !! Instance of the concatenate layer integer, intent(in) :: unit !! File unit ! Local variables integer :: i !! Loop index character(100) :: fmt ! Write initial parameters !--------------------------------------------------------------------------- write(unit,'(3X,"INPUT_RANK = ",I0)') this%input_rank write(fmt,'("(3X,""INPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%input_shape) write(unit,fmt) this%input_shape write(fmt,'("(3X,""INPUT_LAYER_IDS ="",",I0,"(1X,I0))")') size(this%input_layer_ids) write(unit,fmt) this%input_layer_ids end subroutine print_to_unit_concat !############################################################################### !############################################################################### subroutine read_concat(this, unit, verbose) !! Read concatenate layer from file use athena__tools_infile, only: assign_val, assign_vec, get_val use coreutils, only: to_lower, to_upper, icount implicit none ! Arguments class(concat_layer_type), intent(inout) :: this !! Instance of the concatenate layer integer, intent(in) :: unit !! Unit number integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: stat, verbose_ = 0 !! File status and verbosity level integer :: itmp1 = 0 !! Temporary integer integer :: input_rank = 0 !! Input rank integer, dimension(:), allocatable :: input_shape, input_layer_ids !! Input shape character(256) :: buffer, tag, err_msg !! Buffer, tag, and error message ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose ! Loop over tags in layer card !--------------------------------------------------------------------------- tag_loop: do ! Check for end of file !------------------------------------------------------------------------ read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if if(trim(adjustl(buffer)).eq."") cycle tag_loop ! Check for end of layer card !------------------------------------------------------------------------ if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then backspace(unit) exit tag_loop end if tag=trim(adjustl(buffer)) if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) ! Read parameters from file !------------------------------------------------------------------------ select case(trim(tag)) case("INPUT_SHAPE") itmp1 = icount(get_val(buffer)) allocate(input_shape(itmp1), source=0) call assign_vec(buffer, input_shape, itmp1) case("INPUT_RANK") call assign_val(buffer, input_rank, itmp1) case("INPUT_LAYER_IDS") itmp1 = icount(get_val(buffer)) allocate(input_layer_ids(itmp1), source=0) call assign_vec(buffer, input_layer_ids, itmp1) case default ! Don't look for "e" due to scientific notation of numbers ! ... i.e. exponent (E+00) if(scan(to_lower(trim(adjustl(buffer))),& 'abcdfghijklmnopqrstuvwxyz').eq.0)then cycle tag_loop elseif(tag(:3).eq.'END')then cycle tag_loop end if write(err_msg,'("Unrecognised line in input file: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end select end do tag_loop if(allocated(input_shape))then if(input_rank.eq.0)then input_rank = size(input_shape) elseif(input_rank.ne.size(input_shape))then write(err_msg,'("input_rank (",I0,") does not match input_shape (",I0,")")') & input_rank, size(input_shape) call stop_program(err_msg) return end if elseif(input_rank.eq.0)then write(err_msg,'("input_rank must be provided if input_shape is not")') call stop_program(err_msg) return end if ! Set hyperparameters and initialise layer !--------------------------------------------------------------------------- call this%set_hyperparams( & input_layer_ids = input_layer_ids, & input_rank = input_rank, & verbose = verbose_ & ) call this%init(input_shape = input_shape) ! Check for end of layer card !--------------------------------------------------------------------------- read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then write(0,*) trim(adjustl(buffer)) write(err_msg,'("END ",A," not where expected")') to_upper(this%name) call stop_program(err_msg) return end if end subroutine read_concat !############################################################################### !############################################################################### function read_concat_layer(unit, verbose) result(layer) !! Read concatenate layer from file and return layer implicit none ! Arguments integer, intent(in) :: unit !! Unit number integer, optional, intent(in) :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Instance of the concatenate layer ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose allocate(layer, source=concat_layer_type( & input_layer_ids=[0,0], input_rank=1)) call layer%read(unit, verbose=verbose_) end function read_concat_layer !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### function calc_input_shape_concat(this, input_shapes) result(input_shape) !! Calculate input shape based on shapes of input layers implicit none ! Arguments class(concat_layer_type), intent(in) :: this !! Instance of the layer integer, dimension(:,:), intent(in) :: input_shapes !! Input shapes integer, allocatable, dimension(:) :: input_shape !! Calculated input shape input_shape = sum(input_shapes, dim=2) end function calc_input_shape_concat !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### subroutine combine_concat(this, input_list) !! Forward propagation for 2D input implicit none ! Arguments class(concat_layer_type), intent(inout) :: this !! Instance of the concatenate layer type(array_ptr_type), dimension(:), intent(in) :: input_list !! Input values ! Local variables integer :: i, j, s !! Loop index type(array_type), pointer :: ptr !! Pointer array if(allocated(this%output))then if(any(shape(this%output).ne.shape(input_list(1)%array)))then deallocate(this%output) allocate(this%output( & size(input_list(1)%array,1), & size(input_list(1)%array,2) & )) end if else allocate(this%output( & size(input_list(1)%array,1), & size(input_list(1)%array,2) & )) end if do s = 1, size(input_list(1)%array, 2) index_loop: do i = 1, size(input_list(1)%array, 1) do j = 1, size(input_list,1) if(.not.input_list(j)%array(i,s)%allocated) cycle index_loop end do ptr => concat_layers(input_list, i, s, dim = 1) call this%output(i,s)%zero_grad() call this%output(i,s)%assign_and_deallocate_source(ptr) this%output(i,s)%is_temporary = .false. end do index_loop end do end subroutine combine_concat !############################################################################### end module athena__concat_layer