diff --git a/lya_2pt/tracer.py b/lya_2pt/tracer.py index 25e12cc..e998ea6 100644 --- a/lya_2pt/tracer.py +++ b/lya_2pt/tracer.py @@ -2,7 +2,8 @@ import numpy as np from lya_2pt.constants import ABSORBER_IGM -from lya_2pt.tracer_utils import rebin, project_deltas, get_angle_list, gram_schmidt +from lya_2pt.tracer_utils import ( + rebin, project_deltas, get_angle_list, gram_schmidt, get_orthonormal_vectors_svd) class Tracer: @@ -269,18 +270,24 @@ def rebin(self, rebin_factor, dwave, absorption_line): self.logwave_term = log_lambda - np.sum(log_lambda * weights) / self.sum_weights self.term3_norm = (weights * self.logwave_term**2).sum() - def project(self, old_projection=True): + def project(self, old_projection=True, use_svd=True): """Apply projection matrix to deltas""" assert not self.is_projected, "Tracer already projected" + if old_projection: self.deltas = project_deltas(self.log_lambda, self.deltas, self.weights, self.order) + self.is_projected = True + return + + if use_svd: + basis = get_orthonormal_vectors_svd(self.log_lambda, self.weights, self.order) else: basis = gram_schmidt(self.log_lambda, self.weights, self.order) - for b in basis: - self.deltas -= b * np.dot(b * self.weights, self.deltas) + proj_vec_mat = basis * self.weights + self.deltas -= basis.T.dot(proj_vec_mat.dot(self.deltas)) - self.proj_vec_mat = (basis * self.weights).T + self.proj_vec_mat = proj_vec_mat.T self.is_projected = True diff --git a/lya_2pt/tracer_utils.py b/lya_2pt/tracer_utils.py index 36f85c5..c4c50bc 100644 --- a/lya_2pt/tracer_utils.py +++ b/lya_2pt/tracer_utils.py @@ -189,6 +189,16 @@ def get_projection_matrix(log_lambda, weights, order): return Vh.T, np.eye(weights.size) - Vh.T @ Vh +def get_orthonormal_vectors_svd(log_lambda, weights, order): + wsqrt = np.sqrt(weights) + input_vectors_matrix = np.vander(log_lambda, order + 1).T * wsqrt + Vh = np.linalg.svd(input_vectors_matrix, full_matrices=False)[2] + s = weights != 0 + Vh[~s] = 0 + Vh[s] /= wsqrt[s] + return Vh + + def gram_schmidt(log_lambda, weights, order): basis = [] for n in range(order + 1):