inverse_design_array_2d Module Function

module function inverse_design_array_2d(this, target, x_init, optimiser, steps) result(x_opt)

Optimise the input so the network output matches a target. Wraps the array_type implementation after converting to 2D array.

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout), target :: this

Instance of the network

type(array_type), intent(in), dimension(:,:) :: target

Target output values

type(array_type), intent(in), dimension(:,:) :: x_init

Initial input values

class(base_optimiser_type), intent(in), optional :: optimiser

Optimiser for input updates (defaults to network optimiser)

integer, intent(in) :: steps

Number of optimisation iterations

Return Value type(array_type), dimension(size(x_init,1), size(x_init,2))

Optimised input


Source Code

  module function inverse_design_array_2d( &
       this, target, x_init, optimiser, steps &
  ) result(x_opt)
    !! Optimise the input so the network output matches a target.
    !! Wraps the array_type implementation after converting to 2D array.
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of the network
    type(array_type), dimension(:,:), intent(in) :: target
    !! Target output values
    type(array_type), dimension(:,:), intent(in) :: x_init
    !! Initial input values
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser for input updates (defaults to network optimiser)
    integer, intent(in) :: steps
    !! Number of optimisation iterations
    type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt
    !! Optimised input

    ! Local variables
    integer :: step, i, j, itmp1, root_id, num_x, num_samples, num_elements
    !! Loop index, root layer id, number of input elements
    logical :: use_edge_features
    !! Whether edge features are used in the input
    type(array_type), pointer :: loss
    !! Loss pointer
    class(base_optimiser_type), allocatable :: opt
    !! Local optimiser instance
    real(real32), allocatable :: x_flat(:), x_grad(:)
    !! Flat input vector and gradient
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans
    real(real32), allocatable :: saved_params(:)
    !! Saved network parameters


    !---------------------------------------------------------------------------
    ! Ensure the network has a loss function
    !---------------------------------------------------------------------------
    if(.not.allocated(this%loss))then
       call this%set_loss("mse")
    end if


    !---------------------------------------------------------------------------
    ! Get number of input elements
    !---------------------------------------------------------------------------
    num_x = 0
    use_edge_features = .false.
    if(this%use_graph_input)then
       num_samples = size(x_init, dim=2)
       num_x = size(x_init(1,1)%val) ! vertex features
       ! determine if edge features are used by checking the output shape of the input layer
       if(size(this%model(this%root_vertices(1))%layer%output_shape,dim=1).eq.2)then
          use_edge_features = .true.
          num_x = num_x + size(x_init(2,1)%val) ! edge features
       end if
    else
       num_samples = size(x_init(1,1)%val, dim=2)
       do i = 1, size(x_init,1)
          do j = 1, size(x_init,2)
             num_x = num_x + size(x_init(i,j)%val,dim=1)
          end do
       end do
    end if
    x_opt = x_init
    if(num_samples.gt.1)then
       call stop_program( &
            "inverse_design_array_2d: batch size greater than 1 not supported" &
       )
    end if


    !---------------------------------------------------------------------------
    ! Set up optimiser for input variables
    !---------------------------------------------------------------------------
    if(present(optimiser))then
       allocate(opt, source=optimiser)
    else
       allocate(opt, source=base_optimiser_type( &
            learning_rate=this%optimiser%learning_rate))
    end if
    call opt%init_gradients(num_x)
    opt%iter = 0


    !---------------------------------------------------------------------------
    ! Pre-allocate flat arrays used in the optimisation loop
    !---------------------------------------------------------------------------
    allocate(x_flat(num_x))
    allocate(x_grad(num_x))


    !---------------------------------------------------------------------------
    ! Ensure training mode is active so the full graph is built
    !---------------------------------------------------------------------------
    call this%set_training_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Get root layer id
    !---------------------------------------------------------------------------
    root_id = this%auto_graph%vertex(this%root_vertices(1))%id
    call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Save network parameters so they can be restored afterwards
    !---------------------------------------------------------------------------
    allocate(saved_params(this%num_params))
    saved_params = this%get_params()


    !---------------------------------------------------------------------------
    ! Optimisation loop
    !---------------------------------------------------------------------------
    do step = 1, steps

       ! Forward pass with current x
       call this%forward(x_opt)

       ! Enable gradient tracking on the input layer output
       if(this%use_graph_input)then
          call this%model(root_id)%layer%output(1,1)%set_requires_grad(.true.)
          if(use_edge_features)then
             call this%model(root_id)%layer%output(2,1)%set_requires_grad(.true.)
          end if
       else
          do i = 1, size(x_opt,1)
             do j = 1, size(x_opt,2)
                call this%model(root_id)%layer%output(i,j)%set_requires_grad(.true.)
             end do
          end do
       end if

       ! Compute loss via the network's loss function
       call this%save_output(target)
       loss => this%loss_eval(1, num_samples)

       ! Backward pass
       call loss%grad_reverse()

       ! Extract gradient w.r.t. input
       itmp1 = 0
       if(associated(this%model(root_id)%layer%output(1,1)%grad))then
          if(this%use_graph_input)then
             num_elements = size(x_opt(1,1)%val, dim=1)
             do i = 1, size(x_opt(1,1)%val, dim=2)
                itmp1 = itmp1 + 1
                x_grad(itmp1:itmp1+num_elements-1) = &
                     this%model(root_id)%layer%output(1,1)%grad%val(:,i)
                x_flat(itmp1:itmp1+num_elements-1) = &
                     x_opt(1,1)%val(:,i)
                itmp1 = itmp1 + num_elements - 1
             end do
             if(use_edge_features)then
                num_elements = size(x_opt(1,1)%val, dim=1)
                do i = 1, size(x_opt(2,1)%val, dim=2)
                   itmp1 = itmp1 + 1
                   x_grad(itmp1:itmp1+num_elements-1) = &
                        this%model(root_id)%layer%output(2,1)%grad%val(:,i)
                   x_flat(itmp1:itmp1+num_elements-1) = &
                        x_opt(2,1)%val(:,i)
                   itmp1 = itmp1 + num_elements - 1
                end do
             end if
          else
             do i = 1, size(x_opt,1)
                do j = 1, size(x_opt,2)
                   num_elements = size(x_opt(i,j)%val, dim=1)
                   itmp1 = itmp1 + 1
                   x_grad(itmp1:itmp1+num_elements-1) = &
                        this%model(root_id)%layer%output(i,j)%grad%val(:,1)
                   x_flat(itmp1:itmp1+num_elements-1) = &
                        x_opt(i,j)%val(:,1)
                   itmp1 = itmp1 + num_elements - 1
                end do
             end do
          end if
       else
          x_grad = 0._real32
       end if

       ! Update x using the optimiser (not the model weights)
       opt%iter = opt%iter + 1
       call opt%minimise(x_flat, x_grad)

       ! Convert flat x back to array form
       itmp1 = 0
       if(this%use_graph_input)then
          do i = 1, size(x_opt(1,1)%val, dim=2)
             itmp1 = itmp1 + 1
             x_opt(1,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(1,1)%val, dim=1)-1)
             itmp1 = itmp1 + size(x_opt(1,1)%val, dim=1) - 1
          end do
          if(use_edge_features)then
             do i = 1, size(x_opt(2,1)%val, dim=2)
                itmp1 = itmp1 + 1
                x_opt(2,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(2,1)%val, dim=1)-1)
                itmp1 = itmp1 + size(x_opt(2,1)%val, dim=1) - 1
             end do
          end if
       else
          do i = 1, size(x_opt,1)
             do j = 1, size(x_opt,2)
                itmp1 = itmp1 + 1
                x_opt(i,j)%val(:,1) = x_flat(itmp1:itmp1+size(x_opt(i,j)%val, dim=1)-1)
                itmp1 = itmp1 + size(x_opt(i,j)%val, dim=1) - 1
             end do
          end do
       end if

       ! Clean up computation graph
       call loss%nullify_graph()
       deallocate(loss)
       nullify(loss)

       ! Reset network parameter gradients so they remain unchanged
       call this%reset_gradients()

    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Restore network parameters to ensure model is unchanged
    !---------------------------------------------------------------------------
    call this%set_params(saved_params)

  end function inverse_design_array_2d