Skip to content

Commit

Permalink
Fixing test 3 units
Browse files Browse the repository at this point in the history
  • Loading branch information
LiamCPinchbeck committed Sep 4, 2024
1 parent 69f8b14 commit 02b3540
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_3_irfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from gammabayes.likelihoods.irfs import IRF_LogLikelihood

import numpy as np
from astropy import units as u

def test_irf_input_output():
energy_true_axis, longitudeaxistrue, latitudeaxistrue = np.logspace(-1,2,31), np.linspace(-5,5,21), np.linspace(-4,4,17)
energy_true_axis, longitudeaxistrue, latitudeaxistrue = np.logspace(-1,2,31)*u.TeV, np.linspace(-5,5,21)*u.deg, np.linspace(-4,4,17)*u.deg

energy_recon_axis, longitudeaxis, latitudeaxis = np.logspace(-1,2,16), np.linspace(-5,5,11), np.linspace(-4,4,9)
energy_recon_axis, longitudeaxis, latitudeaxis = np.logspace(-1,2,16)*u.TeV, np.linspace(-5,5,11)*u.deg, np.linspace(-4,4,9)*u.deg


irf_loglike = IRF_LogLikelihood(axes=[energy_recon_axis, longitudeaxis, latitudeaxis],
Expand All @@ -20,14 +21,14 @@ def test_irf_input_output():
zenith=40,
hemisphere='North',
prod_vers=5)
recon_lon, recon_lat, true_energy, true_lon, true_lat = np.asarray(0.), np.asarray(0.), np.asarray(1.0), np.asarray(0.), np.asarray(0.)
recon_lon, recon_lat, true_energy, true_lon, true_lat = np.asarray(0.)*u.deg, np.asarray(0.)*u.deg, np.asarray(1.0)*u.TeV, np.asarray(0.)*u.deg, np.asarray(0.)*u.deg

result = np.squeeze(irf_loglike.log_psf(recon_lon=recon_lon, recon_lat=recon_lat, true_energy=true_energy, true_lon=true_lon, true_lat=true_lat))
print(result)
assert np.isneginf(result)


recon_lon, recon_lat, true_energy, true_lon, true_lat = np.asarray(10.), np.asarray(10.), np.asarray(1.0), np.asarray(10.), np.asarray(10.)
recon_lon, recon_lat, true_energy, true_lon, true_lat = np.asarray(10.)*u.deg, np.asarray(10.)*u.deg, np.asarray(1.0)*u.TeV, np.asarray(10.)*u.deg, np.asarray(10.)*u.deg

result = np.squeeze(irf_loglike.log_psf(recon_lon=recon_lon, recon_lat=recon_lat, true_energy=true_energy, true_lon=true_lon, true_lat=true_lat))
print(result)
Expand Down

0 comments on commit 02b3540

Please sign in to comment.