Skip to content

Commit a8e2edb

Browse files
authored
Merge pull request #54 from Small-Bodies-Node/queue-position-in-catch-route
Queue position in catch route
2 parents 39539e2 + edb89a6 commit a8e2edb

25 files changed

+297
-142
lines changed

live-tests/test_alive.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,32 @@
1010

1111
def parse_args() -> argparse.Namespace:
1212
parser = argparse.ArgumentParser()
13-
parser.add_argument('--url', default='http://127.0.0.1:5000')
14-
parser.add_argument('--target', default='65P')
15-
parser.add_argument('--source', '-s', default=None, action='append')
16-
parser.add_argument('--no-cached', '--force', '-f',
17-
action='store_false', dest='cached')
18-
parser.add_argument('-v', dest='verbose', action='store_true',
19-
help='verbose mode')
13+
parser.add_argument("--url", default="http://127.0.0.1:5000")
14+
parser.add_argument("--target", default="65P")
15+
parser.add_argument("--source", "-s", default=None, action="append")
16+
parser.add_argument(
17+
"--no-cached", "--force", "-f", action="store_false", dest="cached"
18+
)
19+
parser.add_argument("-v", dest="verbose", action="store_true", help="verbose mode")
2020
return parser.parse_args()
2121

2222

2323
def query(args: argparse.Namespace) -> Tuple[str, float, Any]:
2424
start = time.monotonic()
2525

26-
params = {
27-
'target': args.target,
28-
'cached': args.cached
29-
}
26+
params = {"target": args.target, "cached": args.cached}
3027
if args.source is not None:
31-
params['sources'] = args.source
28+
params["sources"] = args.source
3229

33-
res = requests.get(f'{args.url}/catch', params=params)
30+
res = requests.get(f"{args.url}/catch", params=params)
3431
data = res.json()
3532
if args.verbose:
3633
print(data)
3734
print()
3835

39-
if data['queued']:
40-
messages = SSEClient(data['message_stream'], chunk_size=128)
36+
if data["queued"]:
37+
print(f"Queue position: {data['queue_position']}")
38+
messages = SSEClient(data["message_stream"], chunk_size=128)
4139
for message in messages:
4240
if len(message.data) == 0:
4341
continue
@@ -51,58 +49,61 @@ def query(args: argparse.Namespace) -> Tuple[str, float, Any]:
5149
if not isinstance(message_data, dict):
5250
continue
5351

54-
if message_data['job_prefix'] != data['job_id'][:8]:
52+
if message_data["job_prefix"] != data["job_id"][:8]:
5553
# this mesage is not for us
5654
continue
5755

5856
# this message is for us, print the text
59-
print(message_data['text'], file=sys.stderr)
57+
print(message_data["text"], file=sys.stderr)
6058

6159
# Message status may be 'success', 'error', 'running', 'queued'.
62-
if message_data['status'] == 'error':
63-
raise Exception(message_data['text'])
60+
if message_data["status"] == "error":
61+
raise Exception(message_data["text"])
6462

65-
if message_data['status'] == 'success':
63+
if message_data["status"] == "success":
6664
break
6765

68-
elif data['results'] is None:
69-
raise Exception(data['message'])
66+
elif data["results"] is None:
67+
raise Exception(data["message"])
7068

7169
# 'results' is the URL to the search results
72-
res = requests.get(data['results'])
70+
res = requests.get(data["results"])
7371
dt = time.monotonic() - start
7472

7573
# response is JSON formatted
76-
return data['results'], dt, res.json()
74+
return data["results"], dt, res.json()
7775

7876

7977
def summarize(results_url: str, dt: float, data: Any) -> None:
8078
count_by_survey = []
81-
for survey in set([row['source'] for row in data['data']]):
82-
n = len([row for row in data['data']
83-
if row['source'] == survey])
84-
count_by_survey.append(f' - {n} {survey}\n')
79+
for survey in set([row["source"] for row in data["data"]]):
80+
n = len([row for row in data["data"] if row["source"] == survey])
81+
count_by_survey.append(f" - {n} {survey}\n")
8582

86-
print(f'''
83+
print(
84+
f"""
8785
Job ID: {data['job_id']}
8886
Results: {results_url}
8987
Elapsed time: {dt:.1f} s
90-
Count: {data['count']}''')
88+
Count: {data['count']}"""
89+
)
9190

9291
if len(count_by_survey) > 0:
93-
print(''.join(count_by_survey))
92+
print("".join(count_by_survey))
9493

9594

96-
if __name__ == '__main__':
95+
if __name__ == "__main__":
9796
args = parse_args()
9897

99-
print(f'''
98+
print(
99+
f"""
100100
CATCH APIs Query
101101
102102
Base URL: {args.url}
103103
Target: {args.target}
104104
Cached results allowed? {args.cached}
105-
''')
105+
"""
106+
)
106107

107108
results_url, dt, data = query(args)
108109
summarize(results_url, dt, data)

live-tests/test_query_fixed.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
import numpy as np
1515
from catch_apis.app import app
1616

17+
1718
@pytest.fixture()
1819
def test_client() -> TestClient:
1920
return app.test_client()
2021

22+
2123
def test_point_full_search(test_client: TestClient):
2224
# Crab nebula
2325
parameters = {
@@ -31,13 +33,36 @@ def test_point_full_search(test_client: TestClient):
3133
assert results["message"] == ""
3234
assert np.isclose(results["query"]["ra"], (5 + (34 + 32 / 60) / 60) * 15)
3335
assert np.isclose(results["query"]["dec"], 22 + 48 / 60 / 60)
34-
assert "neat_palomar_tricam" in results["query"]["sources"]
35-
assert len(results["data"]) == 428
36+
assert set([row["source"] for row in results["data"]]) == {
37+
"atlas_haleakela",
38+
"atlas_mauna_loa",
39+
"catalina_bigelow",
40+
"catalina_lemmon",
41+
"neat_maui_geodss",
42+
"ps1dr2",
43+
"spacewatch",
44+
}
45+
46+
# only count static data sets
47+
sources = [row["source"] for row in results["data"]]
48+
expected = {
49+
"neat_maui_geodss": 1,
50+
"neat_palomar_tricam": 0,
51+
"ps1dr2": 73,
52+
"spacewatch": 166,
53+
}
54+
for source, count in expected.items():
55+
assert sources.count(source) == count
56+
3657
product_ids = [obs["product_id"] for obs in results["data"]]
3758

38-
# these are verified to be the Crab:
59+
# verified to be the Crab:
3960
assert "rings.v3.skycell.1784.059.wrp.g.55560_46188.fits" in product_ids
40-
assert "urn:nasa:pds:gbo.ast.catalina.survey:data_calibrated:703_20211108_2b_n24018_01_0002.arch" in product_ids
61+
assert (
62+
"urn:nasa:pds:gbo.ast.catalina.survey:data_calibrated:703_20211108_2b_n24018_01_0002.arch"
63+
in product_ids
64+
)
65+
4166

4267
def test_point_date_range(test_client: TestClient):
4368
# Crab nebula
@@ -51,30 +76,31 @@ def test_point_date_range(test_client: TestClient):
5176
results = response.json()
5277
assert results["query"]["start_date"] is None
5378
assert results["query"]["stop_date"] is None
54-
assert len(results["data"]) == 167
79+
assert len(results["data"]) == 166
5580

5681
parameters["start_date"] = "2007-01-01"
5782
response = test_client.get("/fixed", params=parameters)
5883
response.raise_for_status()
5984
results = response.json()
6085
assert results["query"]["start_date"] == "2007-01-01 00:00:00.000"
61-
assert len(results["data"]) == 118
86+
assert len(results["data"]) == 117
6287

6388
parameters["stop_date"] = "2008-01-01"
6489
response = test_client.get("/fixed", params=parameters)
6590
response.raise_for_status()
6691
results = response.json()
6792
assert results["query"]["stop_date"] == "2008-01-01 00:00:00.000"
6893
assert len(results["data"]) == 24
69-
94+
7095
del parameters["start_date"]
7196
response = test_client.get("/fixed", params=parameters)
7297
response.raise_for_status()
7398
results = response.json()
7499
assert results["query"]["start_date"] is None
75100
assert len(results["data"]) == 73
76101

77-
# 118 + 73 - 24 = 167
102+
# 117 + 73 - 24 = 166
103+
78104

79105
def test_areal_search(test_client: TestClient):
80106
# Crab nebula
@@ -107,4 +133,4 @@ def test_areal_search(test_client: TestClient):
107133
response = test_client.get("/fixed", params=parameters)
108134
response.raise_for_status()
109135
results = response.json()
110-
assert len(results["data"]) == 73
136+
assert len(results["data"]) == 73

live-tests/test_query_moving.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ENV.REDIS_TASK_MESSAGES = "TASK_MESSAGES_TESTING"
2727

2828
from catch_apis.app import app
29-
from catch_apis.services.stream import messages as stream_messages
29+
from catch_apis.services.stream import messages_service
3030
from catch_apis import woRQer
3131

3232

@@ -39,8 +39,8 @@ def test_client() -> TestClient:
3939
def mock_stream_messages(timeout):
4040
"""Patch message stream to timeout in an absolute amount of time."""
4141
with mock.patch(
42-
"catch_apis.services.stream.messages",
43-
partial(stream_messages, timeout),
42+
"catch_apis.services.stream.messages_service",
43+
partial(messages_service, timeout),
4444
):
4545
yield
4646

@@ -124,7 +124,7 @@ def _query(
124124
# hang until the timeout is reached, run a worker in "burst" mode, then
125125
# directly read messages from the message function
126126
woRQer.run(True)
127-
messages = [message for message in stream_messages(1)]
127+
messages = [message for message in messages_service(1)]
128128
for message in messages:
129129
if len(message) == 0 or not message.startswith("data:"):
130130
continue
@@ -160,11 +160,11 @@ def _query(
160160
def test_equivalencies(test_client: TestClient, targets: List[str]) -> None:
161161
source = "neat_maui_geodss"
162162
catch0, caught0, queued0 = _query(test_client, targets[0], True, source=source)
163-
data0 = caught0["data"]
163+
data0 = sorted(caught0["data"], key=lambda row: row["product_id"])
164164
assert len(data0) > 0
165165
for target in targets[1:]:
166166
catch, caught, queued = _query(test_client, target, True, source=source)
167-
data = caught["data"]
167+
data = sorted(caught["data"], key=lambda row: row["product_id"])
168168
for a, b in zip(data0, data):
169169
for k in COMPARE_KEYS:
170170
# np.isclose using rtol = 1% in case of ephemeris updates and

src/catch_apis/api/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
from .catch import catch # noqa: F401
2-
from .caught import caught # noqa: F401
3-
from .fixed import fixed_target_query # noqa: F401
4-

src/catch_apis/api/catch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from astropy.time import Time
1010

1111
from ..config import allowed_sources, get_logger, QueryStatus
12-
from .. import services
12+
from ..services.target_name import parse_target_name
13+
from ..services.catch import catch_service
14+
from ..services.queue import JobsQueue
1315
from ..services.message import (
1416
Message,
1517
listen_for_task_messages,
@@ -32,7 +34,7 @@ def _format_date(date):
3234
return date if date is None else date.iso
3335

3436

35-
def catch(
37+
def catch_controller(
3638
target: str,
3739
sources: Optional[List[str]] = None,
3840
start_date: Optional[str] = None,
@@ -75,7 +77,7 @@ def catch(
7577

7678
target_type: str
7779
sanitized_target: str
78-
target_type, sanitized_target = services.parse_target_name(target)
80+
target_type, sanitized_target = parse_target_name(target)
7981
if sanitized_target == "":
8082
messages.append("Invalid target: empty string")
8183
valid_query = False
@@ -117,14 +119,15 @@ def catch(
117119
"job_id": job_id.hex,
118120
"queued": False,
119121
"queue_full": False,
122+
"queue_position": None,
120123
"message": None,
121124
"version": version,
122125
}
123126

124127
Message.reset_t0()
125128
listen_for_task_messages(job_id)
126129

127-
status: QueryStatus = services.catch(
130+
status: QueryStatus = catch_service(
128131
job_id,
129132
sanitized_target,
130133
sources=_sources,
@@ -142,8 +145,13 @@ def catch(
142145
message_stream_url: str = urllib.parse.urlunsplit(
143146
(parsed[0], parsed[1], os.path.join(parsed[2], "stream"), "", "")
144147
)
145-
146148
if status == QueryStatus.QUEUED:
149+
queue = JobsQueue()
150+
for job in queue.jobs:
151+
if job.args[0].hex == job_id.hex:
152+
result["queue_position"] = job.get_position()
153+
break
154+
147155
result["queued"] = True
148156
result["message_stream"] = message_stream_url
149157
result["results"] = results_url

src/catch_apis/api/caught.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
import uuid
2-
from typing import Any, Dict, List, Tuple, Union
32

4-
from .. import services
3+
from ..services.caught import caught_service
4+
from ..services.status.job_id import job_id_service
55
from .. import __version__ as version
66

77

8-
def caught(job_id: str) -> Union[dict, Tuple[str, int]]:
8+
def caught_controller(job_id: str) -> dict | tuple[str, int]:
99
"""Controller for returning caught data."""
1010

1111
try:
1212
_job_id: uuid.UUID = uuid.UUID(job_id, version=4)
1313
except ValueError:
1414
return "Invalid job ID", 400
1515

16-
parameters, status = services.status.job_id(_job_id)
17-
data: List[Dict[str, Any]] = services.caught(_job_id)
16+
parameters, status = job_id_service(_job_id)
17+
data = caught_service(_job_id)
1818
return {
1919
"parameters": parameters,
2020
"status": status,

src/catch_apis/api/fixed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from astropy.coordinates import Angle
88
import astropy.units as u
99

10-
from .. import services
10+
from ..services.fixed import fixed_target_query_service
1111
from ..config import CatchApisException, get_logger, allowed_sources
1212
from .. import __version__ as version
1313

@@ -60,7 +60,7 @@ def _format_date(date):
6060
return date if date is None else date.iso
6161

6262

63-
def fixed_target_query(
63+
def fixed_target_query_controller(
6464
ra: str,
6565
dec: str,
6666
sources: Optional[List[str]] = None,
@@ -128,7 +128,7 @@ def fixed_target_query(
128128
data: List[dict] = []
129129
if valid_query:
130130
try:
131-
data = services.fixed_target_query(
131+
data = fixed_target_query_service(
132132
job_id,
133133
sanitized_ra,
134134
sanitized_dec,

0 commit comments

Comments
 (0)