module athena__clipper !! Module containing functions to clip gradients !! !! This module implements clipping methods for layer gradients use coreutils, only: real32 implicit none private public :: clip_type type clip_type !! Type for clipping gradients logical :: l_min_max = .false. !! Boolean whether min/max values are set logical :: l_norm = .false. !! Boolean whether a norm is set real(real32) :: min =-huge(1._real32) !! Minimum value for clipping real(real32) :: max = huge(1._real32) !! Maximum value for clipping real(real32) :: norm = huge(1._real32) !! Maximum L2-norm for clipping contains procedure, pass(this) :: read => read_clip !! Read clipping information procedure, pass(this) :: set => set_clip !! Set clipping information procedure, pass(this) :: apply => apply_clip !! Apply clipping to gradients end type clip_type interface clip_type !! Interface for the clip type module function clip_setup( & clip_min, clip_max, clip_norm) result(clip) !! Set up the clip dictionary real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm !! Minimum, maximum, and norm values for clipping type(clip_type) :: clip !! Clip dictionary end function clip_setup end interface clip_type contains !############################################################################### module function clip_setup( & clip_min, clip_max, clip_norm) result(clip) !! Set up the clip dictionary implicit none ! Arguments real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm !! Minimum, maximum, and norm values for clipping type(clip_type) :: clip !! Instance of the clip type !--------------------------------------------------------------------------- ! Set up clipping limits !--------------------------------------------------------------------------- if(present(clip_min))then clip%l_min_max = .true. clip%min = clip_min end if if(present(clip_max))then clip%l_min_max = .true. clip%max = clip_max end if if(present(clip_norm))then clip%l_norm = .true. clip%norm = clip_norm end if end function clip_setup !############################################################################### !############################################################################### subroutine read_clip(this, min_str, max_str, norm_str) !! Read clipping information implicit none ! Arguments class(clip_type), intent(inout) :: this !! Instance of the clip type character(*), intent(in) :: min_str, max_str, norm_str !! Strings for min, max, and norm values if(trim(min_str).ne."")then read(min_str,*) this%min else this%min = -huge(1._real32) end if if(trim(max_str).ne."")then read(max_str,*) this%max else this%max = huge(1._real32) end if if(trim(min_str).ne."".or.trim(max_str).ne."")then this%l_min_max = .true. end if if(trim(norm_str).ne."")then read(norm_str,*) this%norm this%l_norm = .true. end if end subroutine read_clip !############################################################################### !############################################################################### subroutine set_clip(this, clip_dict, clip_min, clip_max, clip_norm) !! Set clipping information implicit none ! Arguments class(clip_type), intent(inout) :: this !! Instance of the clip type type(clip_type), optional, intent(in) :: clip_dict !! Clip dictionary real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm !! Minimum, maximum, and norm values for clipping !--------------------------------------------------------------------------- ! Set up clipping limits !--------------------------------------------------------------------------- if(present(clip_dict))then this%l_min_max = clip_dict%l_min_max this%l_norm = clip_dict%l_norm this%min = clip_dict%min this%max = clip_dict%max this%norm = clip_dict%norm if(present(clip_min).or.present(clip_max).or.present(clip_norm))then write(*,*) "Multiple clip options provided" write(*,*) "Ignoring all except clip_dict" end if else if(present(clip_min))then this%l_min_max = .true. this%min = clip_min end if if(present(clip_max))then this%l_min_max = .true. this%max = clip_max end if if(present(clip_norm))then this%l_norm = .true. this%norm = clip_norm end if end if end subroutine set_clip !############################################################################### !############################################################################### pure subroutine apply_clip(this, length, gradient, bias) !! Function to apply clipping to gradients implicit none ! Arguments class(clip_type), intent(in) :: this !! Instance of the clip type integer, intent(in) :: length !! Length of the gradient real(real32), dimension(length), intent(inout) :: gradient !! Gradient to be clipped real(real32), dimension(:), optional, intent(inout) :: bias !! Bias to be clipped ! Local variables real(real32) :: scale !! Scaling factor for the gradient real(real32), dimension(:), allocatable :: bias_ !! Copy of the bias if(present(bias))then bias_ = bias else allocate(bias_(1), source=0._real32) end if ! Clip values to within limits of (min,max) if(this%l_min_max)then gradient = max(this%min,min(this%max,gradient)) bias_ = max(this%min,min(this%max,bias_)) end if ! Clip values to a maximum L2-norm if(this%l_norm)then scale = min(1._real32, & this%norm/sqrt(sum(gradient**2._real32) + & sum(bias_)**2._real32)) if(scale.lt.1._real32)then gradient = gradient * scale bias_ = bias_ * scale end if end if if(present(bias)) bias = bias_ end subroutine apply_clip !############################################################################### end module athena__clipper