module athena__reshape_layer !! Module containing implementation of a reshape layer !! !! This module implements a general reshape layer that can transform tensors !! between arbitrary shapes while preserving the total number of elements. !! Unlike flatten (which only converts to 1D), reshape allows any target shape. !! !! Mathematical operation: !! Reshape: (d1, d2, ..., dn) -> (d1', d2', ..., dm') !! where: d1 * d2 * ... * dn = d1' * d2' * ... * dm' !! !! Examples: !! - (28, 28) -> (784) [flatten] !! - (784) -> (28, 28) [unflatten] !! - (64, 32, 32) -> (64, 1024) [spatial to sequence] !! - (100, 50) -> (10, 10, 50) [add spatial dimension] !! !! Properties: !! - No learnable parameters (pure reshape operation) !! - Preserves all information (bijective mapping) !! - No computation beyond memory reorganisation !! - Gradients flow unchanged (chain rule applies directly) use coreutils, only: real32, stop_program use athena__base_layer, only: base_layer_type use diffstruc, only: array_type, reshape use athena__misc_types, only: & onnx_node_type, onnx_initialiser_type, onnx_tensor_type implicit none private public :: reshape_layer_type public :: read_reshape_layer, create_from_onnx_reshape_layer type, extends(base_layer_type) :: reshape_layer_type !! Type for reshape layer with overloaded procedures contains procedure, pass(this) :: set_hyperparams => set_hyperparams_reshape !! Set hyperparameters for reshape layer procedure, pass(this) :: init => init_reshape !! Initialise reshape layer procedure, pass(this) :: print_to_unit => print_to_unit_reshape !! Print reshape layer to unit procedure, pass(this) :: read => read_reshape !! Read reshape layer from file procedure, pass(this) :: build_from_onnx => build_from_onnx_reshape !! Build reshape layer from ONNX node and initialisers procedure, pass(this) :: forward => forward_reshape !! Forward propagation derived type handler end type reshape_layer_type interface reshape_layer_type !! Interface for setting up the reshape layer module function layer_setup( & output_shape, input_shape, verbose & ) result(layer) !! Set up the reshape layer integer, dimension(:), intent(in) :: output_shape !! Target output shape (excluding batch dimension) integer, dimension(:), optional, intent(in) :: input_shape !! Input shape (excluding batch dimension) integer, optional, intent(in) :: verbose !! Verbosity level type(reshape_layer_type) :: layer !! Instance of the reshape layer end function layer_setup end interface reshape_layer_type contains !############################################################################### module function layer_setup( & output_shape, input_shape, verbose & ) result(layer) !! Set up the reshape layer implicit none ! Arguments integer, dimension(:), intent(in) :: output_shape !! Target output shape (excluding batch dimension) integer, dimension(:), optional, intent(in) :: input_shape !! Input shape (excluding batch dimension) integer, optional, intent(in) :: verbose !! Verbosity level type(reshape_layer_type) :: layer !! Instance of the reshape layer ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Set hyperparameters !--------------------------------------------------------------------------- call layer%set_hyperparams(output_shape, verbose_) !--------------------------------------------------------------------------- ! Initialise layer !--------------------------------------------------------------------------- if(present(input_shape)) call layer%init(input_shape, verbose_) end function layer_setup !############################################################################### !############################################################################### subroutine set_hyperparams_reshape(this, output_shape, verbose) !! Set hyperparameters for reshape layer implicit none ! Arguments class(reshape_layer_type), intent(inout) :: this !! Instance of the reshape layer integer, dimension(:), intent(in) :: output_shape !! Output rank integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose this%type = "rshp" this%name = "reshape" this%input_rank = 0 this%output_shape = output_shape this%output_rank = size(output_shape) if(verbose_ .gt. 0) write(*,'(" Setting up reshape layer")') end subroutine set_hyperparams_reshape !############################################################################### !############################################################################### subroutine init_reshape(this, input_shape, verbose) !! Initialise reshape layer implicit none ! Arguments class(reshape_layer_type), intent(inout) :: this !! Instance of the reshape layer integer, dimension(:), intent(in) :: input_shape !! Input shape integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ = 0 !! Verbosity level integer :: input_size, output_size !! Total number of elements integer :: i !! Loop index if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Set input shape !--------------------------------------------------------------------------- this%input_rank = size(input_shape) if(allocated(this%input_shape)) deallocate(this%input_shape) allocate(this%input_shape, source=input_shape) !--------------------------------------------------------------------------- ! Validate reshape compatibility !--------------------------------------------------------------------------- input_size = product(input_shape) output_size = product(this%output_shape) if(input_size .ne. output_size)then write(*,'("ERROR: Reshape layer - incompatible shapes")') write(*,'(" Input shape has ",I0," elements")') input_size write(*,'(" Output shape has ",I0," elements")') output_size call stop_program("Reshape layer shape mismatch") end if !--------------------------------------------------------------------------- ! Print layer info !--------------------------------------------------------------------------- if(verbose_ .gt. 0)then write(*,'(" Reshape layer initialised")') write(*,'(" Input shape: ",*(I0," x "))') this%input_shape write(*,'(" Output shape: ",*(I0," x "))') this%output_shape end if end subroutine init_reshape !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### subroutine print_to_unit_reshape(this, unit) !! Print reshape layer to unit implicit none ! Arguments class(reshape_layer_type), intent(in) :: this !! Instance of the reshape layer integer, intent(in) :: unit !! File unit ! Local variables character(100) :: fmt !! Format string ! Write initial parameters !--------------------------------------------------------------------------- write(unit,'(3X,"INPUT_RANK = ",I0)') this%input_rank write(fmt,'("(3X,""INPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%input_shape) write(unit,fmt) this%input_shape write(fmt,'("(3X,""OUTPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%output_shape) write(unit,fmt) this%output_shape end subroutine print_to_unit_reshape !############################################################################### !############################################################################### subroutine read_reshape(this, unit, verbose) !! Read reshape layer from file use athena__tools_infile, only: assign_val, assign_vec, get_val use coreutils, only: to_lower, to_upper, icount implicit none ! Arguments class(reshape_layer_type), intent(inout) :: this !! Instance of the reshape layer integer, intent(in) :: unit !! File unit integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: stat, verbose_ = 0 !! File status and verbosity level integer :: itmp1 = 0 !! Temporary integer integer :: input_rank = 0 !! Input rank integer, dimension(:), allocatable :: input_shape, output_shape !! Input shape character(256) :: buffer, tag, err_msg !! Buffer, tag, and error message character(256) :: value_str !! Temporary scalar value buffer ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose ! Loop over tags in layer card !--------------------------------------------------------------------------- tag_loop: do ! Check for end of file !------------------------------------------------------------------------ read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if if(trim(adjustl(buffer)).eq."") cycle tag_loop ! Check for end of layer card !------------------------------------------------------------------------ if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then backspace(unit) exit tag_loop end if tag=trim(adjustl(buffer)) if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) ! Read parameters from save file !------------------------------------------------------------------------ select case(trim(tag)) case("INPUT_RANK") value_str = get_val(buffer) read(value_str, *) input_rank case("INPUT_SHAPE") itmp1 = icount(get_val(buffer)) allocate(input_shape(itmp1), source=0) call assign_vec(buffer, input_shape, itmp1) case("OUTPUT_SHAPE") itmp1 = icount(get_val(buffer)) allocate(output_shape(itmp1), source=0) call assign_vec(buffer, output_shape, itmp1) case default ! Don't look for "e" due to scientific notation of numbers ! ... i.e. exponent (E+00) if( & scan( & to_lower(trim(adjustl(buffer))), & 'abcdfghijklmnopqrstuvwxyz' & ) .eq. 0 & )then cycle tag_loop elseif(tag(:3).eq.'END')then cycle tag_loop end if write(err_msg,'("Unrecognised line in input file: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end select end do tag_loop if(.not.allocated(output_shape))then call stop_program('("Reshape layer missing OUTPUT_SHAPE")') return end if ! Set hyperparameters and initialise layer !--------------------------------------------------------------------------- call this%set_hyperparams(output_shape = output_shape, verbose = verbose_) call this%init(input_shape = input_shape) ! Check for end of layer card !--------------------------------------------------------------------------- read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then write(0,*) trim(adjustl(buffer)) write(err_msg,'("END ",A," not where expected")') to_upper(this%name) call stop_program(err_msg) return end if end subroutine read_reshape !############################################################################### !############################################################################### function read_reshape_layer(unit, verbose) result(layer) !! Read reshape layer from file implicit none ! Arguments integer, intent(in) :: unit !! File unit integer, intent(in), optional :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Instance of the reshape layer ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose allocate(layer, source=reshape_layer_type(output_shape=[0])) call layer%read(unit, verbose=verbose_) end function read_reshape_layer !############################################################################### !############################################################################### subroutine build_from_onnx_reshape(this, node, initialisers, value_info, verbose) !! Build reshape layer from ONNX node and initialiser implicit none ! Arguments class(reshape_layer_type), intent(inout) :: this !! Instance of the reshape layer type(onnx_node_type), intent(in) :: node !! ONNX node type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers !! ONNX initialisers type(onnx_tensor_type), dimension(:), intent(in) :: value_info !! ONNX value infos integer, intent(in) :: verbose !! Verbosity level ! Local variables integer, dimension(:), allocatable :: output_shape !! Output shape ! Check size of initialisers is zero or one (shape can be an initialiser) if(size(initialisers).gt.1)then write(0,*) "WARNING: Multiple initialisers found for ONNX RESHAPE layer" end if ! Extract output shape from value_info (excluding batch dimension) if(allocated(value_info(1)%dims))then if(size(value_info(1)%dims).gt.1)then output_shape = value_info(1)%dims(2:) else allocate(output_shape(1)) output_shape(1) = value_info(1)%dims(1) end if else call stop_program("ONNX RESHAPE layer requires output shape in value_info") return end if call this%set_hyperparams( & output_shape = output_shape, & verbose = verbose & ) end subroutine build_from_onnx_reshape !############################################################################### !############################################################################### function create_from_onnx_reshape_layer( & node, initialisers, value_info, verbose & ) result(layer) !! Build reshape layer from ONNX node and initialiser implicit none ! Arguments type(onnx_node_type), intent(in) :: node !! ONNX node type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers !! ONNX initialisers type(onnx_tensor_type), dimension(:), intent(in) :: value_info !! ONNX value infos integer, intent(in), optional :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Instance of the reshape layer ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose allocate(layer, source=reshape_layer_type(output_shape=[0])) call layer%build_from_onnx(node, initialisers, value_info, verbose=verbose_) end function create_from_onnx_reshape_layer !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### subroutine forward_reshape(this, input) !! Forward propagation derived type handler implicit none ! Arguments class(reshape_layer_type), intent(inout) :: this !! Instance of the reshape layer class(array_type), dimension(:,:), intent(in) :: input !! Input array type(array_type), pointer :: ptr => null() ! Reshape input !--------------------------------------------------------------------------- call this%output(1,1)%zero_grad() ptr => reshape(input(1,1), this%output_shape) call this%output(1,1)%assign_and_deallocate_source(ptr) this%output(1,1)%is_temporary = .false. end subroutine forward_reshape !############################################################################### end module athena__reshape_layer