Skip to content

Commit 25251e7

Browse files
committed
revert to flask
1 parent 06bfa96 commit 25251e7

File tree

3 files changed

+63
-48
lines changed

3 files changed

+63
-48
lines changed

.flaskenv

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
FLASK_APP=server.app
2+
FLASK_RUN_PORT=8000
3+
FLASK_RUN_HOST=0.0.0.0

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ json_logic==0.6.3
66
sentence_transformers==3.2.1
77
uvicorn==0.32.0
88
xgboost==2.1.2
9-
model-provenance==0.1.0
9+
model-provenance==0.1.0
10+
Flask~=2.3.2

server/app.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,90 @@
1-
import os
21
import sys
3-
import asyncio
42
import signal
3+
import threading
54

6-
from fastapi import FastAPI, HTTPException, Request, Depends
7-
from fastapi.responses import JSONResponse
5+
from server import config
6+
from server import data_stream
7+
8+
from flask import Flask, jsonify, request
89

9-
from server import config, data_stream
1010
from server.algos import algos
1111
from server.data_filter import operations_callback
1212

13-
app = FastAPI()
14-
15-
stream_stop_event = asyncio.Event()
13+
app = Flask(__name__)
1614

17-
async def start_data_stream():
18-
await data_stream.run(config.SERVICE_DID, operations_callback, stream_stop_event)
15+
stream_stop_event = threading.Event()
16+
stream_thread = threading.Thread(
17+
target=data_stream.run, args=(config.SERVICE_DID, operations_callback, stream_stop_event,)
18+
)
19+
stream_thread.start()
1920

20-
# Conditionally start the data stream on startup
21-
if os.getenv("ENABLE_DATA_STREAM") == "true":
22-
@app.on_event("startup")
23-
async def startup_event():
24-
asyncio.create_task(start_data_stream())
2521

26-
# Signal handler for graceful shutdown
2722
def sigint_handler(*_):
2823
print('Stopping data stream...')
2924
stream_stop_event.set()
3025
sys.exit(0)
3126

27+
3228
signal.signal(signal.SIGINT, sigint_handler)
3329

34-
@app.get("/")
35-
async def index():
36-
return "ATProto Feed Generator powered by The AT Protocol SDK for Python (https://github.com/MarshalX/atproto)."
3730

38-
@app.get("/.well-known/did.json")
39-
async def did_json():
31+
@app.route('/')
32+
def index():
33+
return 'ATProto Feed Generator powered by The AT Protocol SDK for Python (https://github.com/MarshalX/atproto).'
34+
35+
36+
@app.route('/.well-known/did.json', methods=['GET'])
37+
def did_json():
4038
if not config.SERVICE_DID.endswith(config.HOSTNAME):
41-
raise HTTPException(status_code=404, detail="Not Found")
39+
return '', 404
4240

43-
response_content = {
44-
"@context": ["https://www.w3.org/ns/did/v1"],
45-
"id": config.SERVICE_DID,
46-
"service": [
41+
return jsonify({
42+
'@context': ['https://www.w3.org/ns/did/v1'],
43+
'id': config.SERVICE_DID,
44+
'service': [
4745
{
48-
"id": "#bsky_fg",
49-
"type": "BskyFeedGenerator",
50-
"serviceEndpoint": f"https://{config.HOSTNAME}"
46+
'id': '#bsky_fg',
47+
'type': 'BskyFeedGenerator',
48+
'serviceEndpoint': f'https://{config.HOSTNAME}'
5149
}
5250
]
53-
}
54-
return JSONResponse(content=response_content)
55-
56-
@app.get("/xrpc/app.bsky.feed.describeFeedGenerator")
57-
async def describe_feed_generator():
58-
feeds = [{"uri": uri} for uri in algos.keys()]
59-
response_content = {
60-
"encoding": "application/json",
61-
"body": {
62-
"did": config.SERVICE_DID,
63-
"feeds": feeds
51+
})
52+
53+
54+
@app.route('/xrpc/app.bsky.feed.describeFeedGenerator', methods=['GET'])
55+
def describe_feed_generator():
56+
feeds = [{'uri': uri} for uri in algos.keys()]
57+
response = {
58+
'encoding': 'application/json',
59+
'body': {
60+
'did': config.SERVICE_DID,
61+
'feeds': feeds
6462
}
6563
}
66-
return JSONResponse(content=response_content)
64+
return jsonify(response)
65+
6766

68-
@app.get("/xrpc/app.bsky.feed.getFeedSkeleton")
69-
async def get_feed_skeleton(feed: str = None, cursor: str = None, limit: int = 20):
67+
@app.route('/xrpc/app.bsky.feed.getFeedSkeleton', methods=['GET'])
68+
def get_feed_skeleton():
69+
feed = request.args.get('feed', default=None, type=str)
7070
algo = algos.get(feed)
7171
if not algo:
72-
raise HTTPException(status_code=400, detail="Unsupported algorithm")
72+
return 'Unsupported algorithm', 400
73+
74+
# Example of how to check auth if giving user-specific results:
75+
"""
76+
from server.auth import AuthorizationError, validate_auth
77+
try:
78+
requester_did = validate_auth(request)
79+
except AuthorizationError:
80+
return 'Unauthorized', 401
81+
"""
7382

7483
try:
75-
body = await algo(cursor, limit)
84+
cursor = request.args.get('cursor', default=None, type=str)
85+
limit = request.args.get('limit', default=20, type=int)
86+
body = algo(cursor, limit)
7687
except ValueError:
77-
raise HTTPException(status_code=400, detail="Malformed cursor")
88+
return 'Malformed cursor', 400
7889

79-
return JSONResponse(content=body)
90+
return jsonify(body)

0 commit comments

Comments
 (0)