Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:Vaultexe/server into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedabdou14 committed Mar 11, 2024
2 parents c23c8fb + a64def6 commit efebd0e
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 239 deletions.
8 changes: 4 additions & 4 deletions app/api/deps/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ async def get_current_user(
at_claim = AccessTokenClaim.from_encoded(token)

rt_key = cache.keys.refresh_token(at_claim.sub, req_device_id)
rt_claim = await cache.repo.get_token(
rc=rc,
rt_claim = await cache.tokens.get(
rc,
key=rt_key,
token_cls=RefreshTokenClaim,
)
Expand Down Expand Up @@ -158,8 +158,8 @@ async def get_refresh_user(
raise AuthenticationException

validator_rt_key = cache.keys.refresh_token(rt_claim.sub, req_device_id)
validator_rt = await cache.repo.get_token(
rc=rc,
validator_rt = await cache.tokens.get(
rc,
key=validator_rt_key,
token_cls=RefreshTokenClaim,
)
Expand Down
2 changes: 1 addition & 1 deletion app/api/deps/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def get_sync_redis_conn() -> redis.Redis:
return redis.Redis(connection_pool=mq_sync_redis_pool)
return redis.Redis(connection_pool=mq_sync_redis_pool, decode_responses=True)


def get_mq_low() -> rq.Queue:
Expand Down
12 changes: 6 additions & 6 deletions app/api/routes/v1/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async def otp_login(
if not req_device_id:
raise AuthenticationException

otp_sh_claim = await cache.repo.get_token(
otp_sh_claim = await cache.tokens.get(
rc,
key=cache.keys.otp_shash_token(user.id),
token_cls=schemas.OTPSaltedHashClaim,
Expand All @@ -259,7 +259,7 @@ async def otp_login(
if otp_sh_claim.ip != ip:
raise AuthenticationException

await cache.repo.delete_token(rc, key=cache.keys.otp_shash_token(user.id))
await rc.unlink(cache.keys.otp_shash_token(user.id))

await repo.device.verify(db, id=req_device_id)
await db.commit()
Expand All @@ -284,7 +284,7 @@ async def logout(
* Deletes access & refresh token cookies
"""
key = cache.keys.refresh_token(user.id, req_device_id)
await cache.repo.delete_token(rc, key=key)
await rc.unlink(key)

res.delete_cookie(key=CookieKey.ACCESS_TOKEN)
res.delete_cookie(key=CookieKey.REFRESH_TOKEN)
Expand All @@ -310,7 +310,7 @@ async def logout_all(
devices = await repo.device.get_logged_in_devices(db, user_id=user.id)
keys = [cache.keys.refresh_token(user.id, device.id) for device in devices]

await cache.repo.delete_many_tokens(rc, keys=keys)
await rc.unlink(*keys)

res.delete_cookie(key=CookieKey.ACCESS_TOKEN)
res.delete_cookie(key=CookieKey.REFRESH_TOKEN)
Expand Down Expand Up @@ -349,7 +349,7 @@ async def grant_web_token(
is_admin=user.is_admin,
)

await cache.repo.save_token_claim(
await cache.tokens.save(
rc,
keep_ttl=is_refresh,
token_claim=rt_claim,
Expand Down Expand Up @@ -407,7 +407,7 @@ async def grant_autherization_code(
subject=user.id,
)

await cache.repo.save_token_claim(
await cache.tokens.save(
rc,
key=cache.keys.otp_shash_token(user.id),
token_claim=otp_sh_claim,
Expand Down
2 changes: 1 addition & 1 deletion app/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import keys
from .repo import repo
from .service import tokens
196 changes: 6 additions & 190 deletions app/cache/client.py
Original file line number Diff line number Diff line change
@@ -1,197 +1,13 @@
from collections.abc import AsyncGenerator, Mapping
from contextlib import asynccontextmanager
from typing import Any, overload

import redis.asyncio as aioredis
from redis.asyncio.client import Pipeline

from app.core.config import settings
from app.utils.coders import JsonCoder

async_redis_pool = aioredis.ConnectionPool.from_url(str(settings.REDIS_URI))


class AsyncRedisClient:
"""
Redis client wrapper.
"""

def __init__(self) -> None:
self.redis: aioredis.Redis = aioredis.Redis(connection_pool=async_redis_pool)

@overload
async def set(self, key: str, value: Any, *, ttl: int) -> bool:
...

@overload
async def set(self, key: str, value: Any, *, keepttl: bool) -> bool:
...

async def set(
self,
key: str,
value: Any,
ttl: int | None = None,
keepttl: bool | None = None,
) -> bool:
"""
Set a key-value pair in Redis.
Either ttl or keepttl must be specified but not both.
Overloads:
set(self, key: str, value: Any, *, ttl: int) -> bool
set(self, key: str, value: Any, *, keepttl: bool) -> bool
Args:
key (str): Key to set.
value (Any): Value to set.
ttl (int, optional): Expiration time in seconds. Defaults to None.
keepttl (bool, optional): Keep the TTL of the key. Defaults to False.
Returns:
bool: True if successful, False otherwise.
"""
value = self.encode(value)

if ttl is not None and keepttl is not None:
raise ValueError("Either ttl or keepttl must be specified but not both.")
elif ttl is not None:
res = await self.redis.set(key, value, ex=ttl)
elif keepttl is not None:
ttl = await self.redis.ttl(key)
if ttl == -2:
# key does not exist
return False
res = await self.redis.set(key, value, keepttl=keepttl)
else:
raise ValueError("Either ttl or keepttl must be specified but not both.")

return True if res else False

async def set_many(self, key_value_pairs: Mapping[str, Any], ttl: int) -> bool:
"""
Set multiple key-value pairs in Redis.
Returns:
True if successful, False otherwise.
"""
async with self.pipeline() as pipe:
for key, value in key_value_pairs.items():
value = self.encode(value)
pipe.set(key, value, ex=ttl)
return bool(await pipe.execute())

async def get(self, key: str) -> Any | None:
"""
Get a value from Redis.
Returns:
Value if successful, None otherwise.
"""
value = await self.redis.get(key)
return self.decode(value)

async def get_with_ttl(self, key: str) -> tuple[int, Any | None]:
"""
Get a value from Redis with its TTL.
Returns:
- (ttl, value) if successful
- (-2, None) if key does not exist
- (-1, None) if key has no expire set
"""
async with self.pipeline() as pipe:
pipe.ttl(key)
pipe.get(key)
ttl, value = await pipe.execute()
return ttl, self.decode(value)

async def get_all_startswith(
self,
key_prefix: str,
) -> dict[str, Any]:
"""
Get all key-value pairs starting with a given prefix.
Returns:
Dictionary of key-value pairs.
"""
keys = await self.redis.keys(f"{key_prefix}*")
values = await self.redis.mget(keys)
values = [self.decode(value) for value in values]
return dict(zip(keys, values, strict=True))

async def delete(self, key: str) -> int:
"""
Delete a key from Redis.
Returns:
True if successful, False otherwise.
"""
return await self.redis.unlink(key)

async def delete_many(self, keys: list[str]) -> int:
"""
Delete multiple keys from Redis.
Returns:
Number of keys deleted.
"""
return await self.redis.unlink(*keys)

async def delete_all_startswith(self, key_prefix: str) -> int:
"""
Delete all keys starting with a given prefix.
Returns:
Number of keys deleted.
"""
keys = await self.get_all_startswith(key_prefix)
return await self.redis.delete(*keys)

async def flushall(self) -> bool:
"""Delete all cache"""
return await self.redis.flushall()

async def exists(self, key: str) -> bool:
"""Check if a key exists in cache"""
return await self.redis.exists(key) == 1

async def expire(self, key: str, expire: int) -> bool:
"""
Set a key's expiration time in Redis.
Returns:
True if successful, False otherwise.
"""
return await self.redis.expire(key, expire)

async def ttl(self, key: str) -> int:
"""
Get a key's TTL in Redis.
Returns:
- ttl if key exists
- -2 if key does not exist
- -1 if key has no expire set
"""
return await self.redis.ttl(key)

@asynccontextmanager
async def pipeline(self, transactional: bool = True) -> AsyncGenerator[Pipeline, None]:
"""
Create a pipeline for Redis commands.
Returns:
Redis pipeline.
"""
yield self.redis.pipeline(transaction=transactional)

def encode(self, value: Any) -> bytes:
return JsonCoder.encode(value)

def decode(self, value: bytes | None) -> Any:
if value is None:
return None
return JsonCoder.decode(value)
class AsyncRedisClient(aioredis.Redis):
def __init__(self):
super().__init__(
connection_pool=async_redis_pool,
decode_responses=True,
)
1 change: 1 addition & 0 deletions app/cache/service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tokens import tokens
46 changes: 13 additions & 33 deletions app/cache/repo.py → app/cache/service/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from app.schemas import OTPSaltedHashClaim, RefreshTokenClaim, TokenBase


class CacheRepo:
"""Cache repo"""

async def save_token_claim(
class TokensService:
async def save(
self,
rc: AsyncRedisClient,
*,
key: str,
token_claim: TokenBase,
keep_ttl: bool = False,
ttl: int | None = None,
) -> bool:
"""
Cache token claim.
Expand All @@ -22,26 +21,25 @@ async def save_token_claim(
If the keep_ttl is set to True, the ttl is not changed when the token is updated.
However if the token is not found in the cache, the ttl is set to the token's ttl.
If the hash_key is set to True, the key will be md5 hashed before being cached.
Returns:
bool: True if successful, False otherwise.
int: ttl of the token.
"""
value = token_claim.model_dump_json()

if keep_ttl:
exists = await rc.redis.set(
exists = await rc.set(
key,
token_claim.model_dump_json(),
value,
xx=True,
keepttl=True,
)
if exists:
return True

ttl = self._get_token_type_ttl(type(token_claim))

return await rc.set(key, token_claim, ttl=ttl)
ttl = ttl or self._get_token_type_ttl(type(token_claim))
return await rc.set(key, value, ex=ttl)

async def get_token[T: TokenBase](
async def get[T: TokenBase](
self,
rc: AsyncRedisClient,
*,
Expand All @@ -50,25 +48,7 @@ async def get_token[T: TokenBase](
) -> T | None:
"""Get token claim"""
token = await rc.get(key)
return token_cls.model_validate(token) if token else None

async def delete_token(
self,
rc: AsyncRedisClient,
*,
key: str,
) -> bool:
"""Delete token claim from redis"""
return bool(await rc.delete(key))

async def delete_many_tokens(
self,
rc: AsyncRedisClient,
*,
keys: list[str],
) -> int:
"""Delete many tokens from redis"""
return await rc.delete_many(keys)
return token_cls.model_validate_json(token) if token else None

def _get_token_type_ttl(self, token_cls: type[TokenBase]) -> int:
ttl = {
Expand All @@ -83,4 +63,4 @@ def _get_token_type_ttl(self, token_cls: type[TokenBase]) -> int:
return ttl


repo = CacheRepo()
tokens = TokensService()
2 changes: 1 addition & 1 deletion app/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def init_redis_connection() -> None:
try:
async with timeout(TIMEOUT):
redis = AsyncRedisClient()
await redis.redis.ping()
await redis.ping()
except Exception as e:
logger.error("--- Connection to Redis failed ---")
logger.log(logging.ERROR, e)
Expand Down
Loading

0 comments on commit efebd0e

Please sign in to comment.