Pad data for convolutional layers
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| real(kind=real32), | intent(in), | dimension(..) | :: | data |
Data to be padded |
|
| real(kind=real32), | intent(out), | allocatable, dimension(..) | :: | data_padded |
Padded data |
|
| integer, | intent(in), | dimension(..) | :: | kernel_size |
Width of kernel/filter |
|
| character(len=*), | intent(inout) | :: | padding_method |
Padding method |
||
| integer, | intent(in), | optional | :: | sample_dim |
Dimensions along which to pad |
|
| integer, | intent(in), | optional | :: | channel_dim |
Dimensions along which to pad |
|
| real(kind=real32), | intent(in), | optional | :: | constant |
Constant value for padding |
subroutine pad_data( & data, data_padded, & kernel_size, padding_method, & sample_dim, channel_dim, constant & ) !! Pad data for convolutional layers implicit none ! Arguments real(real32), dimension(..), intent(in) :: data !! Data to be padded real(real32), allocatable, dimension(..), intent(out) :: data_padded !! Padded data integer, dimension(..), intent(in) :: kernel_size !! Width of kernel/filter character(*), intent(inout) :: padding_method !! Padding method real(real32), optional, intent(in) :: constant !! Constant value for padding integer, optional, intent(in) :: sample_dim, channel_dim !! Dimensions along which to pad ! Local variables integer :: i, j, idim !! Loop indices integer :: num_samples, num_channels, ndim, ndata_dim !! Number of samples, channels, dimensions integer :: sample_dim_ = 0, channel_dim_ = 0 !! Sample and channel dimensions real(real32) :: constant_ = 0._real32 !! Constant value for padding integer, dimension(2) :: bound_store !! Store boundary indices integer, allocatable, dimension(:) :: padding !! Padding width integer, allocatable, dimension(:,:) :: trgt_bound, dest_bound !! Target and destination boundaries integer, allocatable, dimension(:,:) :: tmp_trgt_bound, tmp_dest_bound !! Temporary target and destination boundaries character(256) :: err_msg !! Error message !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(constant)) constant_ = constant if(present(sample_dim)) sample_dim_ = sample_dim if(present(channel_dim)) channel_dim_ = channel_dim ndim = rank(data) #if defined(GFORTRAN) if(ndim.ne.rank(data_padded))then call stop_program("data and data_padded are not the same rank") return end if #else select rank(data_padded) rank(1) if(ndim.ne.1)then call stop_program("data and data_padded are not the same rank") return end if rank(2) if(ndim.ne.2)then call stop_program("data and data_padded are not the same rank") return end if rank(3) if(ndim.ne.3)then call stop_program("data and data_padded are not the same rank") return end if rank(4) if(ndim.ne.4)then call stop_program("data and data_padded are not the same rank") return end if rank(5) if(ndim.ne.5)then call stop_program("data and data_padded are not the same rank") return end if end select #endif ndata_dim = ndim if(sample_dim_.gt.0) ndata_dim = ndata_dim - 1 if(channel_dim_.gt.0) ndata_dim = ndata_dim - 1 select rank(data) rank(1) if(sample_dim_.gt.0) num_samples = size(data,sample_dim_) if(channel_dim_.gt.0) num_channels = size(data,channel_dim_) rank(2) if(sample_dim_.gt.0) num_samples = size(data,sample_dim_) if(channel_dim_.gt.0) num_channels = size(data,channel_dim_) rank(3) if(sample_dim_.gt.0) num_samples = size(data,sample_dim_) if(channel_dim_.gt.0) num_channels = size(data,channel_dim_) rank(4) if(sample_dim_.gt.0) num_samples = size(data,sample_dim_) if(channel_dim_.gt.0) num_channels = size(data,channel_dim_) rank(5) if(sample_dim_.gt.0) num_samples = size(data,sample_dim_) if(channel_dim_.gt.0) num_channels = size(data,channel_dim_) rank default call stop_program("cannot handle data with this rank") return end select !--------------------------------------------------------------------------- ! Handle padding type name !--------------------------------------------------------------------------- ! none = alt. name for 'valid' ! zero = alt. name for 'same' ! symmetric = alt.name for 'replication' ! valid = no padding ! same = maintain spatial dimensions ! ... (i.e. padding added = (kernel_size - 1)/2) ! ... defaults to zeros in the padding ! full = enough padding for filter to slide over every possible position ! ... (i.e. padding added = (kernel_size - 1) ! circular = maintain spatial dimensions ! ... wraps data around for padding (periodic) ! reflection = maintains spatial dimensions ! ... reflect data (about boundary index) ! replication = maintains spatial dimensions ! ... reflect data (boundary included) select rank(kernel_size) rank(0) allocate(padding(ndata_dim)) do i=1,ndata_dim call set_padding(padding(i), kernel_size, padding_method, verbose=0) end do rank(1) if(size(kernel_size).eq.1.and.ndata_dim.gt.1)then allocate(padding(ndata_dim)) do i=1,ndata_dim call set_padding( & padding(i), & kernel_size(1), & padding_method, & verbose = 0 & ) end do else if(sample_dim_.eq.0.and.channel_dim_.eq.0.and.& size(kernel_size).ne.ndim)then write(err_msg,'("& &kernel_size length not equal to rank of data",A,"& &kernel dimension: ",I0,A,"& &data rank: ",I0)' & ) & achar(13) // achar(10), size(kernel_size), & achar(13) // achar(10), ndim call stop_program(err_msg) return elseif(sample_dim_.gt.0.and.channel_dim_.gt.0.and.& size(kernel_size).ne.ndim-2)then write(err_msg,'("& &kernel_size length not equal to rank of data-2",A,"& &kernel dimension: ",I0,A,"& &data rank: ",I0)' & ) & achar(13) // achar(10), size(kernel_size), & achar(13) // achar(10), ndim-2 call stop_program(err_msg) return elseif((sample_dim_.gt.0.or.channel_dim_.gt.0).and.& .not.(sample_dim_.gt.0.and.channel_dim_.gt.0).and.& size(kernel_size).ne.ndim-1)then write(err_msg,'("& &kernel_size length not equal to rank of data-1",A,"& &kernel dimension: ",I0,A,"& &data rank: ",I0)' & ) & achar(13) // achar(10), size(kernel_size), & achar(13) // achar(10), ndim-1 call stop_program(err_msg) return else allocate(padding(size(kernel_size))) end if do i=1,size(kernel_size) call set_padding( & padding(i), kernel_size(i), padding_method, verbose=0 & ) end do end if end select !--------------------------------------------------------------------------- ! Allocate data set ! ... if appropriate, add padding !--------------------------------------------------------------------------- select case(padding_method) case("same") case("full") case("zero") case default if(abs(constant_).gt.1.E-8)then write(*,*) "WARNING: constant is ignored for selected padding method" end if end select allocate(dest_bound(2,ndim)) allocate(trgt_bound(2,ndim)) i = 0 do idim=1,ndim trgt_bound(:,idim) = [ lbound(data,dim=idim), ubound(data,dim=idim) ] dest_bound(:,idim) = trgt_bound(:,idim) if(idim.eq.sample_dim_.or.idim.eq.channel_dim_) cycle i = i + 1 dest_bound(:,idim) = dest_bound(:,idim) + [ -padding(i), padding(i) ] end do select rank(data_padded) rank(1) allocate(data_padded(& dest_bound(1,1):dest_bound(2,1)), source = constant_) ! Copy input data !------------------------------------------------------------------------ select rank(data) rank(1) data_padded( & trgt_bound(1,1):trgt_bound(2,1) & ) = data( & trgt_bound(1,1):trgt_bound(2,1) & ) end select rank(2) allocate(data_padded(& dest_bound(1,1):dest_bound(2,1), & dest_bound(1,2):dest_bound(2,2)), source = constant_) ! Copy input data !------------------------------------------------------------------------ select rank(data) rank(2) data_padded( & trgt_bound(1,1) : trgt_bound(2,1), & trgt_bound(1,2) : trgt_bound(2,2) & ) = data( & trgt_bound(1,1) : trgt_bound(2,1), & trgt_bound(1,2) : trgt_bound(2,2) & ) end select rank(3) allocate( & data_padded(& dest_bound(1,1):dest_bound(2,1),& dest_bound(1,2):dest_bound(2,2),& dest_bound(1,3):dest_bound(2,3) & ), source = constant_ & ) ! Copy input data !------------------------------------------------------------------------ select rank(data) rank(3) data_padded( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3) & ) = data( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3) & ) end select rank(4) allocate( & data_padded( & dest_bound(1,1):dest_bound(2,1), & dest_bound(1,2):dest_bound(2,2), & dest_bound(1,3):dest_bound(2,3), & dest_bound(1,4):dest_bound(2,4) & ), source = constant_ & ) ! Copy input data !------------------------------------------------------------------------ select rank(data) rank(4) data_padded( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3), & trgt_bound(1,4):trgt_bound(2,4) & ) = data( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3), & trgt_bound(1,4):trgt_bound(2,4) & ) end select rank(5) allocate( & data_padded(& dest_bound(1,1):dest_bound(2,1), & dest_bound(1,2):dest_bound(2,2), & dest_bound(1,3):dest_bound(2,3), & dest_bound(1,4):dest_bound(2,4), & dest_bound(1,5):dest_bound(2,5) & ), source = constant_ & ) ! Copy input data !------------------------------------------------------------------------ select rank(data) rank(5) data_padded( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3), & trgt_bound(1,4):trgt_bound(2,4), & trgt_bound(1,5):trgt_bound(2,5) & ) = data( & trgt_bound(1,1):trgt_bound(2,1), & trgt_bound(1,2):trgt_bound(2,2), & trgt_bound(1,3):trgt_bound(2,3), & trgt_bound(1,4):trgt_bound(2,4), & trgt_bound(1,5):trgt_bound(2,5) & ) end select end select !--------------------------------------------------------------------------- ! Return if constant -- or no -- padding !--------------------------------------------------------------------------- select case(padding_method) case ("same") return case("full") return case("zero") return case("valid", "vali") return end select !--------------------------------------------------------------------------- ! Insert padding !--------------------------------------------------------------------------- i = 0 do idim=1,ndim if(idim.eq.sample_dim_.or.idim.eq.channel_dim_) cycle i = i + 1 tmp_dest_bound = dest_bound tmp_trgt_bound = dest_bound tmp_dest_bound(:,idim) = [ dest_bound(1,idim), trgt_bound(1,idim) - 1 ] select case(padding_method) case ("circular") tmp_trgt_bound(:,idim) = & [ trgt_bound(2,idim) - padding(i) + 1, trgt_bound(2,idim) ] case("reflection") tmp_trgt_bound(:,idim) = & [ trgt_bound(1,idim) + 1, trgt_bound(1,idim) + padding(i) ] case("replication") tmp_trgt_bound(:,idim) = & [ trgt_bound(1,idim), trgt_bound(1,idim) + padding(i) - 1 ] end select do j = 1, 2 select rank(data_padded) rank(1) data_padded( & tmp_dest_bound(1,1):tmp_dest_bound(2,1) & ) = data_padded( & tmp_trgt_bound(1,1):tmp_trgt_bound(2,1) & ) rank(2) data_padded( & tmp_dest_bound(1,1):tmp_dest_bound(2,1), & tmp_dest_bound(1,2):tmp_dest_bound(2,2) & ) = data_padded( & tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), & tmp_trgt_bound(1,2):tmp_trgt_bound(2,2) & ) rank(3) data_padded( & tmp_dest_bound(1,1):tmp_dest_bound(2,1), & tmp_dest_bound(1,2):tmp_dest_bound(2,2), & tmp_dest_bound(1,3):tmp_dest_bound(2,3) & ) = data_padded( & tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), & tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), & tmp_trgt_bound(1,3):tmp_trgt_bound(2,3) & ) rank(4) data_padded( & tmp_dest_bound(1,1):tmp_dest_bound(2,1), & tmp_dest_bound(1,2):tmp_dest_bound(2,2), & tmp_dest_bound(1,3):tmp_dest_bound(2,3), & tmp_dest_bound(1,4):tmp_dest_bound(2,4) & ) = data_padded( & tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), & tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), & tmp_trgt_bound(1,3):tmp_trgt_bound(2,3), & tmp_trgt_bound(1,4):tmp_trgt_bound(2,4) & ) rank(5) data_padded( & tmp_dest_bound(1,1):tmp_dest_bound(2,1), & tmp_dest_bound(1,2):tmp_dest_bound(2,2), & tmp_dest_bound(1,3):tmp_dest_bound(2,3), & tmp_dest_bound(1,4):tmp_dest_bound(2,4), & tmp_dest_bound(1,5):tmp_dest_bound(2,5) & ) = data_padded( & tmp_trgt_bound(1,1):tmp_trgt_bound(2,1), & tmp_trgt_bound(1,2):tmp_trgt_bound(2,2), & tmp_trgt_bound(1,3):tmp_trgt_bound(2,3), & tmp_trgt_bound(1,4):tmp_trgt_bound(2,4), & tmp_trgt_bound(1,5):tmp_trgt_bound(2,5) & ) end select if(j.eq.2) exit bound_store(:) = tmp_dest_bound(:,idim) select case(padding_method) case ("circular") tmp_dest_bound(:,idim) = tmp_trgt_bound(:,idim) + padding(i) tmp_trgt_bound(:,idim) = bound_store(:) + padding(i) case("reflection") tmp_dest_bound(:,idim) = & tmp_trgt_bound(:,idim) + size(data,idim) - 1 tmp_trgt_bound(:,idim) = bound_store(:) + size(data,idim) - 1 case("replication") tmp_dest_bound(:,idim) = tmp_trgt_bound(:,idim) + size(data,idim) tmp_trgt_bound(:,idim) = bound_store(:) + size(data,idim) end select end do end do end subroutine pad_data