pad_data Subroutine

public subroutine pad_data(data, data_padded, kernel_size, padding_method, sample_dim, channel_dim, constant)

Pad data for convolutional layers

Arguments

Type IntentOptional Attributes Name
real(kind=real32), intent(in), dimension(..) :: data

Data to be padded

real(kind=real32), intent(out), allocatable, dimension(..) :: data_padded

Padded data

integer, intent(in), dimension(..) :: kernel_size

Width of kernel/filter

character(len=*), intent(inout) :: padding_method

Padding method

integer, intent(in), optional :: sample_dim

Dimensions along which to pad

integer, intent(in), optional :: channel_dim

Dimensions along which to pad

real(kind=real32), intent(in), optional :: constant

Constant value for padding


Source Code

  subroutine pad_data( &
       data, data_padded, &
       kernel_size, padding_method, &
       sample_dim, channel_dim, constant &
  )
    !! Pad data for convolutional layers
    implicit none

    ! Arguments
    real(real32), dimension(..), intent(in) :: data
    !! Data to be padded
    real(real32), allocatable, dimension(..), intent(out) :: data_padded
    !! Padded data
    integer, dimension(..), intent(in) :: kernel_size
    !! Width of kernel/filter
    character(*), intent(inout) :: padding_method
    !! Padding method
    real(real32), optional, intent(in) :: constant
    !! Constant value for padding
    integer, optional, intent(in) :: sample_dim, channel_dim
    !! Dimensions along which to pad

    ! Local variables
    integer :: i, j, idim
    !! Loop indices
    integer :: num_samples, num_channels, ndim, ndata_dim
    !! Number of samples, channels, dimensions
    integer :: sample_dim_ = 0, channel_dim_ = 0
    !! Sample and channel dimensions
    real(real32) :: constant_ = 0._real32
    !! Constant value for padding
    integer, dimension(2) :: bound_store
    !! Store boundary indices
    integer, allocatable, dimension(:) :: padding
    !! Padding width
    integer, allocatable, dimension(:,:) :: trgt_bound, dest_bound
    !! Target and destination boundaries
    integer, allocatable, dimension(:,:) :: tmp_trgt_bound, tmp_dest_bound
    !! Temporary target and destination boundaries

    character(256) :: err_msg
    !! Error message


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(constant)) constant_ = constant
    if(present(sample_dim)) sample_dim_ = sample_dim
    if(present(channel_dim)) channel_dim_ = channel_dim

    ndim = rank(data)
#if defined(GFORTRAN)
    if(ndim.ne.rank(data_padded))then
       call stop_program("data and data_padded are not the same rank")
       return
    end if
#else
    select rank(data_padded)
    rank(1)
       if(ndim.ne.1)then
          call stop_program("data and data_padded are not the same rank")
          return
       end if
    rank(2)
       if(ndim.ne.2)then
          call stop_program("data and data_padded are not the same rank")
          return
       end if
    rank(3)
       if(ndim.ne.3)then
          call stop_program("data and data_padded are not the same rank")
          return
       end if
    rank(4)
       if(ndim.ne.4)then
          call stop_program("data and data_padded are not the same rank")
          return
       end if
    rank(5)
       if(ndim.ne.5)then
          call stop_program("data and data_padded are not the same rank")
          return
       end if
    end select
#endif
    ndata_dim = ndim
    if(sample_dim_.gt.0)  ndata_dim = ndata_dim - 1
    if(channel_dim_.gt.0) ndata_dim = ndata_dim - 1

    select rank(data)
    rank(1)
       if(sample_dim_.gt.0) num_samples = size(data,sample_dim_)
       if(channel_dim_.gt.0) num_channels = size(data,channel_dim_)
    rank(2)
       if(sample_dim_.gt.0) num_samples = size(data,sample_dim_)
       if(channel_dim_.gt.0) num_channels = size(data,channel_dim_)
    rank(3)
       if(sample_dim_.gt.0) num_samples = size(data,sample_dim_)
       if(channel_dim_.gt.0) num_channels = size(data,channel_dim_)
    rank(4)
       if(sample_dim_.gt.0) num_samples = size(data,sample_dim_)
       if(channel_dim_.gt.0) num_channels = size(data,channel_dim_)
    rank(5)
       if(sample_dim_.gt.0) num_samples = size(data,sample_dim_)
       if(channel_dim_.gt.0) num_channels = size(data,channel_dim_)
    rank default
       call stop_program("cannot handle data with this rank")
       return
    end select


    !---------------------------------------------------------------------------
    ! Handle padding type name
    !---------------------------------------------------------------------------
    ! none  = alt. name for 'valid'
    ! zero  = alt. name for 'same'
    ! symmetric = alt.name for 'replication'
    ! valid = no padding
    ! same  = maintain spatial dimensions
    !         ... (i.e. padding added = (kernel_size - 1)/2)
    !         ... defaults to zeros in the padding
    ! full  = enough padding for filter to slide over every possible position
    !         ... (i.e. padding added = (kernel_size - 1)
    ! circular = maintain spatial dimensions
    !            ... wraps data around for padding (periodic)
    ! reflection = maintains spatial dimensions
    !              ... reflect data (about boundary index)
    ! replication = maintains spatial dimensions
    !               ... reflect data (boundary included)
    select rank(kernel_size)
    rank(0)
       allocate(padding(ndata_dim))
       do i=1,ndata_dim
          call set_padding(padding(i), kernel_size, padding_method, verbose=0)
       end do
    rank(1)
       if(size(kernel_size).eq.1.and.ndata_dim.gt.1)then
          allocate(padding(ndata_dim))
          do i=1,ndata_dim
             call set_padding( &
                  padding(i), &
                  kernel_size(1), &
                  padding_method, &
                  verbose = 0 &
             )
          end do
       else
          if(sample_dim_.eq.0.and.channel_dim_.eq.0.and.&
               size(kernel_size).ne.ndim)then
             write(err_msg,'("&
                  &kernel_size length not equal to rank of data",A,"&
                  &kernel dimension: ",I0,A,"&
                  &data rank: ",I0)' &
             ) &
                  achar(13) // achar(10), size(kernel_size), &
                  achar(13) // achar(10), ndim
             call stop_program(err_msg)
             return
          elseif(sample_dim_.gt.0.and.channel_dim_.gt.0.and.&
               size(kernel_size).ne.ndim-2)then
             write(err_msg,'("&
                  &kernel_size length not equal to rank of data-2",A,"&
                  &kernel dimension: ",I0,A,"&
                  &data rank: ",I0)' &
             ) &
                  achar(13) // achar(10), size(kernel_size), &
                  achar(13) // achar(10), ndim-2
             call stop_program(err_msg)
             return
          elseif((sample_dim_.gt.0.or.channel_dim_.gt.0).and.&
               .not.(sample_dim_.gt.0.and.channel_dim_.gt.0).and.&
               size(kernel_size).ne.ndim-1)then
             write(err_msg,'("&
                  &kernel_size length not equal to rank of data-1",A,"&
                  &kernel dimension: ",I0,A,"&
                  &data rank: ",I0)' &
             ) &
                  achar(13) // achar(10), size(kernel_size), &
                  achar(13) // achar(10), ndim-1
             call stop_program(err_msg)
             return
          else
             allocate(padding(size(kernel_size)))
          end if
          do i=1,size(kernel_size)
             call set_padding( &
                  padding(i), kernel_size(i), padding_method, verbose=0 &
             )
          end do
       end if
    end select


    !---------------------------------------------------------------------------
    ! Allocate data set
    ! ... if appropriate, add padding
    !---------------------------------------------------------------------------
    select case(padding_method)
    case("same")
    case("full")
    case("zero")
    case default
       if(abs(constant_).gt.1.E-8)then
          write(*,*) "WARNING: constant is ignored for selected padding method"
       end if
    end select


    allocate(dest_bound(2,ndim))
    allocate(trgt_bound(2,ndim))
    i = 0
    do idim=1,ndim
       trgt_bound(:,idim) = [ lbound(data,dim=idim), ubound(data,dim=idim) ]
       dest_bound(:,idim) = trgt_bound(:,idim)
       if(idim.eq.sample_dim_.or.idim.eq.channel_dim_) cycle
       i = i + 1
       dest_bound(:,idim) = dest_bound(:,idim) + [ -padding(i), padding(i) ]
    end do

    select rank(data_padded)
    rank(1)
       allocate(data_padded(&
            dest_bound(1,1):dest_bound(2,1)), source = constant_)

       ! Copy input data
       !------------------------------------------------------------------------
       select rank(data)
       rank(1)
          data_padded( &
               trgt_bound(1,1):trgt_bound(2,1) &
          ) = data( &
               trgt_bound(1,1):trgt_bound(2,1) &
          )
       end select
    rank(2)
       allocate(data_padded(&
            dest_bound(1,1):dest_bound(2,1), &
            dest_bound(1,2):dest_bound(2,2)), source = constant_)

       ! Copy input data
       !------------------------------------------------------------------------
       select rank(data)
       rank(2)
          data_padded( &
               trgt_bound(1,1) : trgt_bound(2,1), &
               trgt_bound(1,2) : trgt_bound(2,2) &
          ) = data( &
               trgt_bound(1,1) : trgt_bound(2,1), &
               trgt_bound(1,2) : trgt_bound(2,2) &
          )
       end select
    rank(3)
       allocate( &
            data_padded(&
                 dest_bound(1,1):dest_bound(2,1),&
                 dest_bound(1,2):dest_bound(2,2),&
                 dest_bound(1,3):dest_bound(2,3) &
            ), source = constant_ &
       )

       ! Copy input data
       !------------------------------------------------------------------------
       select rank(data)
       rank(3)
          data_padded( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3) &
          ) = data( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3) &
          )
       end select
    rank(4)
       allocate( &
            data_padded( &
                 dest_bound(1,1):dest_bound(2,1), &
                 dest_bound(1,2):dest_bound(2,2), &
                 dest_bound(1,3):dest_bound(2,3), &
                 dest_bound(1,4):dest_bound(2,4) &
            ), source = constant_ &
       )

       ! Copy input data
       !------------------------------------------------------------------------
       select rank(data)
       rank(4)
          data_padded( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3), &
               trgt_bound(1,4):trgt_bound(2,4) &
          ) = data( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3), &
               trgt_bound(1,4):trgt_bound(2,4) &
          )
       end select
    rank(5)
       allocate( &
            data_padded(&
                 dest_bound(1,1):dest_bound(2,1), &
                 dest_bound(1,2):dest_bound(2,2), &
                 dest_bound(1,3):dest_bound(2,3), &
                 dest_bound(1,4):dest_bound(2,4), &
                 dest_bound(1,5):dest_bound(2,5) &
            ), source = constant_ &
       )

       ! Copy input data
       !------------------------------------------------------------------------
       select rank(data)
       rank(5)
          data_padded( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3), &
               trgt_bound(1,4):trgt_bound(2,4), &
               trgt_bound(1,5):trgt_bound(2,5) &
          ) = data( &
               trgt_bound(1,1):trgt_bound(2,1), &
               trgt_bound(1,2):trgt_bound(2,2), &
               trgt_bound(1,3):trgt_bound(2,3), &
               trgt_bound(1,4):trgt_bound(2,4), &
               trgt_bound(1,5):trgt_bound(2,5) &
          )
       end select
    end select


    !---------------------------------------------------------------------------
    ! Return if constant -- or no -- padding
    !---------------------------------------------------------------------------
    select case(padding_method)
    case ("same")
       return
    case("full")
       return
    case("zero")
       return
    case("valid", "vali")
       return
    end select


    !---------------------------------------------------------------------------
    ! Insert padding
    !---------------------------------------------------------------------------
    i = 0
    do idim=1,ndim
       if(idim.eq.sample_dim_.or.idim.eq.channel_dim_) cycle
       i = i + 1
       tmp_dest_bound = dest_bound
       tmp_trgt_bound = dest_bound
       tmp_dest_bound(:,idim) = [ dest_bound(1,idim), trgt_bound(1,idim) - 1 ]
       select case(padding_method)
       case ("circular")
          tmp_trgt_bound(:,idim) = &
               [ trgt_bound(2,idim) - padding(i) + 1, trgt_bound(2,idim) ]
       case("reflection")
          tmp_trgt_bound(:,idim) = &
               [ trgt_bound(1,idim) + 1, trgt_bound(1,idim) + padding(i) ]
       case("replication")
          tmp_trgt_bound(:,idim) = &
               [ trgt_bound(1,idim), trgt_bound(1,idim) + padding(i) - 1 ]
       end select
       do j = 1, 2
          select rank(data_padded)
          rank(1)
             data_padded( &
                  tmp_dest_bound(1,1):tmp_dest_bound(2,1) &
             ) = data_padded( &
                  tmp_trgt_bound(1,1):tmp_trgt_bound(2,1) &
             )
          rank(2)
             data_padded( &
                  tmp_dest_bound(1,1):tmp_dest_bound(2,1), &
                  tmp_dest_bound(1,2):tmp_dest_bound(2,2) &
             ) = data_padded( &
                  tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), &
                  tmp_trgt_bound(1,2):tmp_trgt_bound(2,2) &
             )
          rank(3)
             data_padded( &
                  tmp_dest_bound(1,1):tmp_dest_bound(2,1), &
                  tmp_dest_bound(1,2):tmp_dest_bound(2,2), &
                  tmp_dest_bound(1,3):tmp_dest_bound(2,3) &
             ) = data_padded( &
                  tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), &
                  tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), &
                  tmp_trgt_bound(1,3):tmp_trgt_bound(2,3) &
             )
          rank(4)
             data_padded( &
                  tmp_dest_bound(1,1):tmp_dest_bound(2,1), &
                  tmp_dest_bound(1,2):tmp_dest_bound(2,2), &
                  tmp_dest_bound(1,3):tmp_dest_bound(2,3), &
                  tmp_dest_bound(1,4):tmp_dest_bound(2,4) &
             ) = data_padded( &
                  tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), &
                  tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), &
                  tmp_trgt_bound(1,3):tmp_trgt_bound(2,3), &
                  tmp_trgt_bound(1,4):tmp_trgt_bound(2,4) &
             )
          rank(5)
             data_padded( &
                  tmp_dest_bound(1,1):tmp_dest_bound(2,1), &
                  tmp_dest_bound(1,2):tmp_dest_bound(2,2), &
                  tmp_dest_bound(1,3):tmp_dest_bound(2,3), &
                  tmp_dest_bound(1,4):tmp_dest_bound(2,4), &
                  tmp_dest_bound(1,5):tmp_dest_bound(2,5) &
             ) = data_padded( &
                  tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), &
                  tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), &
                  tmp_trgt_bound(1,3):tmp_trgt_bound(2,3), &
                  tmp_trgt_bound(1,4):tmp_trgt_bound(2,4), &
                  tmp_trgt_bound(1,5):tmp_trgt_bound(2,5) &
             )
          end select

          if(j.eq.2) exit
          bound_store(:) = tmp_dest_bound(:,idim)
          select case(padding_method)
          case ("circular")
             tmp_dest_bound(:,idim) = tmp_trgt_bound(:,idim) + padding(i)
             tmp_trgt_bound(:,idim) = bound_store(:) + padding(i)
          case("reflection")
             tmp_dest_bound(:,idim) = &
                  tmp_trgt_bound(:,idim) + size(data,idim) - 1
             tmp_trgt_bound(:,idim) = bound_store(:) + size(data,idim) - 1
          case("replication")
             tmp_dest_bound(:,idim) = tmp_trgt_bound(:,idim) + size(data,idim)
             tmp_trgt_bound(:,idim) = bound_store(:) + size(data,idim)
          end select
       end do
    end do

  end subroutine pad_data