Skip to content

Commit

Permalink
Accept MAST URIs as input to get_cloud_uris, add jwst to supported cl…
Browse files Browse the repository at this point in the history
…oud missions
  • Loading branch information
snbianco committed Jan 28, 2025
1 parent dca9873 commit 51605cb
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 46 deletions.
21 changes: 10 additions & 11 deletions astroquery/mast/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(self, provider="AWS", profile=None, verbose=False):
import boto3
import botocore

self.supported_missions = ["mast:hst/product", "mast:tess/product", "mast:kepler", "mast:galex", "mast:ps1"]
self.supported_missions = ["mast:hst/product", "mast:tess/product", "mast:kepler", "mast:galex", "mast:ps1",

Check warning on line 55 in astroquery/mast/cloud.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/cloud.py#L55

Added line #L55 was not covered by tests
"mast:jwst/product"]

self.boto3 = boto3
self.botocore = botocore
Expand All @@ -77,11 +78,7 @@ def is_supported(self, data_product):
response : bool
Is the product from a supported mission.
"""

for mission in self.supported_missions:
if data_product['dataURI'].lower().startswith(mission):
return True
return False
return any(data_product['dataURI'].lower().startswith(mission) for mission in self.supported_missions)

Check warning on line 81 in astroquery/mast/cloud.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/cloud.py#L81

Added line #L81 was not covered by tests

def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
"""
Expand All @@ -92,7 +89,7 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
Parameters
----------
data_product : `~astropy.table.Row`
data_product : `~astropy.table.Row`, str
Product to be converted into cloud data uri.
include_bucket : bool
Default True. When false returns the path of the file relative to the
Expand All @@ -108,6 +105,8 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
Cloud URI generated from the data product. If the product cannot be
found in the cloud, None is returned.
"""
# If data_product is a string, convert to a list
data_product = [data_product] if isinstance(data_product, str) else data_product

Check warning on line 109 in astroquery/mast/cloud.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/cloud.py#L109

Added line #L109 was not covered by tests

uri_list = self.get_cloud_uri_list(data_product, include_bucket=include_bucket, full_url=full_url)

Expand All @@ -124,8 +123,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
Parameters
----------
data_products : `~astropy.table.Table`
Table containing products to be converted into cloud data uris.
data_products : `~astropy.table.Table`, list
Table containing products or list of MAST uris to be converted into cloud data uris.
include_bucket : bool
Default True. When false returns the path of the file relative to the
top level cloud storage location.
Expand All @@ -141,8 +140,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
if data_products includes products not found in the cloud.
"""
s3_client = self.boto3.client('s3', config=self.config)

paths = utils.mast_relative_path(data_products["dataURI"])
data_uris = data_products if isinstance(data_products, list) else data_products['dataURI']
paths = utils.mast_relative_path(data_uris)

Check warning on line 144 in astroquery/mast/cloud.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/cloud.py#L143-L144

Added lines #L143 - L144 were not covered by tests
if isinstance(paths, str): # Handle the case where only one product was requested
paths = [paths]

Expand Down
9 changes: 5 additions & 4 deletions astroquery/mast/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,15 @@ def _parse_result(self, response, *, verbose=False): # Used by the async_to_syn

if self.service == self._search:
results = self._service_api_connection._parse_result(response, verbose, data_key='results')

# Warn if maximum results are returned
if len(results) >= self.limit:
warnings.warn("Maximum results returned, may not include all sources within radius.",
MaxResultsWarning)
elif self.service == self._list_products:
# Results from post_list_products endpoint need to be handled differently
results = Table(response.json()['products'])

if len(results) >= self.limit:
warnings.warn("Maximum results returned, may not include all sources within radius.",
MaxResultsWarning)

return results

def _validate_criteria(self, **criteria):
Expand Down
23 changes: 15 additions & 8 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,9 +854,9 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
Parameters
----------
data_products : `~astropy.table.Table`
Table containing products to be converted into cloud data uris. If provided, this will supercede
page_size, page, or any keyword arguments passed in as criteria.
data_products : `~astropy.table.Table`, list
Table containing products or list of MAST uris to be converted into cloud data uris.
If provided, this will supercede page_size, page, or any keyword arguments passed in as criteria.
include_bucket : bool
Default True. When False, returns the path of the file relative to the
top level cloud storage location.
Expand Down Expand Up @@ -920,16 +920,23 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
# Return list of associated data products
data_products = self.get_product_list(obs)

# Filter product list
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension, **filter_products)
if isinstance(data_products, Table):

Check warning on line 923 in astroquery/mast/observations.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/observations.py#L923

Added line #L923 was not covered by tests
# Filter product list
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension,

Check warning on line 925 in astroquery/mast/observations.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/observations.py#L925

Added line #L925 was not covered by tests
**filter_products)
else: # data_products is a list of URIs
# Warn if trying to supply filters
if filter_products or extension or mrp_only:
warnings.warn('Filtering is not supported when providing a list of MAST URIs. '

Check warning on line 930 in astroquery/mast/observations.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/observations.py#L929-L930

Added lines #L929 - L930 were not covered by tests
'To apply filters, please provide query criteria or a table of data products '
'as returned by `Observations.get_product_list`', InputWarning)

if not len(data_products):
warnings.warn("No matching products to fetch associated cloud URIs.", NoResultsWarning)
warnings.warn('No matching products to fetch associated cloud URIs.', NoResultsWarning)

Check warning on line 935 in astroquery/mast/observations.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/observations.py#L935

Added line #L935 was not covered by tests
return

# Remove duplicate products
data_products = utils.remove_duplicate_products(data_products, 'dataURI')

return self._cloud_connection.get_cloud_uri_list(data_products, include_bucket, full_url)

def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
Expand All @@ -941,7 +948,7 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
Parameters
----------
data_product : `~astropy.table.Row`
data_product : `~astropy.table.Row`, str
Product to be converted into cloud data uri.
include_bucket : bool
Default True. When false returns the path of the file relative to the
Expand Down
60 changes: 42 additions & 18 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,20 +692,6 @@ def test_observations_download_products_no_duplicates(self, tmp_path, caplog, ms
with caplog.at_level("INFO", logger="astroquery"):
assert "products were duplicates" in caplog.text

def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table):

# Get a product list with 6 duplicate JWST MSA config files
products = msa_product_table

assert len(products) == 6

# enable access to public AWS S3 bucket
Observations.enable_cloud_dataset(provider='AWS')

# Check that only one URI is returned
uris = Observations.get_cloud_uris(products)
assert len(uris) == 1

def test_observations_download_file(self, tmp_path):

def check_result(result, path):
Expand Down Expand Up @@ -776,7 +762,7 @@ def test_observations_download_file_escaped(self, tmp_path):
"s3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/"
"rings.v3.skycell.1334.061.stk.r.unconv.exp.fits")
])
def test_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
def test_observations_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
pytest.importorskip("boto3")
# get a product list
product = Table()
Expand All @@ -790,13 +776,17 @@ def test_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
assert len(uri) > 0, f'Product for dataURI {test_data_uri} was not found in the cloud.'
assert uri == expected_cloud_uri, f'Cloud URI does not match expected. ({uri} != {expected_cloud_uri})'

# pass the URI as a string
uri = Observations.get_cloud_uri(test_data_uri)
assert uri == expected_cloud_uri, f'Cloud URI does not match expected. ({uri} != {expected_cloud_uri})'

@pytest.mark.parametrize("test_obs_id", ["25568122", "31411", "107604081"])
def test_get_cloud_uris(self, test_obs_id):
def test_observations_get_cloud_uris(self, test_obs_id):
pytest.importorskip("boto3")

# get a product list
index = 24 if test_obs_id == '25568122' else 0
products = Observations.get_product_list(test_obs_id)[index:]
products = Observations.get_product_list(test_obs_id)[index:index + 2]

assert len(products) > 0, (f'No products found for OBSID {test_obs_id}. '
'Unable to move forward with getting URIs from the cloud.')
Expand All @@ -814,7 +804,28 @@ def test_get_cloud_uris(self, test_obs_id):
Observations.get_cloud_uris(products,
extension='png')

def test_get_cloud_uris_query(self):
def test_observations_get_cloud_uris_list_input(self):
uri_list = ['mast:HST/product/u24r0102t_c1f.fits',
'mast:PS1/product/rings.v3.skycell.1334.061.stk.r.unconv.exp.fits']
expected = ['s3://stpubdata/hst/public/u24r/u24r0102t/u24r0102t_c1f.fits',
's3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/rings.v3.skycell.1334.'
'061.stk.r.unconv.exp.fits']

# list of URI strings as input
uris = Observations.get_cloud_uris(uri_list)
assert len(uris) > 0, f'Products for URI list {uri_list} were not found in the cloud.'
assert uris == expected

# check for warning if filters are provided with list input
with pytest.warns(InputWarning, match='Filtering is not supported'):
Observations.get_cloud_uris(uri_list,
extension='png')

# check for warning if one of the URIs is not found
with pytest.warns(NoResultsWarning, match='Failed to retrieve MAST relative path'):
Observations.get_cloud_uris(['mast:HST/product/does_not_exist.fits'])

def test_observations_get_cloud_uris_query(self):
pytest.importorskip("boto3")

# enable access to public AWS S3 bucket
Expand All @@ -839,6 +850,19 @@ def test_get_cloud_uris_query(self):
with pytest.warns(NoResultsWarning):
Observations.get_cloud_uris(target_name=234295611)

def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table):
# Get a product list with 6 duplicate JWST MSA config files
products = msa_product_table

assert len(products) == 6

# enable access to public AWS S3 bucket
Observations.enable_cloud_dataset(provider='AWS')

# Check that only one URI is returned
uris = Observations.get_cloud_uris(products)
assert len(uris) == 1

######################
# CatalogClass tests #
######################
Expand Down
26 changes: 21 additions & 5 deletions astroquery/mast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Miscellaneous functions used throughout the MAST module.
"""

import warnings
import numpy as np

import requests
Expand All @@ -14,11 +15,11 @@
from urllib import parse

import astropy.coordinates as coord
from astropy.table import unique
from astropy.table import unique, Table

from .. import log
from ..version import version
from ..exceptions import ResolverError, InvalidQueryError
from ..exceptions import NoResultsWarning, ResolverError, InvalidQueryError
from ..utils import commons

from . import conf
Expand Down Expand Up @@ -192,6 +193,9 @@ def mast_relative_path(mast_uri):
# ("uri", "/path/to/product")
# so we index for path (index=1)
path = json_response.get(uri[1])["path"]
if path is None:
warnings.warn(f"Failed to retrieve MAST relative path for {uri[1]}. Skipping...", NoResultsWarning)
continue

Check warning on line 198 in astroquery/mast/utils.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/utils.py#L196-L198

Added lines #L196 - L198 were not covered by tests
if 'galex' in path:
path = path.lstrip("/mast/")
elif '/ps1/' in path:
Expand All @@ -218,19 +222,31 @@ def _split_list_into_chunks(input_list, chunk_size):
def remove_duplicate_products(data_products, uri_key):
"""
Removes duplicate data products that have the same data URI.
Parameters
----------
data_products : `~astropy.table.Table`
Table containing products to be checked for duplicates.
data_products : `~astropy.table.Table`, list
Table containing products or list of URIs to be checked for duplicates.
uri_key : str
Column name representing the URI of a product.
Returns
-------
unique_products : `~astropy.table.Table`
Table containing products with unique dataURIs.
"""
# Get unique products based on input type
if isinstance(data_products, Table):
unique_products = unique(data_products, keys=uri_key)
else: # data_products is a list
seen = set()
unique_products = []
for uri in data_products:
if uri not in seen:
seen.add(uri)
unique_products.append(uri)

Check warning on line 247 in astroquery/mast/utils.py

View check run for this annotation

Codecov / codecov/patch

astroquery/mast/utils.py#L242-L247

Added lines #L242 - L247 were not covered by tests

number = len(data_products)
unique_products = unique(data_products, keys=uri_key)
number_unique = len(unique_products)
if number_unique < number:
log.info(f"{number - number_unique} of {number} products were duplicates. "
Expand Down

0 comments on commit 51605cb

Please sign in to comment.