diff --git a/examples/error_estimates_forces.jl b/examples/error_estimates_forces.jl
index 71eb8f578f..e2c3f209c2 100644
--- a/examples/error_estimates_forces.jl
+++ b/examples/error_estimates_forces.jl
@@ -54,7 +54,7 @@ tol = 1e-5;
 # We compute the reference solution ``P_*`` from which we will compute the
 # references forces.
 scfres_ref = self_consistent_field(basis_ref; tol, callback=identity)
-ψ_ref = DFTK.select_occupied_orbitals(basis_ref, scfres_ref.ψ, scfres_ref.occupation).ψ;
+ψ_ref = DFTK.select_occupied_orbitals(scfres_ref.ψ, scfres_ref.occupation).ψ;
 
 # We compute a variational approximation of the reference solution with
 # smaller `Ecut`. `ψr`, `ρr` and `Er` are the quantities computed with `Ecut`
@@ -69,16 +69,16 @@ Ecut = 15
 basis = PlaneWaveBasis(model; Ecut, kgrid)
 scfres = self_consistent_field(basis; tol, callback=identity)
 ψr = DFTK.transfer_blochwave(scfres.ψ, basis, basis_ref)
-ρr = compute_density(basis_ref, ψr, scfres.occupation)
-Er, hamr = energy_hamiltonian(basis_ref, ψr, scfres.occupation; ρ=ρr);
+ρr = compute_density(ψr, scfres.occupation)
+Er, hamr = energy_hamiltonian(ψr, scfres.occupation; ρ=ρr);
 
 # We then compute several quantities that we need to evaluate the error bounds.
 
 # - Compute the residual ``R(P)``, and remove the virtual orbitals, as required
 #   in [`src/scf/newton.jl`](https://github.com/JuliaMolSim/DFTK.jl/blob/fedc720dab2d194b30d468501acd0f04bd4dd3d6/src/scf/newton.jl#L121).
-res = DFTK.compute_projected_gradient(basis_ref, ψr, scfres.occupation)
-res, occ = DFTK.select_occupied_orbitals(basis_ref, res, scfres.occupation)
-ψr = DFTK.select_occupied_orbitals(basis_ref, ψr, scfres.occupation).ψ;
+res = DFTK.compute_projected_gradient(ψr, scfres.occupation)
+res, occ = DFTK.select_occupied_orbitals(BlochWaves(ψr.basis, res), scfres.occupation)
+ψr = DFTK.select_occupied_orbitals(ψr, scfres.occupation).ψ;
 
 # - Compute the error ``P-P_*`` on the associated orbitals ``ϕ-ψ`` after aligning
 #   them: this is done by solving ``\min |ϕ - ψU|`` for ``U`` unitary matrix of
@@ -149,7 +149,7 @@ Mres = apply_metric(ψr.data, P, res, apply_inv_M);
 
 # - Compute the projection of the residual onto the high and low frequencies:
 resLF = DFTK.transfer_blochwave(res, basis_ref, basis)
-resHF = res - DFTK.transfer_blochwave(resLF, basis, basis_ref);
+resHF = denest(res) - denest(DFTK.transfer_blochwave(resLF, basis, basis_ref));
 
 # - Compute ``{\boldsymbol M}^{-1}_{22}R_2(P)``:
 e2 = apply_metric(ψr, P, resHF, apply_inv_M);
@@ -163,15 +163,15 @@ e2 = apply_metric(ψr, P, resHF, apply_inv_M);
 end
 ΩpKe2 = DFTK.apply_Ω(e2, ψr, hamr, Λ) .+ DFTK.apply_K(basis_ref, e2, ψr, ρr, occ)
 ΩpKe2 = DFTK.transfer_blochwave(ΩpKe2, basis_ref, basis)
-rhs = resLF - ΩpKe2;
+rhs = denest(resLF) - denest(ΩpKe2);
 
 # - Solve the Schur system to compute ``R_{\rm Schur}(P)``: this is the most
 #   costly step, but inverting ``\boldsymbol{Ω} + \boldsymbol{K}`` on the small space has
 #   the same cost than the full SCF cycle on the small grid.
-(; ψ) = DFTK.select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
-e1 = DFTK.solve_ΩplusK(basis, ψ, rhs, occ; tol).δψ
+(; ψ) = DFTK.select_occupied_orbitals(scfres.ψ, scfres.occupation)
+e1 = DFTK.solve_ΩplusK(ψ, rhs, occ; tol).δψ
 e1 = DFTK.transfer_blochwave(e1, basis, basis_ref)
-res_schur = e1 + Mres;
+res_schur = denest(e1) + Mres;
 
 # ## Error estimates
 
@@ -197,8 +197,9 @@ relerror["F(P)"] = compute_relerror(f);
 # To this end, we use the `ForwardDiff.jl` package to compute ``{\rm d}F(P)``
 # using automatic differentiation.
 function df(basis, occupation, ψ, δψ, ρ)
-    δρ = DFTK.compute_δρ(basis, ψ, δψ, occupation)
-    ForwardDiff.derivative(ε -> compute_forces(basis, ψ.+ε.*δψ, occupation; ρ=ρ+ε.*δρ), 0)
+    δρ = DFTK.compute_δρ(ψ, δψ, occupation)
+    ForwardDiff.derivative(ε -> compute_forces(BlochWaves(ψ.basis, denest(ψ).+ε.*δψ),
+                                               occupation; ρ=ρ+ε.*δρ), 0)
 end;
 
 # - Computation of the forces by a linearization argument if we have access to
diff --git a/examples/geometry_optimization.jl b/examples/geometry_optimization.jl
index 8f27b6464b..4789402b7d 100644
--- a/examples/geometry_optimization.jl
+++ b/examples/geometry_optimization.jl
@@ -38,6 +38,11 @@ function compute_scfres(x)
     if isnothing(ρ)
         ρ = guess_density(basis)
     end
+    if isnothing(ψ)
+        ψ = BlochWaves(basis)
+    else
+        ψ = BlochWaves(basis, denest(ψ))
+    end
     is_converged = DFTK.ScfConvergenceForce(tol / 10)
     scfres = self_consistent_field(basis; ψ, ρ, is_converged, callback=identity)
     ψ = scfres.ψ
diff --git a/examples/publications/2022_cazalis.jl b/examples/publications/2022_cazalis.jl
index 14d6add458..f3ec6bd86f 100644
--- a/examples/publications/2022_cazalis.jl
+++ b/examples/publications/2022_cazalis.jl
@@ -9,8 +9,8 @@ using Plots
 struct Hartree2D end
 struct Term2DHartree <: DFTK.TermNonlinear end
 (t::Hartree2D)(basis) = Term2DHartree()
-function DFTK.ene_ops(term::Term2DHartree, basis::PlaneWaveBasis{T},
-                      ψ, occ; ρ, kwargs...) where {T}
+function DFTK.ene_ops(term::Term2DHartree, ψ::BlochWaves{T}, occ; ρ, kwargs...) where {T}
+    basis = ψ.basis
     ## 2D Fourier transform of 3D Coulomb interaction 1/|x|
     poisson_green_coeffs = 2T(π) ./ [norm(G) for G in G_vectors_cart(basis)]
     poisson_green_coeffs[1] = 0  # DC component
diff --git a/src/DFTK.jl b/src/DFTK.jl
index d5fa385e95..7609de88fb 100644
--- a/src/DFTK.jl
+++ b/src/DFTK.jl
@@ -72,8 +72,10 @@ export compute_fft_size
 export G_vectors, G_vectors_cart, r_vectors, r_vectors_cart
 export Gplusk_vectors, Gplusk_vectors_cart
 export Kpoint
-export to_composite_σG
-export from_composite_σG
+export BlochWaves, view_component, nest, denest
+export blochwave_as_matrix
+export blochwave_as_tensor
+export blochwaves_as_matrices
 export ifft
 export irfft
 export ifft!
diff --git a/src/densities.jl b/src/densities.jl
index e10387213c..b9560a80f7 100644
--- a/src/densities.jl
+++ b/src/densities.jl
@@ -1,23 +1,24 @@
 # Densities (and potentials) are represented by arrays
 # ρ[ix,iy,iz,iσ] in real space, where iσ ∈ [1:n_spin_components]
 
-# TODO: We reduce all components for the density. Will need to be though again when we merge
-# the components and the spins.
 """
-    compute_density(basis::PlaneWaveBasis, ψ::AbstractVector, occupation::AbstractVector)
+    compute_density(ψ::BlochWaves, occupation::AbstractVector)
 
-Compute the density for a wave function `ψ` discretized on the plane-wave
-grid `basis`, where the individual k-points are occupied according to `occupation`.
-`ψ` should be one coefficient matrix per ``k``-point.
+Compute the density for a wave function `ψ` discretized on the plane-wave grid `ψ.basis`,
+where the individual k-points are occupied according to `occupation`.
+`ψ` should contain one coefficient matrix per ``k``-point.
 It is possible to ask only for occupations higher than a certain level to be computed by
 using an optional `occupation_threshold`. By default all occupation numbers are considered.
 """
-@views @timing function compute_density(basis::PlaneWaveBasis{T}, ψ, occupation;
-                                        occupation_threshold=zero(T)) where {T}
-    S = promote_type(T, real(eltype(ψ[1])))
+# TODO: We reduce all components for the density. Will need to be though again when we merge
+# the components and the spins.
+@views @timing function compute_density(ψ::BlochWaves{T, Tψ}, occupation;
+                                        occupation_threshold=zero(T)) where {T, Tψ}
+    S = promote_type(T, real(Tψ))
     # occupation should be on the CPU as we are going to be doing scalar indexing.
     occupation = [to_cpu(oc) for oc in occupation]
 
+    basis = ψ.basis
     mask_occ = [findall(occnk -> abs(occnk) ≥ occupation_threshold, occk)
                 for occk in occupation]
     if all(isempty, mask_occ)  # No non-zero occupations => return zero density
@@ -66,21 +67,22 @@ using an optional `occupation_threshold`. By default all occupation numbers are
 end
 
 # Variation in density corresponding to a variation in the orbitals and occupations.
-@views @timing function compute_δρ(basis::PlaneWaveBasis{T}, ψ, δψ,
-                                   occupation, δoccupation=zero.(occupation);
+@views @timing function compute_δρ(ψ::BlochWaves{T}, δψ, occupation,
+                                   δoccupation=zero.(occupation);
                                    occupation_threshold=zero(T)) where {T}
     ForwardDiff.derivative(zero(T)) do ε
         ψ_ε   = [ψk   .+ ε .* δψk   for (ψk,   δψk)   in zip(ψ, δψ)]
         occ_ε = [occk .+ ε .* δocck for (occk, δocck) in zip(occupation, δoccupation)]
-        compute_density(basis, ψ_ε, occ_ε; occupation_threshold)
+        compute_density(BlochWaves(ψ.basis, ψ_ε), occ_ε; occupation_threshold)
     end
 end
 
-@views @timing function compute_kinetic_energy_density(basis::PlaneWaveBasis{TT}, ψ,
-                                                       occupation) where {TT}
+@views @timing function compute_kinetic_energy_density(ψ::BlochWaves{T, Tψ},
+                                                       occupation) where {T, Tψ}
+    basis = ψ.basis
     @assert basis.model.n_components == 1
-    T = promote_type(TT, real(eltype(ψ[1])))
-    τ = similar(ψ[1], T, (basis.fft_size..., basis.model.n_spin_components))
+    TT = promote_type(T, real(Tψ))
+    τ = similar(ψ[1], TT, (basis.fft_size..., basis.model.n_spin_components))
     τ .= 0
     dαψnk_real = zeros(complex(T), basis.fft_size)
     for (ik, kpt) in enumerate(basis.kpoints)
diff --git a/src/orbitals.jl b/src/orbitals.jl
index eaea3da812..4ef4bdd418 100644
--- a/src/orbitals.jl
+++ b/src/orbitals.jl
@@ -4,19 +4,20 @@ using Random  # Used to have a generic API for CPU and GPU computations alike: s
 # virtual states (or states with small occupation level for metals).
 # threshold is a parameter to distinguish between states we want to keep and the
 # others when using temperature. It is set to 0.0 by default, to treat with insulators.
-function select_occupied_orbitals(basis, ψ, occupation; threshold=0.0)
+function select_occupied_orbitals(ψ, occupation; threshold=0.0)
     N = [something(findlast(x -> x > threshold, occk), 0) for occk in occupation]
     selected_ψ   = [@view ψk[:, :, 1:N[ik]] for (ik, ψk)   in enumerate(ψ)]
     selected_occ = [      occk[1:N[ik]]     for (ik, occk) in enumerate(occupation)]
 
+    ψ = BlochWaves(ψ.basis, selected_ψ)
     # If we have an insulator, sanity check that the orbitals we kept are the occupied ones.
     if iszero(threshold)
-        model   = basis.model
+        model   = ψ.basis.model
         n_spin  = model.n_spin_components
         n_bands = div(model.n_electrons, n_spin * filled_occupation(model), RoundUp)
         @assert all([n_bands == size(ψk, 3) for ψk in ψ])
     end
-    (; ψ=selected_ψ, occupation=selected_occ)
+    (; ψ, occupation=selected_occ)
 end
 
 # Packing routines used in direct_minimization and newton algorithms.
diff --git a/src/postprocess/forces.jl b/src/postprocess/forces.jl
index f18622c7f0..a54dd74f1b 100644
--- a/src/postprocess/forces.jl
+++ b/src/postprocess/forces.jl
@@ -5,10 +5,11 @@ lattice vectors. To get cartesian forces use [`compute_forces_cart`](@ref).
 Returns a list of lists of forces (as SVector{3}) in the same order as the `atoms`
 and `positions` in the underlying [`Model`](@ref).
 """
-@timing function compute_forces(basis::PlaneWaveBasis{T}, ψ, occupation; kwargs...) where {T}
+@timing function compute_forces(ψ::BlochWaves{T}, occupation; kwargs...) where {T}
+    basis = ψ.basis
     # no explicit symmetrization is performed here, it is the
     # responsability of each term to return symmetric forces
-    forces_per_term = [compute_forces(term, basis, ψ, occupation; kwargs...)
+    forces_per_term = [compute_forces(term, ψ, occupation; kwargs...)
                        for term in basis.terms]
     sum(filter(!isnothing, forces_per_term))
 end
@@ -19,14 +20,14 @@ Returns a list of lists of forces
 `[[force for atom in positions] for (element, positions) in atoms]`
 which has the same structure as the `atoms` object passed to the underlying [`Model`](@ref).
 """
-function compute_forces_cart(basis::PlaneWaveBasis, ψ, occupation; kwargs...)
-    forces_reduced = compute_forces(basis, ψ, occupation; kwargs...)
-    covector_red_to_cart.(basis.model, forces_reduced)
+function compute_forces_cart(ψ::BlochWaves, occupation; kwargs...)
+    forces_reduced = compute_forces(ψ, occupation; kwargs...)
+    covector_red_to_cart.(ψ.basis.model, forces_reduced)
 end
 
 function compute_forces(scfres)
-    compute_forces(scfres.basis, scfres.ψ, scfres.occupation; scfres.ρ)
+    compute_forces(scfres.ψ, scfres.occupation; scfres.ρ)
 end
 function compute_forces_cart(scfres)
-    compute_forces_cart(scfres.basis, scfres.ψ, scfres.occupation; scfres.ρ)
+    compute_forces_cart(scfres.ψ, scfres.occupation; scfres.ρ)
 end
diff --git a/src/postprocess/stresses.jl b/src/postprocess/stresses.jl
index f1c4de4412..ff31994d84 100644
--- a/src/postprocess/stresses.jl
+++ b/src/postprocess/stresses.jl
@@ -12,8 +12,9 @@ Compute the stresses (= 1/Vol dE/d(M*lattice), taken at M=I) of an obtained SCF
                                    basis.kgrid, basis.symmetries_respect_rgrid,
                                    basis.use_symmetries_for_kpoint_reduction,
                                    basis.comm_kpts, basis.architecture)
-        ρ = compute_density(new_basis, scfres.ψ, scfres.occupation)
-        energies = energy_hamiltonian(new_basis, scfres.ψ, scfres.occupation;
+        ψ = BlochWaves(new_basis, denest(scfres.ψ))
+        ρ = compute_density(ψ, scfres.occupation)
+        energies = energy_hamiltonian(ψ, scfres.occupation;
                                       ρ, scfres.eigenvalues, scfres.εF).energies
         energies.total
     end
diff --git a/src/response/hessian.jl b/src/response/hessian.jl
index 259b0401ed..6ac5e376b5 100644
--- a/src/response/hessian.jl
+++ b/src/response/hessian.jl
@@ -40,9 +40,10 @@ end
 Compute the application of K defined at ψ to δψ. ρ is the density issued from ψ.
 δψ also generates a δρ, computed with `compute_δρ`.
 """
+# T@D@ basis redundant; change signature maybe?
 @views @timing function apply_K(basis::PlaneWaveBasis, δψ, ψ, ρ, occupation)
     δψ = proj_tangent(δψ, ψ)
-    δρ = compute_δρ(basis, ψ, δψ, occupation)
+    δρ = compute_δρ(ψ, δψ, occupation)
     δV = apply_kernel(basis, δρ; ρ)
 
     Kδψ = map(enumerate(ψ)) do (ik, ψk)
@@ -62,13 +63,14 @@ Compute the application of K defined at ψ to δψ. ρ is the density issued fro
 end
 
 """
-    solve_ΩplusK(basis::PlaneWaveBasis{T}, ψ, res, occupation;
+    solve_ΩplusK(ψ::BlochWaves{T}, rhs, occupation;
                  tol=1e-10, verbose=false) where {T}
 
 Return δψ where (Ω+K) δψ = rhs
 """
-@timing function solve_ΩplusK(basis::PlaneWaveBasis{T}, ψ, rhs, occupation;
-                      callback=identity, tol=1e-10) where {T}
+@timing function solve_ΩplusK(ψ::BlochWaves{T}, rhs, occupation; callback=identity,
+                              tol=1e-10) where {T}
+    basis = ψ.basis
     filled_occ = filled_occupation(basis.model)
     # for now, all orbitals have to be fully occupied -> need to strip them beforehand
     @assert all(all(occ_k .== filled_occ) for occ_k in occupation)
@@ -79,8 +81,8 @@ Return δψ where (Ω+K) δψ = rhs
     @assert mpi_nprocs() == 1  # Distributed implementation not yet available
 
     # compute quantites at the point which define the tangent space
-    ρ = compute_density(basis, ψ, occupation)
-    H = energy_hamiltonian(basis, ψ, occupation; ρ).ham
+    ρ = compute_density(ψ, occupation)
+    H = energy_hamiltonian(ψ, occupation; ρ).ham
 
     ψ_matrices = blochwaves_as_matrices(ψ)
     pack(ψ) = reinterpret_real(pack_ψ(ψ))
@@ -152,11 +154,12 @@ Solve the problem `(Ω+K) δψ = rhs` using a split algorithm, where `rhs` is ty
     basis = ham.basis
     @assert size(rhs[1]) == size(ψ[1])  # Assume the same number of bands in ψ and rhs
 
+    ψ_array = denest(ψ)
     # compute δρ0 (ignoring interactions)
-    δψ0, δoccupation0 = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, -rhs;
+    δψ0, δoccupation0 = apply_χ0_4P(ham, ψ_array, occupation, εF, eigenvalues, -rhs;
                                     tol=tol_sternheimer, occupation_threshold,
                                     kwargs...)  # = -χ04P * rhs
-    δρ0 = compute_δρ(basis, ψ, δψ0, occupation, δoccupation0; occupation_threshold)
+    δρ0 = compute_δρ(ψ, δψ0, occupation, δoccupation0; occupation_threshold)
 
     # compute total δρ
     pack(δρ)   = vec(δρ)
@@ -183,13 +186,13 @@ Solve the problem `(Ω+K) δψ = rhs` using a split algorithm, where `rhs` is ty
     end
 
     # Compute total change in eigenvalues
-    δeigenvalues = map(ψ, δHψ) do ψk, δHψk
+    δeigenvalues = map(ψ_array, δHψ) do ψk, δHψk
         map(eachslice(ψk; dims=3), eachslice(δHψk; dims=3)) do ψnk, δHψnk
             real(dot(ψnk, δHψnk))  # δε_{nk} = <ψnk | δH | ψnk>
         end
     end
 
-    δψ, δoccupation, δεF = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
+    δψ, δoccupation, δεF = apply_χ0_4P(ham, ψ_array, occupation, εF, eigenvalues, δHψ;
                                        occupation_threshold, tol=tol_sternheimer,
                                        kwargs...)
 
diff --git a/src/scf/direct_minimization.jl b/src/scf/direct_minimization.jl
index 2898752d0b..b03284d3b1 100644
--- a/src/scf/direct_minimization.jl
+++ b/src/scf/direct_minimization.jl
@@ -63,14 +63,16 @@ Computes the ground state by direct minimization. `kwargs...` are
 passed to `Optim.Options()`. Note that the resulting ψ are not
 necessarily eigenvectors of the Hamiltonian.
 """
-direct_minimization(basis::PlaneWaveBasis; kwargs...) = direct_minimization(basis, nothing; kwargs...)
-function direct_minimization(basis::PlaneWaveBasis{T}, ψ0;
-                             prec_type=PreconditionerTPA, maxiter=1_000,
+direct_minimization(basis::PlaneWaveBasis; kwargs...) =
+    direct_minimization(BlochWaves(basis); kwargs...)
+
+function direct_minimization(ψ0::BlochWaves{T}; prec_type=PreconditionerTPA, maxiter=1_000,
                              optim_solver=Optim.LBFGS, tol=1e-6, kwargs...) where {T}
     if mpi_nprocs() > 1
         # need synchronization in Optim
         error("Direct minimization with MPI is not supported yet")
     end
+    basis = ψ0.basis
     model = basis.model
     @assert model.n_components == 1
     @assert iszero(model.temperature)  # temperature is not yet supported
@@ -81,7 +83,7 @@ function direct_minimization(basis::PlaneWaveBasis{T}, ψ0;
     Nk = length(basis.kpoints)
 
     if isnothing(ψ0)
-        ψ0 = [random_orbitals(basis, kpt, n_bands) for kpt in basis.kpoints]
+        ψ0 = BlochWaves(basis, [random_orbitals(basis, kpt, n_bands) for kpt in basis.kpoints])
     end
     ψ0_matrices = blochwaves_as_matrices(ψ0)
     occupation = [filled_occ * ones(T, n_bands) for _ = 1:Nk]
@@ -100,8 +102,8 @@ function direct_minimization(basis::PlaneWaveBasis{T}, ψ0;
     # computes energies and gradients
     function fg!(E, G, ψ)
         ψ = unpack(ψ)
-        ρ = compute_density(basis, ψ, occupation)
-        energies, H = energy_hamiltonian(basis, ψ, occupation; ρ)
+        ρ = compute_density(BlochWaves(basis, ψ), occupation)
+        energies, H = energy_hamiltonian(BlochWaves(basis, ψ), occupation; ρ)
 
         # The energy has terms like occ * <ψ|H|ψ>, so the gradient is 2occ Hψ
         if G !== nothing
@@ -145,5 +147,6 @@ function direct_minimization(basis::PlaneWaveBasis{T}, ψ0;
 
     # We rely on the fact that the last point where fg! was called is the minimizer to
     # avoid recomputing at ψ
-    (; ham=H, basis, energies, converged=true, ρ, ψ, eigenvalues, occupation, εF, optim_res=res)
+    (; ham=H, basis, energies, converged=true, ρ, ψ=BlochWaves(basis, ψ), eigenvalues,
+     occupation, εF, optim_res=res)
 end
diff --git a/src/scf/nbands_algorithm.jl b/src/scf/nbands_algorithm.jl
index 79a3a13968..aa2f93caec 100644
--- a/src/scf/nbands_algorithm.jl
+++ b/src/scf/nbands_algorithm.jl
@@ -67,7 +67,7 @@ function determine_n_bands(bands::AdaptiveBands, occupation::Nothing, eigenvalue
     (; n_bands_converge, n_bands_compute)
 end
 function determine_n_bands(bands::AdaptiveBands, occupation::AbstractVector,
-                           eigenvalues::AbstractVector, ψ::AbstractVector)
+                           eigenvalues::AbstractVector, ψ::BlochWaves)
     # TODO Could return different bands per k-Points
 
     # Determine number of bands to be actually converged
diff --git a/src/scf/newton.jl b/src/scf/newton.jl
index 8fb7621ec2..3fd59aeef6 100644
--- a/src/scf/newton.jl
+++ b/src/scf/newton.jl
@@ -49,9 +49,9 @@ using IterativeSolvers
 #  Compute the gradient of the energy, projected on the space tangent to ψ, that
 #  is to say H(ψ)*ψ - ψ*λ where λ is the set of Rayleigh coefficients associated
 #  to the ψ.
-function compute_projected_gradient(basis::PlaneWaveBasis, ψ, occupation)
-    ρ = compute_density(basis, ψ, occupation)
-    H = energy_hamiltonian(basis, ψ, occupation; ρ).ham
+function compute_projected_gradient(ψ, occupation)
+    ρ = compute_density(ψ, occupation)
+    H = energy_hamiltonian(ψ, occupation; ρ).ham
 
     [proj_tangent_kpt(H.blocks[ik] * ψk, ψk) for (ik, ψk) in enumerate(ψ)]
 end
@@ -75,18 +75,18 @@ end
 
 
 """
-    newton(basis::PlaneWaveBasis{T}, ψ0;
+    newton(ψ0::BlochWaves{T};
            tol=1e-6, tol_cg=tol / 100, maxiter=20, callback=ScfDefaultCallback(),
            is_converged=ScfConvergenceDensity(tol))
 
 Newton algorithm. Be careful that the starting point needs to be not too far
 from the solution.
 """
-function newton(basis::PlaneWaveBasis{T}, ψ0; tol=1e-6, tol_cg=tol / 100, maxiter=20,
-                callback=ScfDefaultCallback(),
-                is_converged=ScfConvergenceDensity(tol)) where {T}
+function newton(ψ0::BlochWaves{T}; tol=1e-6, tol_cg=tol / 100, maxiter=20,
+                callback=ScfDefaultCallback(), is_converged=ScfConvergenceDensity(tol)) where {T}
 
     # setting parameters
+    basis = ψ0.basis
     model = basis.model
     @assert model.n_components == 1
     @assert iszero(model.temperature)  # temperature is not yet supported
@@ -106,9 +106,9 @@ function newton(basis::PlaneWaveBasis{T}, ψ0; tol=1e-6, tol_cg=tol / 100, maxit
     n_iter = 0
 
     # orbitals, densities and energies to be updated along the iterations
-    ψ = deepcopy(ψ0)
-    ρ = compute_density(basis, ψ, occupation)
-    energies, H = energy_hamiltonian(basis, ψ, occupation; ρ)
+    ψ = BlochWaves(basis, denest(ψ0))
+    ρ = compute_density(ψ, occupation)
+    energies, H = energy_hamiltonian(ψ, occupation; ρ)
     converged = false
 
     # perform iterations
@@ -116,13 +116,13 @@ function newton(basis::PlaneWaveBasis{T}, ψ0; tol=1e-6, tol_cg=tol / 100, maxit
         n_iter += 1
 
         # compute Newton step and next iteration
-        res = compute_projected_gradient(basis, ψ, occupation)
+        res = compute_projected_gradient(ψ, occupation)
         # solve (Ω+K) δψ = -res so that the Newton step is ψ <- ψ + δψ
-        δψ = solve_ΩplusK(basis, ψ, -res, occupation; tol=tol_cg).δψ
-        ψ  = [ortho_qr(ψ[ik] + δψ[ik]) for ik = 1:Nk]
+        δψ = solve_ΩplusK(ψ, -res, occupation; tol=tol_cg).δψ
+        ψ  = BlochWaves(basis, [ortho_qr(ψ[ik] + δψ[ik]) for ik = 1:Nk])
 
-        ρ_next = compute_density(basis, ψ, occupation)
-        energies, H = energy_hamiltonian(basis, ψ, occupation; ρ=ρ_next)
+        ρ_next = compute_density(ψ, occupation)
+        energies, H = energy_hamiltonian(ψ, occupation; ρ=ρ_next)
         info = (; ham=H, basis, converged, stage=:iterate, ρin=ρ, ρout=ρ_next, n_iter,
                 energies, algorithm="Newton")
         callback(info)
diff --git a/src/scf/potential_mixing.jl b/src/scf/potential_mixing.jl
index 2892a43590..a9e3f4983b 100644
--- a/src/scf/potential_mixing.jl
+++ b/src/scf/potential_mixing.jl
@@ -229,7 +229,7 @@ trial_damping(damping::FixedDamping, args...) = damping.α
     fermialg::AbstractFermiAlgorithm=default_fermialg(basis.model),
     ρ=guess_density(basis),
     V=nothing,
-    ψ=nothing,
+    ψ::BlochWaves=BlochWaves(basis),
     tol=1e-6,
     maxiter=100,
     eigensolver=lobpcg_hyper,
@@ -253,15 +253,16 @@ trial_damping(damping::FixedDamping, args...) = damping.α
     end
 
     # Initial guess for V (if none given)
-    ham = energy_hamiltonian(basis, nothing, nothing; ρ).ham
+    ham = energy_hamiltonian(BlochWaves(basis), nothing; ρ).ham
     isnothing(V) && (V = total_local_potential(ham))
 
-    function EVρ(Vin; diagtol=tol / 10, ψ=nothing, eigenvalues=nothing, occupation=nothing)
+    function EVρ(Vin; diagtol=tol / 10, ψ=BlochWaves(ham.basis), eigenvalues=nothing,
+                 occupation=nothing)
         ham_V = hamiltonian_with_total_potential(ham, Vin)
 
         res_V = next_density(ham_V, nbandsalg, fermialg; eigensolver, ψ, eigenvalues,
                              occupation, miniter=diag_miniter, tol=diagtol)
-        new_E, new_ham = energy_hamiltonian(basis, res_V.ψ, res_V.occupation;
+        new_E, new_ham = energy_hamiltonian(res_V.ψ, res_V.occupation;
                                             ρ=res_V.ρout, eigenvalues=res_V.eigenvalues,
                                             εF=res_V.εF)
         (; basis, ham=new_ham, energies=new_E,
diff --git a/src/scf/self_consistent_field.jl b/src/scf/self_consistent_field.jl
index e3e8f3f550..a72ff52361 100644
--- a/src/scf/self_consistent_field.jl
+++ b/src/scf/self_consistent_field.jl
@@ -13,8 +13,9 @@ data structure to determine and adjust the number of bands to be computed.
 function next_density(ham::Hamiltonian,
                       nbandsalg::NbandsAlgorithm=AdaptiveBands(ham.basis.model),
                       fermialg::AbstractFermiAlgorithm=default_fermialg(ham.basis.model);
-                      eigensolver=lobpcg_hyper, ψ=nothing, eigenvalues=nothing,
-                      occupation=nothing, kwargs...)
+                      eigensolver=lobpcg_hyper, ψ=BlochWaves(ham.basis),
+                      eigenvalues=nothing, occupation=nothing, kwargs...)
+    @assert ham.basis == ψ.basis
     n_bands_converge, n_bands_compute = determine_n_bands(nbandsalg, occupation,
                                                           eigenvalues, ψ)
 
@@ -49,10 +50,10 @@ function next_density(ham::Hamiltonian,
               "`nbandsalg=AdaptiveBands(model; n_bands_converge=$(n_bands_converge + 3)`)")
     end
 
-    ψ = eigres.X
-    ρout = compute_density(ham.basis, ψ, occupation; nbandsalg.occupation_threshold)
-    (; ψ, eigenvalues=eigres.λ, occupation, εF, ρout, diagonalization=eigres,
-     n_bands_converge, nbandsalg.occupation_threshold)
+    ψ = BlochWaves(ham.basis, eigres.X)
+    ρout = compute_density(ψ, occupation; nbandsalg.occupation_threshold)
+    (; ψ, eigenvalues=eigres.λ, occupation, εF, ρout,
+     diagonalization=eigres, n_bands_converge, nbandsalg.occupation_threshold)
 end
 
 
@@ -87,7 +88,7 @@ Overview of parameters:
 @timing function self_consistent_field(
     basis::PlaneWaveBasis{T};
     ρ=guess_density(basis),
-    ψ=nothing,
+    ψ::BlochWaves=BlochWaves(basis),
     tol=1e-6,
     is_converged=ScfConvergenceDensity(tol),
     maxiter=100,
@@ -102,10 +103,8 @@ Overview of parameters:
     compute_consistent_energies=true,
     response=ResponseOptions(),  # Dummy here, only for AD
 ) where {T}
+    @assert basis == ψ.basis
     # All these variables will get updated by fixpoint_map
-    if !isnothing(ψ)
-        @assert length(ψ) == length(basis.kpoints)
-    end
     occupation = nothing
     eigenvalues = nothing
     ρout = ρ
@@ -124,7 +123,7 @@ Overview of parameters:
 
         # Note that ρin is not the density of ψ, and the eigenvalues
         # are not the self-consistent ones, which makes this energy non-variational
-        energies, ham = energy_hamiltonian(basis, ψ, occupation; ρ=ρin, eigenvalues, εF)
+        energies, ham = energy_hamiltonian(ψ, occupation; ρ=ρin, eigenvalues, εF)
 
         # Diagonalize `ham` to get the new state
         nextstate = next_density(ham, nbandsalg, fermialg; eigensolver, ψ, eigenvalues,
@@ -138,8 +137,7 @@ Overview of parameters:
 
         # Compute the energy of the new state
         if compute_consistent_energies
-            energies = energy_hamiltonian(basis, ψ, occupation;
-                                          ρ=ρout, eigenvalues, εF).energies
+            energies = energy_hamiltonian(ψ, occupation; ρ=ρout, eigenvalues, εF).energies
         end
         info = merge(info, (; energies))
 
@@ -162,7 +160,7 @@ Overview of parameters:
     # We do not use the return value of solver but rather the one that got updated by fixpoint_map
     # ψ is consistent with ρout, so we return that. We also perform a last energy computation
     # to return a correct variational energy
-    energies, ham = energy_hamiltonian(basis, ψ, occupation; ρ=ρout, eigenvalues, εF)
+    energies, ham = energy_hamiltonian(ψ, occupation; ρ=ρout, eigenvalues, εF)
 
     # Measure for the accuracy of the SCF
     # TODO probably should be tracked all the way ...
diff --git a/src/supercell.jl b/src/supercell.jl
index 11cfeb22dc..4d2bac2bc0 100644
--- a/src/supercell.jl
+++ b/src/supercell.jl
@@ -88,7 +88,7 @@ function cell_to_supercell(ψ, basis::PlaneWaveBasis{T},
         push!(ψ_out_blocs, ψk_supercell)
     end
     # Note that each column is normalize since each ψ[ik][:,n] is.
-    [cat(ψ_out_blocs...; dims=3)]
+    BlochWaves(basis_supercell, [cat(ψ_out_blocs...; dims=3)])
 end
 
 @doc raw"""
@@ -112,12 +112,11 @@ function cell_to_supercell(scfres::NamedTuple)
     ψ_supercell     = cell_to_supercell(ψ, basis, basis_supercell)
     eigs_supercell  = [vcat(scfres_unfold.eigenvalues...)]
     occ_supercell   = compute_occupation(basis_supercell, eigs_supercell, scfres.εF).occupation
-    ρ_supercell     = compute_density(basis_supercell, ψ_supercell, occ_supercell;
-                                      scfres.occupation_threshold)
+    ρ_supercell     = compute_density(ψ_supercell, occ_supercell; scfres.occupation_threshold)
 
     # Supercell Energies
-    Eham_supercell = energy_hamiltonian(basis_supercell, ψ_supercell, occ_supercell;
-                                        ρ=ρ_supercell, eigenvalues=eigs_supercell, scfres.εF)
+    Eham_supercell = energy_hamiltonian(ψ_supercell, occ_supercell; ρ=ρ_supercell,
+                                        eigenvalues=eigs_supercell, scfres.εF)
 
     merge(scfres, (; ham=Eham_supercell.ham, basis=basis_supercell, ψ=ψ_supercell,
                    energies=Eham_supercell.energies, ρ=ρ_supercell,
diff --git a/src/symmetry.jl b/src/symmetry.jl
index bfd8fa01ea..64e6a24a15 100644
--- a/src/symmetry.jl
+++ b/src/symmetry.jl
@@ -432,7 +432,11 @@ function unfold_array_(basis_irred, basis_unfolded, data, is_ψ)
             data_unfolded[ik_unfolded] = data[ik_irred]
         end
     end
-    data_unfolded
+    if is_ψ
+        BlochWaves(basis_unfolded, denest(basis_unfolded, data_unfolded))
+    else
+        data_unfolded
+    end
 end
 
 function unfold_bz(scfres)
@@ -440,7 +444,7 @@ function unfold_bz(scfres)
     ψ = unfold_array_(scfres.basis, basis_unfolded, scfres.ψ, true)
     eigenvalues = unfold_array_(scfres.basis, basis_unfolded, scfres.eigenvalues, false)
     occupation = unfold_array_(scfres.basis, basis_unfolded, scfres.occupation, false)
-    energies, ham = energy_hamiltonian(basis_unfolded, ψ, occupation;
+    energies, ham = energy_hamiltonian(ψ, occupation;
                                        scfres.ρ, eigenvalues, scfres.εF)
     @assert energies.total ≈ scfres.energies.total
     new_scfres = (; basis=basis_unfolded, ψ, ham, eigenvalues, occupation)
diff --git a/src/terms/Hamiltonian.jl b/src/terms/Hamiltonian.jl
index db9923f725..17a18264bd 100644
--- a/src/terms/Hamiltonian.jl
+++ b/src/terms/Hamiltonian.jl
@@ -194,9 +194,10 @@ end
 # Get energies and Hamiltonian
 # kwargs is additional info that might be useful for the energy terms to precompute
 # (eg the density ρ)
-@timing function energy_hamiltonian(basis::PlaneWaveBasis, ψ, occupation; kwargs...)
+@timing function energy_hamiltonian(ψ::BlochWaves, occupation; kwargs...)
+    basis = ψ.basis
     # it: index into terms, ik: index into kpoints
-    @timing "ene_ops" ene_ops_arr = [ene_ops(term, basis, ψ, occupation; kwargs...)
+    @timing "ene_ops" ene_ops_arr = [ene_ops(term, ψ, occupation; kwargs...)
                                      for term in basis.terms]
     energies  = [eh.E for eh in ene_ops_arr]
     operators = [eh.ops for eh in ene_ops_arr]         # operators[it][ik]
@@ -221,8 +222,8 @@ end
     energies = Energies(basis.model.term_types, energies)
     (; energies, ham)
 end
-function Hamiltonian(basis::PlaneWaveBasis; ψ=nothing, occupation=nothing, kwargs...)
-    energy_hamiltonian(basis, ψ, occupation; kwargs...).ham
+function Hamiltonian(basis::PlaneWaveBasis; ψ=BlochWaves(basis), occupation=nothing, kwargs...)
+    energy_hamiltonian(ψ, occupation; kwargs...).ham
 end
 
 """
diff --git a/src/terms/anyonic.jl b/src/terms/anyonic.jl
index 16b4629ecd..0f04cba2ed 100644
--- a/src/terms/anyonic.jl
+++ b/src/terms/anyonic.jl
@@ -99,8 +99,9 @@ function TermAnyonic(basis::PlaneWaveBasis{T}, hbar, β) where {T}
     TermAnyonic(hbar, β, ρref, Aref)
 end
 
-function ene_ops(term::TermAnyonic, basis::PlaneWaveBasis{T}, ψ, occupation;
+function ene_ops(term::TermAnyonic, ψ::BlochWaves{T}, occupation;
                  ρ, kwargs...) where {T}
+    basis = ψ.basis
     @assert basis.model.n_components == 1
     hbar = term.hbar
     β = term.β
diff --git a/src/terms/entropy.jl b/src/terms/entropy.jl
index d24e710e39..b07f7d508e 100644
--- a/src/terms/entropy.jl
+++ b/src/terms/entropy.jl
@@ -8,8 +8,8 @@ struct Entropy end
 (::Entropy)(basis) = TermEntropy()
 struct TermEntropy <: Term end
 
-function ene_ops(term::TermEntropy, basis::PlaneWaveBasis{T}, ψ, occupation;
-                 kwargs...) where {T}
+function ene_ops(term::TermEntropy, ψ::BlochWaves{T}, occupation; kwargs...) where {T}
+    basis = ψ.basis
     ops = [NoopOperator(basis, kpt) for kpt in basis.kpoints]
     smearing    = basis.model.smearing
     temperature = basis.model.temperature
diff --git a/src/terms/ewald.jl b/src/terms/ewald.jl
index 3d23d166e8..fc052bf406 100644
--- a/src/terms/ewald.jl
+++ b/src/terms/ewald.jl
@@ -25,10 +25,10 @@ end
     TermEwald(energy, forces, η)
 end
 
-function ene_ops(term::TermEwald, basis::PlaneWaveBasis, ψ, occupation; kwargs...)
-    (; E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
+function ene_ops(term::TermEwald, ψ::BlochWaves, occupation; kwargs...)
+    (; E=term.energy, ops=[NoopOperator(ψ.basis, kpt) for kpt in ψ.basis.kpoints])
 end
-compute_forces(term::TermEwald, ::PlaneWaveBasis, ψ, occupation; kwargs...) = term.forces
+compute_forces(term::TermEwald, ::BlochWaves, occupation; kwargs...) = term.forces
 
 """
 Standard computation of energy and forces.
diff --git a/src/terms/hartree.jl b/src/terms/hartree.jl
index 0c4cd78c12..f193a02d5c 100644
--- a/src/terms/hartree.jl
+++ b/src/terms/hartree.jl
@@ -39,8 +39,9 @@ function TermHartree(basis::PlaneWaveBasis{T}, scaling_factor) where {T}
     TermHartree(T(scaling_factor), T(scaling_factor) .* poisson_green_coeffs)
 end
 
-@timing "ene_ops: hartree" function ene_ops(term::TermHartree, basis::PlaneWaveBasis{T},
-                                            ψ, occupation; ρ, kwargs...) where {T}
+@timing "ene_ops: hartree" function ene_ops(term::TermHartree, ψ::BlochWaves{T}, occupation;
+                                            ρ, kwargs...) where {T}
+    basis = ψ.basis
     ρtot_fourier = fft(basis, total_density(ρ))
     pot_fourier = term.poisson_green_coeffs .* ρtot_fourier
     pot_real = irfft(basis, pot_fourier)
diff --git a/src/terms/kinetic.jl b/src/terms/kinetic.jl
index 3649260383..992ed92770 100644
--- a/src/terms/kinetic.jl
+++ b/src/terms/kinetic.jl
@@ -34,8 +34,9 @@ function kinetic_energy(kin::Kinetic, Ecut, q)
     kinetic_energy(kin.blowup, kin.scaling_factor, Ecut, q)
 end
 
-@timing "ene_ops: kinetic" function ene_ops(term::TermKinetic, basis::PlaneWaveBasis{T},
-                                            ψ, occupation; kwargs...) where {T}
+@timing "ene_ops: kinetic" function ene_ops(term::TermKinetic, ψ::BlochWaves{T}, occupation;
+                                            kwargs...) where {T}
+    basis = ψ.basis
     ops = [FourierMultiplication(basis, kpoint, term.kinetic_energies[ik])
            for (ik, kpoint) in enumerate(basis.kpoints)]
     if isnothing(ψ) || isnothing(occupation)
diff --git a/src/terms/local.jl b/src/terms/local.jl
index 3ee08cb40b..e2378e06fe 100644
--- a/src/terms/local.jl
+++ b/src/terms/local.jl
@@ -6,10 +6,10 @@
 # two spin components.
 abstract type TermLocalPotential <: Term end
 
-@timing "ene_ops: local" function ene_ops(term::TermLocalPotential,
-                                          basis::PlaneWaveBasis{T}, ψ, occupation;
-                                          kwargs...) where {T}
+@timing "ene_ops: local" function ene_ops(term::TermLocalPotential, ψ::BlochWaves{T},
+                                          occupation; kwargs...) where {T}
     potview(data, spin) = ndims(data) == 4 ? (@view data[:, :, :, spin]) : data
+    basis = ψ.basis
     ops = [RealSpaceMultiplication(basis, kpt, potview(term.potential_values, kpt.spin))
            for kpt in basis.kpoints]
     if :ρ in keys(kwargs)
@@ -123,18 +123,19 @@ function (::AtomicLocal)(basis::PlaneWaveBasis{T}) where {T}
     TermAtomicLocal(pot_real)
 end
 
-@timing "forces: local" function compute_forces(::TermAtomicLocal, basis::PlaneWaveBasis{TT},
-                                                ψ, occupation; ρ, kwargs...) where {TT}
-    T = promote_type(TT, real(eltype(ψ[1])))
+@timing "forces: local" function compute_forces(::TermAtomicLocal, ψ::BlochWaves{T, Tψ},
+                                                occupation; ρ, kwargs...) where {T, Tψ}
+    TT = promote_type(T, real(Tψ))
+    basis = ψ.basis
     model = basis.model
     ρ_fourier = fft(basis, total_density(ρ))
 
     # energy = sum of form_factor(G) * struct_factor(G) * rho(G)
     # where struct_factor(G) = e^{-i G·r}
-    forces = [zero(Vec3{T}) for _ = 1:length(model.positions)]
+    forces = [zero(Vec3{TT}) for _ = 1:length(model.positions)]
     for group in model.atom_groups
         element = model.atoms[first(group)]
-        form_factors = [Complex{T}(local_potential_fourier(element, norm(G)))
+        form_factors = [Complex{TT}(local_potential_fourier(element, norm(G)))
                         for G in G_vectors_cart(basis)]
         for idx in group
             r = model.positions[idx]
diff --git a/src/terms/local_nonlinearity.jl b/src/terms/local_nonlinearity.jl
index 34998827f1..7c4e665cd8 100644
--- a/src/terms/local_nonlinearity.jl
+++ b/src/terms/local_nonlinearity.jl
@@ -9,8 +9,9 @@ struct TermLocalNonlinearity{TF} <: TermNonlinear
 end
 (L::LocalNonlinearity)(::AbstractBasis) = TermLocalNonlinearity(L.f)
 
-function ene_ops(term::TermLocalNonlinearity, basis::PlaneWaveBasis{T}, ψ, occupation;
+function ene_ops(term::TermLocalNonlinearity, ψ::BlochWaves{T}, occupation;
                  ρ, kwargs...) where {T}
+    basis = ψ.basis
     fp(ρ) = ForwardDiff.derivative(term.f, ρ)
     E = sum(fρ -> convert_dual(T, fρ), term.f.(ρ)) * basis.dvol
     potential = convert_dual.(T, fp.(ρ))
diff --git a/src/terms/magnetic.jl b/src/terms/magnetic.jl
index b49e76af76..4e5e9d7b65 100644
--- a/src/terms/magnetic.jl
+++ b/src/terms/magnetic.jl
@@ -30,8 +30,8 @@ function TermMagnetic(basis::PlaneWaveBasis{T}, Afunction::Function) where {T}
     TermMagnetic(Apotential)
 end
 
-function ene_ops(term::TermMagnetic, basis::PlaneWaveBasis{T}, ψ, occupation;
-                 kwargs...) where {T}
+function ene_ops(term::TermMagnetic, ψ::BlochWaves{T}, occupation; kwargs...) where {T}
+    basis = ψ.basis
     ops = [MagneticFieldOperator(basis, kpoint, term.Apotential)
            for (ik, kpoint) in enumerate(basis.kpoints)]
     if isnothing(ψ) || isnothing(occupation)
diff --git a/src/terms/nonlocal.jl b/src/terms/nonlocal.jl
index 32a6bb4946..7a31f83421 100644
--- a/src/terms/nonlocal.jl
+++ b/src/terms/nonlocal.jl
@@ -25,9 +25,9 @@ struct TermAtomicNonlocal <: Term
     ops::Vector{NonlocalOperator}
 end
 
-@timing "ene_ops: nonlocal" function ene_ops(term::TermAtomicNonlocal,
-                                             basis::PlaneWaveBasis{T},
-                                             ψ, occupation; kwargs...) where {T}
+@timing "ene_ops: nonlocal" function ene_ops(term::TermAtomicNonlocal, ψ::BlochWaves{T},
+                                             occupation; kwargs...) where {T}
+    basis = ψ.basis
     if isnothing(ψ) || isnothing(occupation)
         return (; E=T(Inf), term.ops)
     end
@@ -45,10 +45,10 @@ end
     (; E, term.ops)
 end
 
-@timing "forces: nonlocal" function compute_forces(::TermAtomicNonlocal,
-                                                   basis::PlaneWaveBasis{TT},
-                                                   ψ, occupation; kwargs...) where {TT}
-    T = promote_type(TT, real(eltype(ψ[1])))
+@timing "forces: nonlocal" function compute_forces(::TermAtomicNonlocal, ψ::BlochWaves{T, Tψ},
+                                                   occupation; kwargs...) where {T, Tψ}
+    basis = ψ.basis
+    TT = promote_type(T, real(Tψ))
     model = basis.model
     unit_cell_volume = model.unit_cell_volume
     psp_groups = [group for group in model.atom_groups
@@ -58,11 +58,11 @@ end
     isempty(psp_groups) && return nothing
 
     # energy terms are of the form <psi, P C P' psi>, where P(G) = form_factor(G) * structure_factor(G)
-    forces = [zero(Vec3{T}) for _ = 1:length(model.positions)]
+    forces = [zero(Vec3{TT}) for _ = 1:length(model.positions)]
     for group in psp_groups
         element = model.atoms[first(group)]
 
-        C = build_projection_coefficients_(T, element.psp)
+        C = build_projection_coefficients_(TT, element.psp)
         for (ik, kpt) in enumerate(basis.kpoints)
             # we compute the forces from the irreductible BZ; they are symmetrized later
             qs = Gplusk_vectors(basis, kpt)
@@ -74,7 +74,7 @@ end
                 P = structure_factors .* form_factors ./ sqrt(unit_cell_volume)
 
                 forces[idx] += map(1:3) do α
-                    dPdR = [-2T(π)*im*q[α] for q in qs] .* P
+                    dPdR = [-2TT(π)*im*q[α] for q in qs] .* P
                     mapreduce(+, 1:model.n_components) do σ
                         ψkσ = ψ[ik][σ, :, :]
                         dHψkσ = P * (C * (dPdR' * ψkσ))
diff --git a/src/terms/pairwise.jl b/src/terms/pairwise.jl
index 55f597459c..8198a64895 100644
--- a/src/terms/pairwise.jl
+++ b/src/terms/pairwise.jl
@@ -33,10 +33,10 @@ struct TermPairwisePotential{TV, Tparams, T} <:Term
     forces::Vector{Vec3{T}}
 end
 
-function ene_ops(term::TermPairwisePotential, basis::PlaneWaveBasis, ψ, occupation; kwargs...)
-    (; E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
+function ene_ops(term::TermPairwisePotential, ψ::BlochWaves, occupation; kwargs...)
+    (; E=term.energy, ops=[NoopOperator(ψ.basis, kpt) for kpt in ψ.basis.kpoints])
 end
-compute_forces(term::TermPairwisePotential, ::PlaneWaveBasis, ψ, occ; kwargs...) = term.forces
+compute_forces(term::TermPairwisePotential, ::BlochWaves, occ; kwargs...) = term.forces
 
 
 """
diff --git a/src/terms/psp_correction.jl b/src/terms/psp_correction.jl
index 198b0771bc..52ff73005d 100644
--- a/src/terms/psp_correction.jl
+++ b/src/terms/psp_correction.jl
@@ -15,8 +15,8 @@ function TermPspCorrection(basis::PlaneWaveBasis)
     TermPspCorrection(energy_psp_correction(model))
 end
 
-function ene_ops(term::TermPspCorrection, basis::PlaneWaveBasis, ψ, occupation; kwargs...)
-    (; E=term.energy, ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
+function ene_ops(term::TermPspCorrection, ψ::BlochWaves, occupation; kwargs...)
+    (; E=term.energy, ops=[NoopOperator(ψ.basis, kpt) for kpt in ψ.basis.kpoints])
 end
 
 """
diff --git a/src/terms/terms.jl b/src/terms/terms.jl
index b4bf7e4c52..1343b77348 100644
--- a/src/terms/terms.jl
+++ b/src/terms/terms.jl
@@ -4,7 +4,7 @@ include("operators.jl")
 # - A Term is something that, given a state, returns a named tuple (; E, hams) with an energy
 #   and a list of RealFourierOperator (for each kpoint).
 # - Each term must overload
-#     `ene_ops(term, basis, ψ, occupation; kwargs...)`
+#     `ene_ops(term, ψ, occupation; kwargs...)`
 #         -> (; E::Real, ops::Vector{RealFourierOperator}).
 # - Note that terms are allowed to hold on to references to ψ (eg Fock term),
 #   so ψ should not mutated after ene_ops
@@ -26,8 +26,8 @@ abstract type TermNonlinear <: Term end
 A term with a constant zero energy.
 """
 struct TermNoop <: Term end
-function ene_ops(term::TermNoop, basis::PlaneWaveBasis{T}, ψ, occupation; kwargs...) where {T}
-    (; E=zero(eltype(T)), ops=[NoopOperator(basis, kpt) for kpt in basis.kpoints])
+function ene_ops(term::TermNoop, ψ::BlochWaves{T}, occupation; kwargs...) where {T}
+    (; E=zero(eltype(T)), ops=[NoopOperator(ψ.basis, kpt) for kpt in ψ.basis.kpoints])
 end
 
 include("Hamiltonian.jl")
@@ -59,9 +59,9 @@ include("anyonic.jl")
 breaks_symmetries(::Anyonic) = true
 
 # forces computes either nothing or an array forces[at][α] (by default no forces)
-compute_forces(::Term, ::AbstractBasis, ψ, occupation; kwargs...) = nothing
+compute_forces(::Term, ::BlochWaves, occupation; kwargs...) = nothing
 # dynamical matrix for phonons computations (array dynmat[n_dim, n_atom, n_dim, n_atom])
-compute_dynmat(::Term, ::AbstractBasis, ψ, occupation; kwargs...) = nothing
+compute_dynmat(::Term, ::BlochWaves, occupation; kwargs...) = nothing
 
 @doc raw"""
     compute_kernel(basis::PlaneWaveBasis; kwargs...)
diff --git a/src/terms/xc.jl b/src/terms/xc.jl
index 3a79070d0e..f4581f11ae 100644
--- a/src/terms/xc.jl
+++ b/src/terms/xc.jl
@@ -55,10 +55,10 @@ struct TermXc{T,CT} <: TermNonlinear where {T,CT}
     ρcore::CT
 end
 
-function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupation;
-                           ρ, τ=nothing) where {T}
+function xc_potential_real(term::TermXc, ψ::BlochWaves{T}, occupation; ρ, τ=nothing) where {T}
     @assert !isempty(term.functionals)
 
+    basis    = ψ.basis
     model    = basis.model
     n_spin   = model.n_spin_components
     potential_threshold = term.potential_threshold
@@ -74,7 +74,7 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
         if isnothing(ψ) || isnothing(occupation)
             τ = zero(ρ)
         else
-            τ = compute_kinetic_energy_density(basis, ψ, occupation)
+            τ = compute_kinetic_energy_density(ψ, occupation)
         end
     end
 
@@ -135,11 +135,11 @@ function xc_potential_real(term::TermXc, basis::PlaneWaveBasis{T}, ψ, occupatio
     (; E, potential, Vτ)
 end
 
-@views @timing "ene_ops: xc" function ene_ops(term::TermXc, basis::PlaneWaveBasis{T},
-                                              ψ, occupation; ρ, τ=nothing,
-                                              kwargs...) where {T}
-    E, Vxc, Vτ = xc_potential_real(term, basis, ψ, occupation; ρ, τ)
+@views @timing "ene_ops: xc" function ene_ops(term::TermXc, ψ::BlochWaves, occupation;
+                                              ρ, τ=nothing, kwargs...)
+    E, Vxc, Vτ = xc_potential_real(term, ψ, occupation; ρ, τ)
 
+    basis = ψ.basis
     ops = map(basis.kpoints) do kpt
         if !isnothing(Vτ)
             [RealSpaceMultiplication(basis, kpt, Vxc[:, :, :, kpt.spin]),
@@ -151,14 +151,14 @@ end
     (; E, ops)
 end
 
-@timing "forces: xc" function compute_forces(term::TermXc, basis::PlaneWaveBasis{T},
-                                             ψ, occupation; ρ, τ=nothing,
-                                             kwargs...) where {T}
+@timing "forces: xc" function compute_forces(term::TermXc, ψ::BlochWaves{T}, occupation;
+                                             ρ, τ=nothing, kwargs...) where {T}
     # the only non-zero force contribution is from the nlcc core charge
     # early return if nlcc is disabled / no elements have model core charges
     isnothing(term.ρcore) && return nothing
 
-    Vxc_real = xc_potential_real(term, basis, ψ, occupation; ρ, τ).potential
+    basis = ψ.basis
+    Vxc_real = xc_potential_real(term, ψ, occupation; ρ, τ).potential
     # TODO: the factor of 2 here should be associated with the density, not the potential
     if basis.model.spin_polarization in (:none, :spinless)
         Vxc_fourier = fft(basis, Vxc_real[:,:,:,1])
diff --git a/src/transfer.jl b/src/transfer.jl
index 5b7d089411..74689588f7 100644
--- a/src/transfer.jl
+++ b/src/transfer.jl
@@ -150,9 +150,10 @@ function transfer_blochwave(ψ_in, basis_in::PlaneWaveBasis{T},
     # It is then of size G_vectors(basis_out.kpoints[ik]) and the transfer can be done with
     # ψ_out[ik] .= ψ_in[ik][idcs_in[ik], :]
 
-    map(enumerate(basis_out.kpoints)) do (ik, kpt_out)
-        transfer_blochwave_kpt(ψ_in[ik], basis_in, basis_in.kpoints[ik], basis_out, kpt_out)
-    end
+    BlochWaves(basis_out, map(enumerate(basis_out.kpoints)) do (ik, kpt_out)
+               transfer_blochwave_kpt(ψ_in[ik], basis_in, basis_in.kpoints[ik],
+                                      basis_out, kpt_out)
+    end)
 end
 
 @doc raw"""
diff --git a/src/workarounds/forwarddiff_rules.jl b/src/workarounds/forwarddiff_rules.jl
index ac46c1584b..431176db1a 100644
--- a/src/workarounds/forwarddiff_rules.jl
+++ b/src/workarounds/forwarddiff_rules.jl
@@ -200,14 +200,14 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
     ## Compute external perturbation (contained in ham_dual) and from matvec with bands
     Hψ_dual = let
         occupation_dual = [T.(occk) for occk in scfres.occupation]
-        ψ_dual = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in scfres.ψ]
-        ρ_dual = compute_density(basis_dual, ψ_dual, occupation_dual)
+        ψ_dual_values = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in scfres.ψ]
+        ψ_dual = BlochWaves(basis_dual, ψ_dual_values)
+        ρ_dual = compute_density(ψ_dual, occupation_dual)
         εF_dual = T(scfres.εF)  # Only needed for entropy term
         eigenvalues_dual = [T.(εk) for εk in scfres.eigenvalues]
-        ham_dual = energy_hamiltonian(basis_dual, ψ_dual, occupation_dual;
-                                      ρ=ρ_dual, eigenvalues=eigenvalues_dual,
-                                      εF=εF_dual).ham
-        ham_dual * ψ_dual
+        ham_dual = energy_hamiltonian(ψ_dual, occupation_dual;
+                                      ρ=ρ_dual, eigenvalues=eigenvalues_dual, εF=εF_dual).ham
+        ham_dual * ψ_dual_values
     end
 
     ## Implicit differentiation
@@ -219,12 +219,12 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
 
     ## Convert and combine
     DT = ForwardDiff.Dual{ForwardDiff.tagtype(T)}
-    ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk...
+    ψ = BlochWaves(basis_dual, map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk...
         map(ψk, δψk...) do ψnk, δψnk...
             Complex(DT(real(ψnk), real.(δψnk)),
                     DT(imag(ψnk), imag.(δψnk)))
         end
-    end
+    end)
     ρ = map((ρi, δρi...) -> DT(ρi, δρi), scfres.ρ, getfield.(δresults, :δρ)...)
     eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk...
         map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)
@@ -237,7 +237,7 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
     # TODO Could add δresults[α].δVind the dual part of the total local potential in ham_dual
     # and in this way return a ham that represents also the total change in Hamiltonian
 
-    (; energies, ham) = energy_hamiltonian(basis_dual, ψ, occupation; ρ, eigenvalues, εF)
+    energies, ham = energy_hamiltonian(ψ, occupation; ρ, eigenvalues, εF)
 
     # This has to be changed whenever the scfres structure changes
     (; ham, basis=basis_dual, energies, ρ, eigenvalues, occupation, εF, ψ,
diff --git a/test/chi0.jl b/test/chi0.jl
index 1bd4687b24..dcfad9111f 100644
--- a/test/chi0.jl
+++ b/test/chi0.jl
@@ -35,7 +35,7 @@ function test_chi0(testcase; symmetries=false, temperature=0, spin_polarization=
                           model_kwargs...)
         basis = PlaneWaveBasis(model; basis_kwargs...)
         ρ0    = guess_density(basis, magnetic_moments)
-        ham0  = energy_hamiltonian(basis, nothing, nothing; ρ=ρ0).ham
+        ham0  = energy_hamiltonian(BlochWaves(basis), nothing; ρ=ρ0).ham
         nbandsalg = is_εF_fixed ? FixedBands(; n_bands_converge=6) : AdaptiveBands(model)
         res = DFTK.next_density(ham0, nbandsalg; tol, eigensolver)
         scfres = (; ham=ham0, res...)
@@ -56,7 +56,7 @@ function test_chi0(testcase; symmetries=false, temperature=0, spin_polarization=
             model = model_LDA(testcase.lattice, testcase.atoms, testcase.positions;
                               model_kwargs..., extra_terms=[term_builder])
             basis = PlaneWaveBasis(model; basis_kwargs...)
-            ham = energy_hamiltonian(basis, nothing, nothing; ρ=ρ0).ham
+            ham = energy_hamiltonian(BlochWaves(basis), nothing; ρ=ρ0).ham
             res = DFTK.next_density(ham, nbandsalg; tol, eigensolver)
             res.ρout
         end
@@ -71,7 +71,7 @@ function test_chi0(testcase; symmetries=false, temperature=0, spin_polarization=
         @test norm(diff_findiff - diff_applied_χ0) < testtol
 
         # Test apply_χ0 without extra bands
-        ψ_occ, occ_occ = DFTK.select_occupied_orbitals(basis, scfres.ψ, scfres.occupation;
+        ψ_occ, occ_occ = DFTK.select_occupied_orbitals(scfres.ψ, scfres.occupation;
                                                        threshold=scfres.occupation_threshold)
         ε_occ = [scfres.eigenvalues[ik][1:size(ψk, 3)] for (ik, ψk) in enumerate(ψ_occ)]
 
diff --git a/test/compute_density.jl b/test/compute_density.jl
index 7f45abdb4e..510d660071 100644
--- a/test/compute_density.jl
+++ b/test/compute_density.jl
@@ -25,14 +25,15 @@
 
         res = diagonalize_all_kblocks(lobpcg_hyper, ham, n_bands; tol)
         occ, εF = DFTK.compute_occupation(basis, res.λ, FermiBisection())
-        ρnew = compute_density(basis, res.X, occ)
+        ρnew = compute_density(BlochWaves(basis, res.X), occ)
 
         for it = 1:n_rounds
             ham = Hamiltonian(basis; ρ=ρnew)
-            res = diagonalize_all_kblocks(lobpcg_hyper, ham, n_bands; tol, ψguess=res.X)
+            ψguess = BlochWaves(basis, res.X)
+            res = diagonalize_all_kblocks(lobpcg_hyper, ham, n_bands; tol, ψguess)
 
             occ, εF = DFTK.compute_occupation(basis, res.λ, FermiBisection())
-            ρnew = compute_density(basis, res.X, occ)
+            ρnew = compute_density(BlochWaves(basis, res.X), occ)
         end
 
         ham, res.X, res.λ, ρnew, occ
diff --git a/test/compute_jacobian_eigen.jl b/test/compute_jacobian_eigen.jl
index 281ff300cf..98cbf24d78 100644
--- a/test/compute_jacobian_eigen.jl
+++ b/test/compute_jacobian_eigen.jl
@@ -7,7 +7,8 @@ using DFTK: apply_K, apply_Ω
 using DFTK: precondprep!, FunctionPreconditioner
 using LinearMaps
 
-function eigen_ΩplusK(basis::PlaneWaveBasis{T}, ψ, occupation, numval) where {T}
+function eigen_ΩplusK(ψ::BlochWaves{T}, occupation, numval) where {T}
+    basis = ψ.basis
     n_components = basis.model.n_components
     @assert n_components == 1
     ψ_matrices = blochwaves_as_matrices(ψ)
@@ -16,8 +17,8 @@ function eigen_ΩplusK(basis::PlaneWaveBasis{T}, ψ, occupation, numval) where {
     unpack(x) = unpack_ψ(reinterpret_complex(x), size.(ψ))
 
     # compute quantites at the point which define the tangent space
-    ρ = compute_density(basis, ψ, occupation)
-    H = energy_hamiltonian(basis, ψ, occupation; ρ).ham
+    ρ = compute_density(ψ, occupation)
+    H = energy_hamiltonian(ψ, occupation; ρ).ham
 
     # preconditioner
     Pks = [PreconditionerTPA(basis, kpt) for kpt in basis.kpoints]
@@ -39,7 +40,7 @@ function eigen_ΩplusK(basis::PlaneWaveBasis{T}, ψ, occupation, numval) where {
     x0 = map(1:numval) do _
         initial = map(enumerate(basis.kpoints)) do (ik, kpt)
             n_Gk = length(G_vectors(basis, kpt))
-            randn(Complex{T}, n_components, n_Gk, size(ψ[ik], 3))
+            randn(Complex{eltype(basis)}, n_components, n_Gk, size(ψ[ik], 3))
         end
         pack(proj_tangent(initial, ψ))
     end
@@ -75,9 +76,9 @@ end
         model  = model_atomic(testcase.lattice, testcase.atoms, testcase.positions)
         basis  = PlaneWaveBasis(model; Ecut=5, kgrid=[1, 1, 1])
         scfres = self_consistent_field(basis; tol=1e-8)
-        ψ, occupation = select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
+        ψ, occupation = select_occupied_orbitals(scfres.ψ, scfres.occupation)
 
-        res = eigen_ΩplusK(basis, ψ, occupation, numval)
+        res = eigen_ΩplusK(ψ, occupation, numval)
         gap = scfres.eigenvalues[1][5] - scfres.eigenvalues[1][4]
 
         # in the linear case, the smallest eigenvalue of Ω is the gap
@@ -91,9 +92,9 @@ end
         model  = model_LDA(testcase.lattice, testcase.atoms, testcase.positions)
         basis  = PlaneWaveBasis(model; Ecut=5, kgrid=[1, 1, 1])
         scfres = self_consistent_field(basis; tol=1e-8)
-        ψ, occupation = select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
+        ψ, occupation = select_occupied_orbitals(scfres.ψ, scfres.occupation)
 
-        res = eigen_ΩplusK(basis, ψ, occupation, numval)
+        res = eigen_ΩplusK(ψ, occupation, numval)
         @test res.λ[1] > 1e-3
     end
 end
diff --git a/test/energies_guess_density.jl b/test/energies_guess_density.jl
index c471a45e33..ca2c13a945 100644
--- a/test/energies_guess_density.jl
+++ b/test/energies_guess_density.jl
@@ -17,7 +17,7 @@
     basis = PlaneWaveBasis(model; Ecut, kgrid, fft_size, kshift)
 
     ρ0 = guess_density(basis, ValenceDensityGaussian())
-    E, H = energy_hamiltonian(basis, nothing, nothing; ρ=ρ0)
+    E, H = energy_hamiltonian(BlochWaves(basis), nothing; ρ=ρ0)
 
     @test E["Hartree"] ≈  0.3527293727197568  atol=5e-8
     @test E["Xc"]      ≈ -2.3033165870558165  atol=5e-8
@@ -26,8 +26,8 @@
     res = diagonalize_all_kblocks(lobpcg_hyper, H, n_bands, tol=1e-9)
     occupation = [[2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0]
                   for i = 1:length(basis.kpoints)]
-    ρ = compute_density(H.basis, res.X, occupation)
-    E, H = energy_hamiltonian(basis, res.X, occupation; ρ)
+    ρ = compute_density(BlochWaves(H.basis, res.X), occupation)
+    E, H = energy_hamiltonian(BlochWaves(basis, res.X), occupation; ρ)
 
     @test E["Kinetic"]        ≈  3.3824289861522194  atol=5e-8
     @test E["AtomicLocal"]    ≈ -2.4178712046759157  atol=5e-8
@@ -49,7 +49,7 @@
                                    PairwisePotential(V, params)],
                       )
     basis = PlaneWaveBasis(model; Ecut, kgrid, fft_size, kshift)
-    E, H = energy_hamiltonian(basis, res.X, occupation; ρ)
+    E, H = energy_hamiltonian(BlochWaves(basis, res.X), occupation; ρ)
 
     @test E["Kinetic"]             ≈  3.3824289861522194  atol=5e-8
     @test E["AtomicLocal"]         ≈ -2.4178712046759157  atol=5e-8
diff --git a/test/hamiltonian_consistency.jl b/test/hamiltonian_consistency.jl
index e6179afb59..4b56d881c0 100644
--- a/test/hamiltonian_consistency.jl
+++ b/test/hamiltonian_consistency.jl
@@ -38,27 +38,28 @@ function test_consistency_term(term; rtol=1e-4, atol=1e-8, ε=1e-6, kgrid=[1, 2,
         n_bands = div(n_electrons, 2, RoundUp)
         filled_occ = DFTK.filled_occupation(model)
 
-        ψ = map(basis.kpoints) do kpt
-            Q = Matrix(qr(randn(ComplexF64, length(G_vectors(basis, kpt)), n_bands)).Q)
-            reshape(Q, 1, size(Q)...)
-        end
+        ψ = BlochWaves(basis, map(basis.kpoints) do kpt
+            n_Gk = length(G_vectors(basis, kpt))
+            Q = Matrix(qr(randn(ComplexF64, n_Gk, n_bands)).Q)
+            reshape(Q, 1, n_Gk, n_bands)
+        end)
         occupation = [filled_occ * rand(n_bands) for _ = 1:length(basis.kpoints)]
         occ_scaling = n_electrons / sum(sum(occupation))
         occupation = [occ * occ_scaling for occ in occupation]
         ρ = with_logger(NullLogger()) do
-            compute_density(basis, ψ, occupation)
+            compute_density(ψ, occupation)
         end
-        E0, ham = energy_hamiltonian(basis, ψ, occupation; ρ)
+        E0, ham = energy_hamiltonian(ψ, occupation; ρ)
 
         @assert length(basis.terms) == 1
 
         δψ = [randn(ComplexF64, size(ψ[ik])) for ik = 1:length(basis.kpoints)]
         function compute_E(ε)
-            ψ_trial = ψ .+ ε .* δψ
+            ψ_trial = BlochWaves(basis, denest(ψ) .+ ε .* δψ)
             ρ_trial = with_logger(NullLogger()) do
-                compute_density(basis, ψ_trial, occupation)
+                compute_density(ψ_trial, occupation)
             end
-            E = energy_hamiltonian(basis, ψ_trial, occupation; ρ=ρ_trial).energies
+            E = energy_hamiltonian(ψ_trial, occupation; ρ=ρ_trial).energies
             E.total
         end
 
@@ -66,7 +67,6 @@ function test_consistency_term(term; rtol=1e-4, atol=1e-8, ε=1e-6, kgrid=[1, 2,
 
         diff_predicted = 0.0
         for ik = 1:length(basis.kpoints)
-            # T@D@ should be ok without 1
             Hψk = ham.blocks[ik]*ψ[ik][1, :, :]
             test_matrix_repr_operator(ham.blocks[ik], ψ[ik]; atol)
             δψkHψk = sum(occupation[ik][iband] * real(dot(δψ[ik][1, :, iband], Hψk[:, iband]))
diff --git a/test/hessian.jl b/test/hessian.jl
index 83881c37ae..8d095060ab 100644
--- a/test/hessian.jl
+++ b/test/hessian.jl
@@ -7,13 +7,13 @@ function setup_quantities(testcase)
     basis = PlaneWaveBasis(model; Ecut=3, kgrid=(3, 3, 3), fft_size=[9, 9, 9])
     scfres = self_consistent_field(basis; tol=10)
 
-    ψ, occupation = select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
+    ψ, occupation = select_occupied_orbitals(scfres.ψ, scfres.occupation)
 
-    ρ = compute_density(basis, ψ, occupation)
-    rhs = compute_projected_gradient(basis, ψ, occupation)
-    ϕ = rhs + ψ
+    ρ = compute_density(ψ, occupation)
+    rhs = compute_projected_gradient(ψ, occupation)
+    ϕ = rhs + denest(ψ)
 
-    (; scfres, basis, ψ, occupation, ρ, rhs, ϕ)
+    (; scfres, ψ, occupation, ρ, rhs, ϕ)
 end
 end
 
@@ -22,11 +22,11 @@ end
     =#    tags=[:dont_test_mpi] setup=[Hessian, TestCases] begin
     using DFTK: solve_ΩplusK
     using LinearAlgebra
-    (; basis, ψ, occupation, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
+    (; ψ, occupation, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
 
     @test isapprox(
-        real(dot(ϕ, solve_ΩplusK(basis, ψ, rhs, occupation).δψ)),
-        real(dot(solve_ΩplusK(basis, ψ, ϕ, occupation).δψ, rhs)),
+        real(dot(ϕ, solve_ΩplusK(ψ, rhs, occupation).δψ)),
+        real(dot(solve_ΩplusK(ψ, ϕ, occupation).δψ, rhs)),
         atol=1e-7
     )
 end
@@ -36,11 +36,12 @@ end
     using DFTK
     using DFTK: apply_Ω
     using LinearAlgebra
-    (; basis, ψ, occupation, ρ, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
+    (; ψ, occupation, ρ, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
+    ψ_arr = denest(ψ)
 
-    H = energy_hamiltonian(basis, ψ, occupation; ρ).ham
+    H = energy_hamiltonian(ψ, occupation; ρ).ham
     # Rayleigh-coefficients
-    Λ = [ψk[1, :, :]'Hψk[1, :, :] for (ψk, Hψk) in zip(ψ, H * ψ)]
+    Λ = [ψk[1, :, :]'Hψk[1, :, :] for (ψk, Hψk) in zip(ψ_arr, H * ψ_arr)]
 
     # Ω is complex-linear and so self-adjoint as a complex operator.
     @test isapprox(
@@ -54,13 +55,13 @@ end
     =#    tags=[:dont_test_mpi] setup=[Hessian, TestCases] begin
     using DFTK: apply_K
     using LinearAlgebra
-    (; basis, ψ, occupation, ρ, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
+    (; ψ, occupation, ρ, rhs, ϕ) = Hessian.setup_quantities(TestCases.silicon)
 
     # K involves conjugates and is only a real-linear operator,
     # hence we test using the real dot product.
     @test isapprox(
-        real(dot(ϕ, apply_K(basis, rhs, ψ, ρ, occupation))),
-        real(dot(apply_K(basis, ϕ, ψ, ρ, occupation), rhs)),
+        real(dot(ϕ, apply_K(ψ.basis, rhs, ψ, ρ, occupation))),
+        real(dot(apply_K(ψ.basis, ϕ, ψ, ρ, occupation), rhs)),
         atol=1e-7
     )
 end
@@ -76,8 +77,8 @@ end
     basis = PlaneWaveBasis(model; Ecut=3, kgrid=(3, 3, 3), fft_size=[9, 9, 9])
     scfres = self_consistent_field(basis; tol=10)
 
-    rhs = compute_projected_gradient(basis, scfres.ψ, scfres.occupation)
-    ϕ = rhs + scfres.ψ
+    rhs = compute_projected_gradient(scfres.ψ, scfres.occupation)
+    ϕ = rhs + denest(scfres.ψ)
 
     @testset "self-adjointness of solve_ΩplusK_split" begin
         @test isapprox(real(dot(ϕ, solve_ΩplusK_split(scfres, rhs).δψ)),
@@ -88,11 +89,12 @@ end
     @testset "solve_ΩplusK_split agrees with solve_ΩplusK" begin
         scfres = self_consistent_field(basis; tol=1e-10)
         δψ1 = solve_ΩplusK_split(scfres, rhs).δψ
-        δψ1 = select_occupied_orbitals(basis, δψ1, scfres.occupation).ψ
-        (; ψ, occupation) = select_occupied_orbitals(basis, scfres.ψ, scfres.occupation)
-        rhs_trunc = select_occupied_orbitals(basis, rhs, occupation).ψ
-        δψ2 = solve_ΩplusK(basis, ψ, rhs_trunc, occupation).δψ
-        @test norm(δψ1 - δψ2) < 1e-7
+        δψ1 = select_occupied_orbitals(BlochWaves(basis, δψ1), scfres.occupation).ψ
+        (; ψ, occupation) = select_occupied_orbitals(scfres.ψ, scfres.occupation)
+        rhs_trunc = select_occupied_orbitals(BlochWaves(basis, rhs), occupation).ψ
+        δψ2 = solve_ΩplusK(ψ, rhs_trunc, occupation).δψ
+        # T@D@
+        @test norm([e for e in δψ1] - δψ2) < 1e-7
     end
 end
 
@@ -109,8 +111,8 @@ end
     scfres = self_consistent_field(basis; tol=1e-12, nbandsalg)
 
     ψ = scfres.ψ
-    rhs = compute_projected_gradient(basis, scfres.ψ, scfres.occupation)
-    ϕ = rhs + ψ
+    rhs = compute_projected_gradient(scfres.ψ, scfres.occupation)
+    ϕ = rhs + denest(ψ)
 
     @testset "self-adjointness of solve_ΩplusK_split" begin
         @test isapprox(real(dot(ϕ, solve_ΩplusK_split(scfres, rhs).δψ)),
diff --git a/test/kernel.jl b/test/kernel.jl
index e153339db8..1e936e4f7f 100644
--- a/test/kernel.jl
+++ b/test/kernel.jl
@@ -28,8 +28,8 @@
             δρ = randn(size(ρ0))
             ρ_minus     = ρ0 - ε * δρ
             ρ_plus      = ρ0 + ε * δρ
-            ops_minus = DFTK.ene_ops(term, basis, nothing, nothing; ρ=ρ_minus).ops
-            ops_plus  = DFTK.ene_ops(term, basis, nothing, nothing; ρ=ρ_plus).ops
+            ops_minus = DFTK.ene_ops(term, BlochWaves(basis), nothing; ρ=ρ_minus).ops
+            ops_plus  = DFTK.ene_ops(term, BlochWaves(basis), nothing; ρ=ρ_plus).ops
             δV = zero(ρ0)
 
             for iσ = 1:model.n_spin_components
diff --git a/test/pairwise.jl b/test/pairwise.jl
index 01b98c84f0..98befa4e49 100644
--- a/test/pairwise.jl
+++ b/test/pairwise.jl
@@ -20,7 +20,7 @@
 
     model = Model(lattice, atoms, positions; terms=[term])
     basis = PlaneWaveBasis(model; Ecut=20, kgrid=(1, 1, 1))
-    forces = compute_forces(only(basis.terms), basis, nothing, nothing)
+    forces = compute_forces(only(basis.terms), BlochWaves(basis), nothing)
 
     # Compare forces to finite differences
     ε=1e-8
diff --git a/test/phonon/helpers.jl b/test/phonon/helpers.jl
index d843fcdb41..e966c5a1dd 100644
--- a/test/phonon/helpers.jl
+++ b/test/phonon/helpers.jl
@@ -29,7 +29,7 @@ function ph_compute_reference(basis_supercell)
             model_disp = Model(convert(Model{eltype(ε)}, model_supercell); lattice, positions)
             # TODO: Would be cleaner with PR #675.
             basis_disp_bs = PlaneWaveBasis(model_disp; Ecut=5)
-            forces = compute_forces(basis_disp_bs, nothing, nothing)
+            forces = compute_forces(BlochWaves(basis_disp_bs), nothing)
             reduce(hcat, forces)
         end
     end
diff --git a/test/stresses.jl b/test/stresses.jl
index 7ac7678700..d530e0e4f6 100644
--- a/test/stresses.jl
+++ b/test/stresses.jl
@@ -17,14 +17,15 @@
     function recompute_energy(lattice, symmetries, element)
         basis = make_basis(lattice, symmetries, element)
         scfres = self_consistent_field(basis; is_converged=DFTK.ScfConvergenceDensity(1e-13))
-        (; energies) = energy_hamiltonian(basis, scfres.ψ, scfres.occupation; ρ=scfres.ρ)
+        (; energies) = energy_hamiltonian(scfres.ψ, scfres.occupation; ρ=scfres.ρ)
         energies.total
     end
 
     function hellmann_feynman_energy(scfres, lattice, symmetries, element)
         basis = make_basis(lattice, symmetries, element)
-        ρ = DFTK.compute_density(basis, scfres.ψ, scfres.occupation)
-        (; energies) = energy_hamiltonian(basis, scfres.ψ, scfres.occupation; ρ)
+        ψ = BlochWaves(basis, denest(scfres.ψ))
+        ρ = DFTK.compute_density(ψ, scfres.occupation)
+        (; energies) = energy_hamiltonian(ψ, scfres.occupation; ρ)
         energies.total
     end