Optimise the input so the network output matches a target. Wraps the array_type implementation after converting to 2D array.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout), | target | :: | this |
Instance of the network |
|
| type(array_type), | intent(in), | dimension(:,:) | :: | target |
Target output values |
|
| type(array_type), | intent(in), | dimension(:,:) | :: | x_init |
Initial input values |
|
| class(base_optimiser_type), | intent(in), | optional | :: | optimiser |
Optimiser for input updates (defaults to network optimiser) |
|
| integer, | intent(in) | :: | steps |
Number of optimisation iterations |
Optimised input
module function inverse_design_array_2d( & this, target, x_init, optimiser, steps & ) result(x_opt) !! Optimise the input so the network output matches a target. !! Wraps the array_type implementation after converting to 2D array. implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of the network type(array_type), dimension(:,:), intent(in) :: target !! Target output values type(array_type), dimension(:,:), intent(in) :: x_init !! Initial input values class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser for input updates (defaults to network optimiser) integer, intent(in) :: steps !! Number of optimisation iterations type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt !! Optimised input ! Local variables integer :: step, i, j, itmp1, root_id, num_x, num_samples, num_elements !! Loop index, root layer id, number of input elements logical :: use_edge_features !! Whether edge features are used in the input type(array_type), pointer :: loss !! Loss pointer class(base_optimiser_type), allocatable :: opt !! Local optimiser instance real(real32), allocatable :: x_flat(:), x_grad(:) !! Flat input vector and gradient logical, allocatable :: mode_store(:) !! Storage for inference mode booleans real(real32), allocatable :: saved_params(:) !! Saved network parameters !--------------------------------------------------------------------------- ! Ensure the network has a loss function !--------------------------------------------------------------------------- if(.not.allocated(this%loss))then call this%set_loss("mse") end if !--------------------------------------------------------------------------- ! Get number of input elements !--------------------------------------------------------------------------- num_x = 0 use_edge_features = .false. if(this%use_graph_input)then num_samples = size(x_init, dim=2) num_x = size(x_init(1,1)%val) ! vertex features ! determine if edge features are used by checking the output shape of the input layer if(size(this%model(this%root_vertices(1))%layer%output_shape,dim=1).eq.2)then use_edge_features = .true. num_x = num_x + size(x_init(2,1)%val) ! edge features end if else num_samples = size(x_init(1,1)%val, dim=2) do i = 1, size(x_init,1) do j = 1, size(x_init,2) num_x = num_x + size(x_init(i,j)%val,dim=1) end do end do end if x_opt = x_init if(num_samples.gt.1)then call stop_program( & "inverse_design_array_2d: batch size greater than 1 not supported" & ) end if !--------------------------------------------------------------------------- ! Set up optimiser for input variables !--------------------------------------------------------------------------- if(present(optimiser))then allocate(opt, source=optimiser) else allocate(opt, source=base_optimiser_type( & learning_rate=this%optimiser%learning_rate)) end if call opt%init_gradients(num_x) opt%iter = 0 !--------------------------------------------------------------------------- ! Pre-allocate flat arrays used in the optimisation loop !--------------------------------------------------------------------------- allocate(x_flat(num_x)) allocate(x_grad(num_x)) !--------------------------------------------------------------------------- ! Ensure training mode is active so the full graph is built !--------------------------------------------------------------------------- call this%set_training_mode(mode_store) !--------------------------------------------------------------------------- ! Get root layer id !--------------------------------------------------------------------------- root_id = this%auto_graph%vertex(this%root_vertices(1))%id call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Save network parameters so they can be restored afterwards !--------------------------------------------------------------------------- allocate(saved_params(this%num_params)) saved_params = this%get_params() !--------------------------------------------------------------------------- ! Optimisation loop !--------------------------------------------------------------------------- do step = 1, steps ! Forward pass with current x call this%forward(x_opt) ! Enable gradient tracking on the input layer output if(this%use_graph_input)then call this%model(root_id)%layer%output(1,1)%set_requires_grad(.true.) if(use_edge_features)then call this%model(root_id)%layer%output(2,1)%set_requires_grad(.true.) end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) call this%model(root_id)%layer%output(i,j)%set_requires_grad(.true.) end do end do end if ! Compute loss via the network's loss function call this%save_output(target) loss => this%loss_eval(1, num_samples) ! Backward pass call loss%grad_reverse() ! Extract gradient w.r.t. input itmp1 = 0 if(associated(this%model(root_id)%layer%output(1,1)%grad))then if(this%use_graph_input)then num_elements = size(x_opt(1,1)%val, dim=1) do i = 1, size(x_opt(1,1)%val, dim=2) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(1,1)%grad%val(:,i) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(1,1)%val(:,i) itmp1 = itmp1 + num_elements - 1 end do if(use_edge_features)then num_elements = size(x_opt(1,1)%val, dim=1) do i = 1, size(x_opt(2,1)%val, dim=2) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(2,1)%grad%val(:,i) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(2,1)%val(:,i) itmp1 = itmp1 + num_elements - 1 end do end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) num_elements = size(x_opt(i,j)%val, dim=1) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(i,j)%grad%val(:,1) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(i,j)%val(:,1) itmp1 = itmp1 + num_elements - 1 end do end do end if else x_grad = 0._real32 end if ! Update x using the optimiser (not the model weights) opt%iter = opt%iter + 1 call opt%minimise(x_flat, x_grad) ! Convert flat x back to array form itmp1 = 0 if(this%use_graph_input)then do i = 1, size(x_opt(1,1)%val, dim=2) itmp1 = itmp1 + 1 x_opt(1,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(1,1)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(1,1)%val, dim=1) - 1 end do if(use_edge_features)then do i = 1, size(x_opt(2,1)%val, dim=2) itmp1 = itmp1 + 1 x_opt(2,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(2,1)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(2,1)%val, dim=1) - 1 end do end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) itmp1 = itmp1 + 1 x_opt(i,j)%val(:,1) = x_flat(itmp1:itmp1+size(x_opt(i,j)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(i,j)%val, dim=1) - 1 end do end do end if ! Clean up computation graph call loss%nullify_graph() deallocate(loss) nullify(loss) ! Reset network parameter gradients so they remain unchanged call this%reset_gradients() end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) !--------------------------------------------------------------------------- ! Restore network parameters to ensure model is unchanged !--------------------------------------------------------------------------- call this%set_params(saved_params) end function inverse_design_array_2d