Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Added "enabled" as a feature to FastAPILimiter #12

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/main_disabled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import aioredis
import uvicorn
from fastapi import Depends, FastAPI

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter

app = FastAPI()


@app.on_event("startup")
async def startup():
await FastAPILimiter.init(None, enabled=False)

@app.on_event("shutdown")
async def shutdown():
await FastAPILimiter.close()


@app.get("/", dependencies=[Depends(RateLimiter(times=2, seconds=5))])
async def index():
return {"msg": "Hello World"}

@app.get(
"/multiple",
dependencies=[
Depends(RateLimiter(times=1, seconds=5)),
Depends(RateLimiter(times=2, seconds=15)),
],
)
async def multiple():
return {"msg": "Hello World"}


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
15 changes: 12 additions & 3 deletions fastapi_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FastAPILimiter:
lua_sha: str = None
identifier: Callable = None
callback: Callable = None
enabled: bool = True
lua_script = """local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]
Expand All @@ -61,14 +62,22 @@ async def init(
prefix: str = "fastapi-limiter",
identifier: Callable = default_identifier,
callback: Callable = default_callback,
enabled: bool = True
):
cls.redis = redis
cls.prefix = prefix
cls.identifier = identifier
cls.callback = callback
cls.lua_sha = await redis.script_load(cls.lua_script)
cls.enabled = enabled

if enabled:
cls.lua_sha = await redis.script_load(cls.lua_script)
else:
cls.lua_sha = None

@classmethod
async def close(cls):
cls.redis.close()
await cls.redis.wait_closed()
if cls.enabled:
cls.redis.close()
await cls.redis.wait_closed()

41 changes: 21 additions & 20 deletions fastapi_limiter/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,24 @@ def __init__(
self.callback = callback

async def __call__(self, request: Request, response: Response):
if not FastAPILimiter.redis:
raise Exception("You must call FastAPILimiter.init in startup event of fastapi!")
index = 0
for route in request.app.routes:
if route.path == request.scope["path"]:
for idx, dependency in enumerate(route.dependencies):
if self is dependency.dependency:
index = idx
break
# moved here because constructor run before app startup
identifier = self.identifier or FastAPILimiter.identifier
callback = self.callback or FastAPILimiter.callback
redis = FastAPILimiter.redis
rate_key = await identifier(request)
key = f"{FastAPILimiter.prefix}:{rate_key}:{index}"
pexpire = await redis.evalsha(
FastAPILimiter.lua_sha, keys=[key], args=[self.times, self.milliseconds]
)
if pexpire != 0:
return await callback(request, response, pexpire)
if FastAPILimiter.enabled:
if not FastAPILimiter.redis:
raise Exception("You must call FastAPILimiter.init in startup event of fastapi!")
index = 0
for route in request.app.routes:
if route.path == request.scope["path"]:
for idx, dependency in enumerate(route.dependencies):
if self is dependency.dependency:
index = idx
break
# moved here because constructor run before app startup
identifier = self.identifier or FastAPILimiter.identifier
callback = self.callback or FastAPILimiter.callback
redis = FastAPILimiter.redis
rate_key = await identifier(request)
key = f"{FastAPILimiter.prefix}:{rate_key}:{index}"
pexpire = await redis.evalsha(
FastAPILimiter.lua_sha, keys=[key], args=[self.times, self.milliseconds]
)
if pexpire != 0:
return await callback(request, response, pexpire)
36 changes: 36 additions & 0 deletions tests/test_depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from starlette.testclient import TestClient

from examples.main import app
from examples.main_disabled import app as app_disabled


def test_limiter():
Expand All @@ -19,6 +20,21 @@ def test_limiter():
response = client.get("/")
assert response.status_code == 200

def test_limiter_disabled():
# Runs the same requests as test_limiter, but with RateLimiter disabled
with TestClient(app_disabled) as client:
response = client.get("/")
assert response.status_code == 200

client.get("/")

response = client.get("/")
assert response.status_code == 200
sleep(5)

response = client.get("/")
assert response.status_code == 200


def test_limiter_multiple():
with TestClient(app) as client:
Expand All @@ -38,3 +54,23 @@ def test_limiter_multiple():

response = client.get("/multiple")
assert response.status_code == 200

def test_limiter_multiple_disabled():
# Runs the same requests as test_limiter_multiple, but with RateLimiter disabled
with TestClient(app_disabled) as client:
response = client.get("/multiple")
assert response.status_code == 200

response = client.get("/multiple")
assert response.status_code == 200
sleep(5)

response = client.get("/multiple")
assert response.status_code == 200

response = client.get("/multiple")
assert response.status_code == 200
sleep(10)

response = client.get("/multiple")
assert response.status_code == 200