!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Utility functions for the perturbation calculations.
!> \note
!>      - routines are programmed with spins in mind
!>        but are as of now not tested with them
!> \par History
!>      22-08-2002, TCH, started development
! **************************************************************************************************
MODULE qs_p_env_methods
   USE admm_methods,                    ONLY: admm_aux_response_density
   USE admm_types,                      ONLY: admm_gapw_r3d_rs_type,&
                                              admm_type,&
                                              get_admm_env
   USE atomic_kind_types,               ONLY: atomic_kind_type
   USE cp_blacs_env,                    ONLY: cp_blacs_env_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
                                              dbcsr_copy,&
                                              dbcsr_p_type,&
                                              dbcsr_release,&
                                              dbcsr_scale,&
                                              dbcsr_set,&
                                              dbcsr_type
   USE cp_dbcsr_operations,             ONLY: copy_fm_to_dbcsr,&
                                              cp_dbcsr_plus_fm_fm_t,&
                                              cp_dbcsr_sm_fm_multiply,&
                                              dbcsr_allocate_matrix_set
   USE cp_fm_basic_linalg,              ONLY: cp_fm_triangular_multiply
   USE cp_fm_cholesky,                  ONLY: cp_fm_cholesky_decompose
   USE cp_fm_pool_types,                ONLY: cp_fm_pool_p_type,&
                                              cp_fm_pool_type,&
                                              fm_pool_create_fm,&
                                              fm_pool_get_el_struct,&
                                              fm_pool_give_back_fm,&
                                              fm_pools_create_fm_vect
   USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
                                              cp_fm_struct_get,&
                                              cp_fm_struct_release,&
                                              cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_info,&
                                              cp_fm_release,&
                                              cp_fm_set_all,&
                                              cp_fm_to_fm,&
                                              cp_fm_type
   USE cp_log_handling,                 ONLY: cp_get_default_logger,&
                                              cp_logger_type,&
                                              cp_to_string
   USE cp_output_handling,              ONLY: cp_print_key_finished_output,&
                                              cp_print_key_unit_nr
   USE hartree_local_methods,           ONLY: init_coulomb_local
   USE hartree_local_types,             ONLY: hartree_local_create
   USE input_constants,                 ONLY: do_admm_aux_exch_func_none,&
                                              ot_precond_none
   USE input_section_types,             ONLY: section_vals_get,&
                                              section_vals_get_subs_vals,&
                                              section_vals_type
   USE kinds,                           ONLY: default_string_length,&
                                              dp
   USE message_passing,                 ONLY: mp_para_env_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE preconditioner_types,            ONLY: init_preconditioner
   USE pw_env_types,                    ONLY: pw_env_type
   USE pw_types,                        ONLY: pw_c1d_gs_type,&
                                              pw_r3d_rs_type
   USE qs_collocate_density,            ONLY: calculate_rho_elec
   USE qs_energy_types,                 ONLY: qs_energy_type
   USE qs_environment_types,            ONLY: get_qs_env,&
                                              qs_environment_type
   USE qs_kind_types,                   ONLY: qs_kind_type
   USE qs_kpp1_env_methods,             ONLY: kpp1_create,&
                                              kpp1_did_change
   USE qs_ks_methods,                   ONLY: qs_ks_update_qs_env
   USE qs_ks_types,                     ONLY: qs_ks_did_change,&
                                              qs_ks_env_type
   USE qs_linres_types,                 ONLY: linres_control_type
   USE qs_local_rho_types,              ONLY: local_rho_set_create
   USE qs_matrix_pools,                 ONLY: mpools_get
   USE qs_mo_types,                     ONLY: get_mo_set,&
                                              mo_set_type
   USE qs_neighbor_list_types,          ONLY: neighbor_list_set_p_type
   USE qs_p_env_types,                  ONLY: qs_p_env_type
   USE qs_rho0_ggrid,                   ONLY: rho0_s_grid_create
   USE qs_rho0_methods,                 ONLY: init_rho0
   USE qs_rho_atom_methods,             ONLY: allocate_rho_atom_internals,&
                                              calculate_rho_atom_coeff
   USE qs_rho_methods,                  ONLY: qs_rho_rebuild,&
                                              qs_rho_update_rho
   USE qs_rho_types,                    ONLY: qs_rho_create,&
                                              qs_rho_get,&
                                              qs_rho_type
   USE string_utilities,                ONLY: compress
   USE task_list_types,                 ONLY: task_list_type
#include "./base/base_uses.f90"

   IMPLICIT NONE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'qs_p_env_methods'
   LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .FALSE.

   PRIVATE
   PUBLIC :: p_env_create, p_env_psi0_changed
   PUBLIC :: p_preortho, p_postortho
   PUBLIC :: p_env_check_i_alloc, p_env_update_rho
   PUBLIC :: p_env_finish_kpp1

CONTAINS

! **************************************************************************************************
!> \brief allocates and initializes the perturbation environment (no setup)
!> \param p_env the environment to initialize
!> \param qs_env the qs_environment for the system
!> \param p1_option ...
!> \param p1_admm_option ...
!> \param orthogonal_orbitals if the orbitals are orthogonal
!> \param linres_control ...
!> \par History
!>      07.2002 created [fawzi]
!> \author Fawzi Mohamed
! **************************************************************************************************
   SUBROUTINE p_env_create(p_env, qs_env, p1_option, p1_admm_option, &
                           orthogonal_orbitals, linres_control)

      TYPE(qs_p_env_type)                                :: p_env
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(dbcsr_p_type), DIMENSION(:), OPTIONAL, &
         POINTER                                         :: p1_option, p1_admm_option
      LOGICAL, INTENT(in), OPTIONAL                      :: orthogonal_orbitals
      TYPE(linres_control_type), OPTIONAL, POINTER       :: linres_control

      CHARACTER(len=*), PARAMETER                        :: routineN = 'p_env_create'

      INTEGER                                            :: handle, n_ao, n_mo, n_spins, natom, spin
      TYPE(admm_gapw_r3d_rs_type), POINTER               :: admm_gapw_env
      TYPE(admm_type), POINTER                           :: admm_env
      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
      TYPE(cp_fm_pool_p_type), DIMENSION(:), POINTER     :: ao_mo_fm_pools, mo_mo_fm_pools
      TYPE(cp_fm_type), POINTER                          :: qs_env_c
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s, matrix_s_aux_fit
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(pw_env_type), POINTER                         :: pw_env
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set

      CALL timeset(routineN, handle)
      NULLIFY (ao_mo_fm_pools, mo_mo_fm_pools, matrix_s, dft_control, para_env, blacs_env)
      CALL get_qs_env(qs_env, &
                      matrix_s=matrix_s, &
                      dft_control=dft_control, &
                      para_env=para_env, &
                      blacs_env=blacs_env)

      n_spins = dft_control%nspins

      p_env%new_preconditioner = .TRUE.

      ALLOCATE (p_env%rho1)
      CALL qs_rho_create(p_env%rho1)
      ALLOCATE (p_env%rho1_xc)
      CALL qs_rho_create(p_env%rho1_xc)

      ALLOCATE (p_env%kpp1_env)
      CALL kpp1_create(p_env%kpp1_env)

      IF (PRESENT(p1_option)) THEN
         p_env%p1 => p1_option
      ELSE
         CALL dbcsr_allocate_matrix_set(p_env%p1, n_spins)
         DO spin = 1, n_spins
            ALLOCATE (p_env%p1(spin)%matrix)
            CALL dbcsr_copy(p_env%p1(spin)%matrix, matrix_s(1)%matrix, &
                            name="p_env%p1-"//TRIM(ADJUSTL(cp_to_string(spin))))
            CALL dbcsr_set(p_env%p1(spin)%matrix, 0.0_dp)
         END DO
      END IF

      IF (dft_control%do_admm) THEN
         CALL get_admm_env(qs_env%admm_env, matrix_s_aux_fit=matrix_s_aux_fit)
         IF (qs_env%admm_env%aux_exch_func /= do_admm_aux_exch_func_none) THEN
            ALLOCATE (p_env%rho1_admm)
            CALL qs_rho_create(p_env%rho1_admm)
         END IF

         IF (PRESENT(p1_admm_option)) THEN
            p_env%p1_admm => p1_admm_option
         ELSE
            CALL dbcsr_allocate_matrix_set(p_env%p1_admm, n_spins)
            DO spin = 1, n_spins
               ALLOCATE (p_env%p1_admm(spin)%matrix)
               CALL dbcsr_copy(p_env%p1_admm(spin)%matrix, matrix_s_aux_fit(1)%matrix, &
                               name="p_env%p1_admm-"//TRIM(ADJUSTL(cp_to_string(spin))))
               CALL dbcsr_set(p_env%p1_admm(spin)%matrix, 0.0_dp)
            END DO
         END IF
         CALL get_qs_env(qs_env, admm_env=admm_env)
         IF (admm_env%do_gapw) THEN
            CALL get_qs_env(qs_env, atomic_kind_set=atomic_kind_set)
            admm_gapw_env => admm_env%admm_gapw_env
            CALL local_rho_set_create(p_env%local_rho_set_admm)
            CALL allocate_rho_atom_internals(p_env%local_rho_set_admm%rho_atom_set, atomic_kind_set, &
                                             admm_gapw_env%admm_kind_set, dft_control, para_env)
         END IF
      END IF

      CALL mpools_get(qs_env%mpools, ao_mo_fm_pools=ao_mo_fm_pools, &
                      mo_mo_fm_pools=mo_mo_fm_pools)

      p_env%n_mo = 0
      p_env%n_ao = 0
      DO spin = 1, n_spins
         CALL get_mo_set(qs_env%mos(spin), mo_coeff=qs_env_c)
         CALL cp_fm_get_info(qs_env_c, &
                             ncol_global=n_mo, nrow_global=n_ao)
         p_env%n_mo(spin) = n_mo
         p_env%n_ao(spin) = n_ao
      END DO

      p_env%orthogonal_orbitals = .FALSE.
      IF (PRESENT(orthogonal_orbitals)) &
         p_env%orthogonal_orbitals = orthogonal_orbitals

      CALL fm_pools_create_fm_vect(ao_mo_fm_pools, elements=p_env%S_psi0, &
                                   name="p_env%S_psi0")

      ! alloc m_epsilon
      CALL fm_pools_create_fm_vect(mo_mo_fm_pools, elements=p_env%m_epsilon, &
                                   name="p_env%m_epsilon")

      ! alloc Smo_inv
      IF (.NOT. p_env%orthogonal_orbitals) THEN
         CALL fm_pools_create_fm_vect(mo_mo_fm_pools, elements=p_env%Smo_inv, &
                                      name="p_env%Smo_inv")
      END IF

      IF (.NOT. p_env%orthogonal_orbitals) THEN
         CALL fm_pools_create_fm_vect(ao_mo_fm_pools, &
                                      elements=p_env%psi0d, &
                                      name="p_env%psi0d")
      END IF

      !------------------------------!
      ! GAPW/GAPW_XC initializations !
      !------------------------------!
      IF (dft_control%qs_control%gapw) THEN
         CALL get_qs_env(qs_env, &
                         atomic_kind_set=atomic_kind_set, &
                         natom=natom, &
                         pw_env=pw_env, &
                         qs_kind_set=qs_kind_set)

         CALL local_rho_set_create(p_env%local_rho_set)
         CALL allocate_rho_atom_internals(p_env%local_rho_set%rho_atom_set, atomic_kind_set, &
                                          qs_kind_set, dft_control, para_env)

         CALL init_rho0(p_env%local_rho_set, qs_env, dft_control%qs_control%gapw_control, &
                        zcore=0.0_dp)
         CALL rho0_s_grid_create(pw_env, p_env%local_rho_set%rho0_mpole)
         CALL hartree_local_create(p_env%hartree_local)
         CALL init_coulomb_local(p_env%hartree_local, natom)
      ELSEIF (dft_control%qs_control%gapw_xc) THEN
         CALL get_qs_env(qs_env, &
                         atomic_kind_set=atomic_kind_set, &
                         qs_kind_set=qs_kind_set)
         CALL local_rho_set_create(p_env%local_rho_set)
         CALL allocate_rho_atom_internals(p_env%local_rho_set%rho_atom_set, atomic_kind_set, &
                                          qs_kind_set, dft_control, para_env)
      END IF

      !------------------------!
      ! LINRES initializations !
      !------------------------!
      IF (PRESENT(linres_control)) THEN

         IF (linres_control%preconditioner_type /= ot_precond_none) THEN
            ! Initialize the preconditioner matrix
            IF (.NOT. ASSOCIATED(p_env%preconditioner)) THEN

               ALLOCATE (p_env%preconditioner(n_spins))
               DO spin = 1, n_spins
                  CALL init_preconditioner(p_env%preconditioner(spin), &
                                           para_env=para_env, blacs_env=blacs_env)
               END DO

               CALL fm_pools_create_fm_vect(ao_mo_fm_pools, elements=p_env%PS_psi0, &
                                            name="p_env%PS_psi0")
            END IF
         END IF

      END IF

      CALL timestop(handle)

   END SUBROUTINE p_env_create

! **************************************************************************************************
!> \brief checks that the intenal storage is allocated, and allocs it if needed
!> \param p_env the environment to check
!> \param qs_env the qs environment this p_env lives in
!> \par History
!>      12.2002 created [fawzi]
!> \author Fawzi Mohamed
!> \note
!>      private routine
! **************************************************************************************************
   SUBROUTINE p_env_check_i_alloc(p_env, qs_env)
      TYPE(qs_p_env_type)                                :: p_env
      TYPE(qs_environment_type), POINTER                 :: qs_env

      CHARACTER(len=*), PARAMETER :: routineN = 'p_env_check_i_alloc'

      CHARACTER(len=25)                                  :: name
      INTEGER                                            :: handle, ispin, nspins
      LOGICAL                                            :: gapw_xc
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      TYPE(dft_control_type), POINTER                    :: dft_control

      CALL timeset(routineN, handle)

      NULLIFY (dft_control, matrix_s)

      CALL get_qs_env(qs_env, dft_control=dft_control)
      gapw_xc = dft_control%qs_control%gapw_xc
      IF (.NOT. ASSOCIATED(p_env%kpp1)) THEN
         CALL get_qs_env(qs_env, matrix_s=matrix_s)
         nspins = dft_control%nspins

         CALL dbcsr_allocate_matrix_set(p_env%kpp1, nspins)
         name = "p_env%kpp1-"
         CALL compress(name, full=.TRUE.)
         DO ispin = 1, nspins
            ALLOCATE (p_env%kpp1(ispin)%matrix)
            CALL dbcsr_copy(p_env%kpp1(ispin)%matrix, matrix_s(1)%matrix, &
                            name=TRIM(name)//ADJUSTL(cp_to_string(ispin)))
            CALL dbcsr_set(p_env%kpp1(ispin)%matrix, 0.0_dp)
         END DO

         CALL qs_rho_rebuild(p_env%rho1, qs_env=qs_env)
         IF (gapw_xc) THEN
            CALL qs_rho_rebuild(p_env%rho1_xc, qs_env=qs_env)
         END IF

      END IF

      IF (dft_control%do_admm .AND. .NOT. ASSOCIATED(p_env%kpp1_admm)) THEN
         CALL get_admm_env(qs_env%admm_env, matrix_s_aux_fit=matrix_s)
         nspins = dft_control%nspins

         CALL dbcsr_allocate_matrix_set(p_env%kpp1_admm, nspins)
         name = "p_env%kpp1_admm-"
         CALL compress(name, full=.TRUE.)
         DO ispin = 1, nspins
            ALLOCATE (p_env%kpp1_admm(ispin)%matrix)
            CALL dbcsr_copy(p_env%kpp1_admm(ispin)%matrix, matrix_s(1)%matrix, &
                            name=TRIM(name)//ADJUSTL(cp_to_string(ispin)))
            CALL dbcsr_set(p_env%kpp1_admm(ispin)%matrix, 0.0_dp)
         END DO

         IF (qs_env%admm_env%aux_exch_func /= do_admm_aux_exch_func_none) THEN
            CALL qs_rho_rebuild(p_env%rho1_admm, qs_env=qs_env, admm=.TRUE.)
         END IF

      END IF

      IF (.NOT. ASSOCIATED(p_env%rho1)) THEN
         CALL qs_rho_rebuild(p_env%rho1, qs_env=qs_env)
         IF (gapw_xc) THEN
            CALL qs_rho_rebuild(p_env%rho1_xc, qs_env=qs_env)
         END IF

         IF (dft_control%do_admm) THEN
            IF (qs_env%admm_env%aux_exch_func /= do_admm_aux_exch_func_none) THEN
               CALL qs_rho_rebuild(p_env%rho1_admm, qs_env=qs_env, admm=.TRUE.)
            END IF
         END IF

      END IF

      CALL timestop(handle)
   END SUBROUTINE p_env_check_i_alloc

! **************************************************************************************************
!> \brief ...
!> \param p_env ...
!> \param qs_env ...
! **************************************************************************************************
   SUBROUTINE p_env_update_rho(p_env, qs_env)
      TYPE(qs_p_env_type), INTENT(IN)                    :: p_env
      TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'p_env_update_rho'

      CHARACTER(LEN=default_string_length)               :: basis_type
      INTEGER                                            :: handle, ispin
      TYPE(admm_type), POINTER                           :: admm_env
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: rho1_ao
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_aux_fit
      TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g_aux
      TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r_aux
      TYPE(qs_ks_env_type), POINTER                      :: ks_env
      TYPE(task_list_type), POINTER                      :: task_list

      CALL timeset(routineN, handle)

      CALL get_qs_env(qs_env, dft_control=dft_control)

      IF (dft_control%do_admm) CALL admm_aux_response_density(qs_env, p_env%p1, p_env%p1_admm)

      CALL qs_rho_get(p_env%rho1, rho_ao=rho1_ao)
      DO ispin = 1, SIZE(rho1_ao)
         CALL dbcsr_copy(rho1_ao(ispin)%matrix, p_env%p1(ispin)%matrix)
      END DO

      CALL qs_rho_update_rho(rho_struct=p_env%rho1, &
                             rho_xc_external=p_env%rho1_xc, &
                             local_rho_set=p_env%local_rho_set, &
                             qs_env=qs_env)

      IF (dft_control%do_admm) THEN
         IF (qs_env%admm_env%aux_exch_func /= do_admm_aux_exch_func_none) THEN
            NULLIFY (ks_env, rho1_ao, rho_g_aux, rho_r_aux, task_list)

            CALL get_qs_env(qs_env, ks_env=ks_env, admm_env=admm_env)
            basis_type = "AUX_FIT"
            CALL get_admm_env(qs_env%admm_env, task_list_aux_fit=task_list)
            IF (admm_env%do_gapw) THEN
               basis_type = "AUX_FIT_SOFT"
               task_list => admm_env%admm_gapw_env%task_list
            END IF
            CALL qs_rho_get(p_env%rho1_admm, &
                            rho_ao=rho1_ao, &
                            rho_g=rho_g_aux, &
                            rho_r=rho_r_aux)
            DO ispin = 1, SIZE(rho1_ao)
               CALL dbcsr_copy(rho1_ao(ispin)%matrix, p_env%p1_admm(ispin)%matrix)
               CALL calculate_rho_elec(ks_env=ks_env, &
                                       matrix_p=rho1_ao(ispin)%matrix, &
                                       rho=rho_r_aux(ispin), &
                                       rho_gspace=rho_g_aux(ispin), &
                                       soft_valid=.FALSE., &
                                       basis_type=basis_type, &
                                       task_list_external=task_list)
            END DO
            IF (admm_env%do_gapw) THEN
               CALL get_qs_env(qs_env, para_env=para_env)
               CALL get_admm_env(admm_env, sab_aux_fit=sab_aux_fit)
               CALL calculate_rho_atom_coeff(qs_env, rho1_ao, &
                                             rho_atom_set=p_env%local_rho_set_admm%rho_atom_set, &
                                             qs_kind_set=admm_env%admm_gapw_env%admm_kind_set, &
                                             oce=admm_env%admm_gapw_env%oce, sab=sab_aux_fit, para_env=para_env)
            END IF
         END IF
      END IF

      CALL timestop(handle)

   END SUBROUTINE p_env_update_rho

! **************************************************************************************************
!> \brief To be called after the value of psi0 has changed.
!>      Recalculates the quantities S_psi0 and m_epsilon.
!> \param p_env the perturbation environment to set
!> \param qs_env ...
!> \par History
!>      07.2002 created [fawzi]
!> \author Fawzi Mohamed
! **************************************************************************************************
   SUBROUTINE p_env_psi0_changed(p_env, qs_env)

      TYPE(qs_p_env_type)                                :: p_env
      TYPE(qs_environment_type), POINTER                 :: qs_env

      CHARACTER(len=*), PARAMETER :: routineN = 'p_env_psi0_changed'

      INTEGER                                            :: handle, iounit, lfomo, n_spins, nmo, spin
      LOGICAL                                            :: was_present
      REAL(KIND=dp)                                      :: maxocc
      TYPE(cp_fm_pool_p_type), DIMENSION(:), POINTER     :: ao_mo_fm_pools
      TYPE(cp_fm_type), DIMENSION(:), POINTER            :: psi0
      TYPE(cp_fm_type), POINTER                          :: mo_coeff
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks, matrix_s, rho_ao
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(mo_set_type), DIMENSION(:), POINTER           :: mos
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(qs_energy_type), POINTER                      :: energy
      TYPE(qs_ks_env_type), POINTER                      :: ks_env
      TYPE(qs_rho_type), POINTER                         :: rho
      TYPE(section_vals_type), POINTER                   :: input, lr_section

      CALL timeset(routineN, handle)

      NULLIFY (ao_mo_fm_pools, mos, psi0, matrix_s, mos, para_env, ks_env, rho, &
               logger, input, lr_section, energy, matrix_ks, dft_control, rho_ao)
      logger => cp_get_default_logger()

      CALL get_qs_env(qs_env, &
                      ks_env=ks_env, &
                      mos=mos, &
                      matrix_s=matrix_s, &
                      matrix_ks=matrix_ks, &
                      para_env=para_env, &
                      rho=rho, &
                      input=input, &
                      energy=energy, &
                      dft_control=dft_control)

      CALL qs_rho_get(rho, rho_ao=rho_ao)

      n_spins = dft_control%nspins
      CALL mpools_get(qs_env%mpools, &
                      ao_mo_fm_pools=ao_mo_fm_pools)
      ALLOCATE (psi0(n_spins))
      DO spin = 1, n_spins
         CALL get_mo_set(mos(spin), mo_coeff=mo_coeff)
         CALL cp_fm_create(psi0(spin), mo_coeff%matrix_struct)
         CALL cp_fm_to_fm(mo_coeff, psi0(spin))
      END DO

      lr_section => section_vals_get_subs_vals(input, "PROPERTIES%LINRES")
      ! def psi0d
      IF (p_env%orthogonal_orbitals) THEN
         IF (ASSOCIATED(p_env%psi0d)) THEN
            CALL cp_fm_release(p_env%psi0d)
         END IF
         p_env%psi0d => psi0
      ELSE

         DO spin = 1, n_spins
            ! m_epsilon=cholesky_decomposition(psi0^T S psi0)^-1
            ! could be optimized by combining next two calls
            CALL cp_dbcsr_sm_fm_multiply(matrix_s(1)%matrix, &
                                         psi0(spin), &
                                         p_env%S_psi0(spin), &
                                         ncol=p_env%n_mo(spin), alpha=1.0_dp)
            CALL parallel_gemm(transa='T', transb='N', n=p_env%n_mo(spin), &
                               m=p_env%n_mo(spin), k=p_env%n_ao(spin), alpha=1.0_dp, &
                               matrix_a=psi0(spin), &
                               matrix_b=p_env%S_psi0(spin), &
                               beta=0.0_dp, matrix_c=p_env%m_epsilon(spin))
            CALL cp_fm_cholesky_decompose(p_env%m_epsilon(spin), &
                                          n=p_env%n_mo(spin))

            ! Smo_inv= (psi0^T S psi0)^-1
            CALL cp_fm_set_all(p_env%Smo_inv(spin), 0.0_dp, 1.0_dp)
            ! faster using cp_fm_cholesky_invert ?
            CALL cp_fm_triangular_multiply( &
               triangular_matrix=p_env%m_epsilon(spin), &
               matrix_b=p_env%Smo_inv(spin), side='R', &
               invert_tr=.TRUE., n_rows=p_env%n_mo(spin), &
               n_cols=p_env%n_mo(spin))
            CALL cp_fm_triangular_multiply( &
               triangular_matrix=p_env%m_epsilon(spin), &
               matrix_b=p_env%Smo_inv(spin), side='R', &
               transpose_tr=.TRUE., &
               invert_tr=.TRUE., n_rows=p_env%n_mo(spin), &
               n_cols=p_env%n_mo(spin))

            ! psi0d=psi0 (psi0^T S psi0)^-1
            ! faster using cp_fm_cholesky_invert ?
            CALL cp_fm_to_fm(psi0(spin), &
                             p_env%psi0d(spin))
            CALL cp_fm_triangular_multiply( &
               triangular_matrix=p_env%m_epsilon(spin), &
               matrix_b=p_env%psi0d(spin), side='R', &
               invert_tr=.TRUE., n_rows=p_env%n_ao(spin), &
               n_cols=p_env%n_mo(spin))
            CALL cp_fm_triangular_multiply( &
               triangular_matrix=p_env%m_epsilon(spin), &
               matrix_b=p_env%psi0d(spin), side='R', &
               transpose_tr=.TRUE., &
               invert_tr=.TRUE., n_rows=p_env%n_ao(spin), &
               n_cols=p_env%n_mo(spin))

            ! updates P
            CALL get_mo_set(mos(spin), lfomo=lfomo, &
                            nmo=nmo, maxocc=maxocc)
            IF (lfomo > nmo) THEN
               CALL dbcsr_set(rho_ao(spin)%matrix, 0.0_dp)
               CALL cp_dbcsr_plus_fm_fm_t(rho_ao(spin)%matrix, &
                                          matrix_v=psi0(spin), &
                                          matrix_g=p_env%psi0d(spin), &
                                          ncol=p_env%n_mo(spin))
               CALL dbcsr_scale(rho_ao(spin)%matrix, alpha_scalar=maxocc)
            ELSE
               CPABORT("symmetrized onesided smearing to do")
            END IF
         END DO

         ! updates rho
         CALL qs_rho_update_rho(rho_struct=rho, qs_env=qs_env)

         ! tells ks_env that p changed
         CALL qs_ks_did_change(ks_env=ks_env, rho_changed=.TRUE.)

      END IF

      ! updates K (if necessary)
      CALL qs_ks_update_qs_env(qs_env)
      iounit = cp_print_key_unit_nr(logger, lr_section, "PRINT%PROGRAM_RUN_INFO", &
                                    extension=".linresLog")
      IF (iounit > 0) THEN
         CALL section_vals_get(lr_section, explicit=was_present)
         IF (was_present) THEN
            WRITE (UNIT=iounit, FMT="(/,(T3,A,T55,F25.14))") &
               "Total energy ground state:                     ", energy%total
         END IF
      END IF
      CALL cp_print_key_finished_output(iounit, logger, lr_section, &
                                        "PRINT%PROGRAM_RUN_INFO")
      !-----------------------------------------------------------------------|
      ! calculates                                                            |
      ! m_epsilon = - psi0d^T times K times psi0d                             |
      !           = - [K times psi0d]^T times psi0d (because K is symmetric)  |
      !-----------------------------------------------------------------------|
      DO spin = 1, n_spins
         ! S_psi0 = k times psi0d
         CALL cp_dbcsr_sm_fm_multiply(matrix_ks(spin)%matrix, &
                                      p_env%psi0d(spin), &
                                      p_env%S_psi0(spin), p_env%n_mo(spin))
         ! m_epsilon = -1 times S_psi0^T times psi0d
         CALL parallel_gemm('T', 'N', &
                            p_env%n_mo(spin), p_env%n_mo(spin), p_env%n_ao(spin), &
                            -1.0_dp, p_env%S_psi0(spin), p_env%psi0d(spin), &
                            0.0_dp, p_env%m_epsilon(spin))
      END DO

      !----------------------------------|
      ! calculates S_psi0 = S * psi0  |
      !----------------------------------|
      ! calculating this reduces the mat mult without storing a full aoxao
      ! matrix (for P). If nspin>1 you might consider calculating it on the
      ! fly to spare some memory
      CALL get_qs_env(qs_env, matrix_s=matrix_s)
      DO spin = 1, n_spins
         CALL cp_dbcsr_sm_fm_multiply(matrix_s(1)%matrix, &
                                      psi0(spin), &
                                      p_env%S_psi0(spin), &
                                      p_env%n_mo(spin))
      END DO

      ! releases psi0
      IF (p_env%orthogonal_orbitals) THEN
         NULLIFY (psi0)
      ELSE
         CALL cp_fm_release(psi0)
      END IF

      ! tells kpp1_env about the change of psi0
      CALL kpp1_did_change(p_env%kpp1_env)

      CALL timestop(handle)

   END SUBROUTINE p_env_psi0_changed

! **************************************************************************************************
!> \brief does a preorthogonalization of the given matrix:
!>      v = (I-PS)v
!> \param p_env the perturbation environment
!> \param qs_env the qs_env that is perturbed by this p_env
!> \param v matrix to orthogonalize
!> \param n_cols the number of columns of C to multiply (defaults to size(v,2))
!> \par History
!>      02.09.2002 adapted for new qs_p_env_type (TC)
!> \author Fawzi Mohamed
! **************************************************************************************************
   SUBROUTINE p_preortho(p_env, qs_env, v, n_cols)

      TYPE(qs_p_env_type)                                :: p_env
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(cp_fm_type), DIMENSION(:), INTENT(inout)      :: v
      INTEGER, DIMENSION(:), INTENT(in), OPTIONAL        :: n_cols

      CHARACTER(len=*), PARAMETER                        :: routineN = 'p_preortho'

      INTEGER                                            :: cols, handle, max_cols, maxnmo, n_spins, &
                                                            nmo2, spin, v_cols, v_rows
      TYPE(cp_fm_pool_type), POINTER                     :: maxmo_maxmo_fm_pool
      TYPE(cp_fm_struct_type), POINTER                   :: maxmo_maxmo_fmstruct, tmp_fmstruct
      TYPE(cp_fm_type)                                   :: tmp_matrix
      TYPE(dft_control_type), POINTER                    :: dft_control

      CALL timeset(routineN, handle)

      NULLIFY (maxmo_maxmo_fm_pool, maxmo_maxmo_fmstruct, tmp_fmstruct, &
               dft_control)

      CALL get_qs_env(qs_env, dft_control=dft_control)
      CALL mpools_get(qs_env%mpools, maxmo_maxmo_fm_pool=maxmo_maxmo_fm_pool)
      n_spins = dft_control%nspins
      maxmo_maxmo_fmstruct => fm_pool_get_el_struct(maxmo_maxmo_fm_pool)
      CALL cp_fm_struct_get(maxmo_maxmo_fmstruct, nrow_global=nmo2, ncol_global=maxnmo)
      CPASSERT(SIZE(v) >= n_spins)
      ! alloc tmp storage
      IF (PRESENT(n_cols)) THEN
         max_cols = MAXVAL(n_cols(1:n_spins))
      ELSE
         max_cols = 0
         DO spin = 1, n_spins
            CALL cp_fm_get_info(v(spin), ncol_global=v_cols)
            max_cols = MAX(max_cols, v_cols)
         END DO
      END IF
      IF (max_cols <= nmo2) THEN
         CALL fm_pool_create_fm(maxmo_maxmo_fm_pool, tmp_matrix)
      ELSE
         CALL cp_fm_struct_create(tmp_fmstruct, nrow_global=max_cols, &
                                  ncol_global=maxnmo, template_fmstruct=maxmo_maxmo_fmstruct)
         CALL cp_fm_create(tmp_matrix, matrix_struct=tmp_fmstruct)
         CALL cp_fm_struct_release(tmp_fmstruct)
      END IF

      DO spin = 1, n_spins

         CALL cp_fm_get_info(v(spin), &
                             nrow_global=v_rows, ncol_global=v_cols)
         CPASSERT(v_rows >= p_env%n_ao(spin))
         cols = v_cols
         IF (PRESENT(n_cols)) THEN
            CPASSERT(n_cols(spin) <= cols)
            cols = n_cols(spin)
         END IF
         CPASSERT(cols <= max_cols)

         ! tmp_matrix = v^T (S psi0)
         CALL parallel_gemm(transa='T', transb='N', m=cols, n=p_env%n_mo(spin), &
                            k=p_env%n_ao(spin), alpha=1.0_dp, matrix_a=v(spin), &
                            matrix_b=p_env%S_psi0(spin), beta=0.0_dp, &
                            matrix_c=tmp_matrix)
         ! v = v - psi0d tmp_matrix^T = v - psi0d psi0^T S v
         CALL parallel_gemm(transa='N', transb='T', m=p_env%n_ao(spin), n=cols, &
                            k=p_env%n_mo(spin), alpha=-1.0_dp, &
                            matrix_a=p_env%psi0d(spin), matrix_b=tmp_matrix, &
                            beta=1.0_dp, matrix_c=v(spin))

      END DO

      IF (max_cols <= nmo2) THEN
         CALL fm_pool_give_back_fm(maxmo_maxmo_fm_pool, tmp_matrix)
      ELSE
         CALL cp_fm_release(tmp_matrix)
      END IF

      CALL timestop(handle)

   END SUBROUTINE p_preortho

! **************************************************************************************************
!> \brief does a postorthogonalization on the given matrix vector:
!>      v = (I-SP) v
!> \param p_env the perturbation environment
!> \param qs_env the qs_env that is perturbed by this p_env
!> \param v matrix to orthogonalize
!> \param n_cols the number of columns of C to multiply (defaults to size(v,2))
!> \par History
!>      07.2002 created [fawzi]
!> \author Fawzi Mohamed
! **************************************************************************************************
   SUBROUTINE p_postortho(p_env, qs_env, v, n_cols)

      TYPE(qs_p_env_type)                                :: p_env
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(cp_fm_type), DIMENSION(:), INTENT(inout)      :: v
      INTEGER, DIMENSION(:), INTENT(in), OPTIONAL        :: n_cols

      CHARACTER(len=*), PARAMETER                        :: routineN = 'p_postortho'

      INTEGER                                            :: cols, handle, max_cols, maxnmo, n_spins, &
                                                            nmo2, spin, v_cols, v_rows
      TYPE(cp_fm_pool_type), POINTER                     :: maxmo_maxmo_fm_pool
      TYPE(cp_fm_struct_type), POINTER                   :: maxmo_maxmo_fmstruct, tmp_fmstruct
      TYPE(cp_fm_type)                                   :: tmp_matrix
      TYPE(dft_control_type), POINTER                    :: dft_control

      CALL timeset(routineN, handle)

      NULLIFY (maxmo_maxmo_fm_pool, maxmo_maxmo_fmstruct, tmp_fmstruct, &
               dft_control)

      CALL get_qs_env(qs_env, dft_control=dft_control)
      CALL mpools_get(qs_env%mpools, maxmo_maxmo_fm_pool=maxmo_maxmo_fm_pool)
      n_spins = dft_control%nspins
      maxmo_maxmo_fmstruct => fm_pool_get_el_struct(maxmo_maxmo_fm_pool)
      CALL cp_fm_struct_get(maxmo_maxmo_fmstruct, nrow_global=nmo2, ncol_global=maxnmo)
      CPASSERT(SIZE(v) >= n_spins)
      ! alloc tmp storage
      IF (PRESENT(n_cols)) THEN
         max_cols = MAXVAL(n_cols(1:n_spins))
      ELSE
         max_cols = 0
         DO spin = 1, n_spins
            CALL cp_fm_get_info(v(spin), ncol_global=v_cols)
            max_cols = MAX(max_cols, v_cols)
         END DO
      END IF
      IF (max_cols <= nmo2) THEN
         CALL fm_pool_create_fm(maxmo_maxmo_fm_pool, tmp_matrix)
      ELSE
         CALL cp_fm_struct_create(tmp_fmstruct, nrow_global=max_cols, &
                                  ncol_global=maxnmo, template_fmstruct=maxmo_maxmo_fmstruct)
         CALL cp_fm_create(tmp_matrix, matrix_struct=tmp_fmstruct)
         CALL cp_fm_struct_release(tmp_fmstruct)
      END IF

      DO spin = 1, n_spins

         CALL cp_fm_get_info(v(spin), &
                             nrow_global=v_rows, ncol_global=v_cols)
         CPASSERT(v_rows >= p_env%n_ao(spin))
         cols = v_cols
         IF (PRESENT(n_cols)) THEN
            CPASSERT(n_cols(spin) <= cols)
            cols = n_cols(spin)
         END IF
         CPASSERT(cols <= max_cols)

         ! tmp_matrix = v^T psi0d
         CALL parallel_gemm(transa='T', transb='N', m=cols, n=p_env%n_mo(spin), &
                            k=p_env%n_ao(spin), alpha=1.0_dp, matrix_a=v(spin), &
                            matrix_b=p_env%psi0d(spin), beta=0.0_dp, &
                            matrix_c=tmp_matrix)
         ! v = v - (S psi0) tmp_matrix^T = v - S psi0 psi0d^T v
         CALL parallel_gemm(transa='N', transb='T', m=p_env%n_ao(spin), n=cols, &
                            k=p_env%n_mo(spin), alpha=-1.0_dp, &
                            matrix_a=p_env%S_psi0(spin), matrix_b=tmp_matrix, &
                            beta=1.0_dp, matrix_c=v(spin))

      END DO

      IF (max_cols <= nmo2) THEN
         CALL fm_pool_give_back_fm(maxmo_maxmo_fm_pool, tmp_matrix)
      ELSE
         CALL cp_fm_release(tmp_matrix)
      END IF

      CALL timestop(handle)

   END SUBROUTINE p_postortho

! **************************************************************************************************
!> \brief ...
!> \param qs_env ...
!> \param p_env ...
! **************************************************************************************************
   SUBROUTINE p_env_finish_kpp1(qs_env, p_env)
      TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env
      TYPE(qs_p_env_type), INTENT(IN)                    :: p_env

      CHARACTER(len=*), PARAMETER                        :: routineN = 'p_env_finish_kpp1'

      INTEGER                                            :: handle, ispin, nao, nao_aux
      TYPE(admm_type), POINTER                           :: admm_env
      TYPE(dbcsr_type)                                   :: work_hmat
      TYPE(dft_control_type), POINTER                    :: dft_control

      CALL timeset(routineN, handle)

      CALL get_qs_env(qs_env, dft_control=dft_control, admm_env=admm_env)

      IF (dft_control%do_admm) THEN
         CALL dbcsr_copy(work_hmat, p_env%kpp1(1)%matrix)

         CALL cp_fm_get_info(admm_env%A, nrow_global=nao_aux, ncol_global=nao)
         DO ispin = 1, SIZE(p_env%kpp1)
            CALL cp_dbcsr_sm_fm_multiply(p_env%kpp1_admm(ispin)%matrix, admm_env%A, admm_env%work_aux_orb, &
                                         ncol=nao, alpha=1.0_dp, beta=0.0_dp)
            CALL parallel_gemm('T', 'N', nao, nao, nao_aux, 1.0_dp, admm_env%A, &
                               admm_env%work_aux_orb, 0.0_dp, admm_env%work_orb_orb)
            CALL dbcsr_set(work_hmat, 0.0_dp)
            CALL copy_fm_to_dbcsr(admm_env%work_orb_orb, work_hmat, keep_sparsity=.TRUE.)
            CALL dbcsr_add(p_env%kpp1(ispin)%matrix, work_hmat, 1.0_dp, 1.0_dp)
         END DO

         CALL dbcsr_release(work_hmat)
      END IF

      CALL timestop(handle)

   END SUBROUTINE p_env_finish_kpp1

END MODULE qs_p_env_methods
