Skip to content

Commit

Permalink
Merge pull request #50 from northpowered/36-bug-config-file-arg-for-c…
Browse files Browse the repository at this point in the history
…li-not-working

Refactoring and bugfixes
  • Loading branch information
northpowered authored Sep 8, 2022
2 parents 07a182f + 98b0785 commit 429e02a
Show file tree
Hide file tree
Showing 60 changed files with 946 additions and 769 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![codecov](https://codecov.io/gh/northpowered/fastapi-boilerplate/branch/master/graph/badge.svg?token=2E6WMLULD7)](https://codecov.io/gh/northpowered/fastapi-boilerplate)
# FastAPI boilerplate

> Version: 1.1.1
> Version: 1.1.2
Work in progress, please read [issues](https://github.com/northpowered/fastapi-boilerplate/issues)

Expand Down
19 changes: 11 additions & 8 deletions ci/init_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
dsn: str = "postgresql://test:[email protected]:5432/test?sslmode=disable"

engine = PostgresEngine(
config={
'dsn':dsn,
}
)
config={
'dsn': dsn,
}
)
username: str = 'test3'
database_name = 'test'
asyncio.run(engine.run_ddl(f"create user {username} with password '{username}'"))
asyncio.run(engine.run_ddl(f"grant all on database {database_name} to {username};"))
asyncio.run(engine.run_ddl(
f"create user {username} with password '{username}'"))
asyncio.run(engine.run_ddl(
f"grant all on database {database_name} to {username};"))
asyncio.run(engine.run_ddl(f"grant all on schema public to {username};"))
asyncio.run(engine.run_ddl(f"grant ALL ON ALL tables in schema public TO {username};"))
asyncio.run(engine.run_ddl(
f"grant ALL ON ALL tables in schema public TO {username};"))
asyncio.run(engine.run_ddl(f"alter database test owner to {username};"))
asyncio.run(engine.run_ddl(f"alter schema public owner to {username};"))
asyncio.run(engine.run_ddl(f"alter schema public owner to {username};"))
11 changes: 7 additions & 4 deletions ci/init_vault.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import requests # type: ignore
import requests # type: ignore

URL: str = "http://127.0.0.1:8200/"
TOKEN: str = "test"
HEADERS: dict = {'X-Vault-Token':TOKEN}
HEADERS: dict = {'X-Vault-Token': TOKEN}

database_mount: str = "database"
kv_mount: str = "kv_test"
Expand All @@ -11,17 +11,20 @@
db_dsn: str = f"postgresql://test:test@{db_host}/test?sslmode=disable"
db_role: str = "testrole"
kv_secret_name: str = "jwt"
def post(path: str, data: dict)->requests.Response:


def post(path: str, data: dict) -> requests.Response:
return requests.post(
url=f"{URL}{path}",
json=data,
headers=HEADERS
)


""" VAULT DATABASE INIT """

print("Creating database secret engine")
resp = post(f'v1/sys/mounts/{database_mount}',{"type":"database"})
resp = post(f'v1/sys/mounts/{database_mount}', {"type": "database"})
print(f"{resp.status_code} --- {resp.text}")

print("Creating database connection")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastapi-boilerplate"
version = "1.8.0"
version = "1.1.2"
description = ""
authors = ["northpowered <[email protected]>"]

Expand Down
9 changes: 4 additions & 5 deletions src/accounting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from .roles import Role, role_router
from .groups import Group, group_router
from .rbac import (
Permission,
Policy,
M2MUserGroup,
M2MUserRole,
Permission,
Policy,
M2MUserGroup,
M2MUserRole,
rbac_user_router,
rbac_role_router,
rbac_policies_router,
rbac_group_router,
rbac_permissions_router
)
from .authentication import Sessions

2 changes: 1 addition & 1 deletion src/accounting/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .models import Sessions
from .models import Sessions
8 changes: 5 additions & 3 deletions src/accounting/authentication/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .jwt import create_access_token, get_user_by_token
from fastapi import Depends


async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
"""
Base auth endpoint with OAuth2PasswordBearer form
Expand All @@ -21,16 +22,17 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
access_token: str
token_type: str
"""
user = await User.authenticate_user(form_data.username,form_data.password)
user = await User.authenticate_user(form_data.username, form_data.password)
access_token = create_access_token(
data={"sub": user.username}) # type: ignore
data={"sub": user.username}) # type: ignore
return {"access_token": access_token, "token_type": "bearer"}


async def get_current_user(current_user: User = Depends(get_user_by_token)):
"""
Obtaining {USER} object of authenticated user
Returns:
{USER}, see accounting.users
"""
return current_user
return current_user
27 changes: 16 additions & 11 deletions src/accounting/authentication/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")

def create_access_token(data: dict, expires_delta: timedelta | None = None)->str:

def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""
Creates JWT signed token from any payload, with expires time
Expand All @@ -28,10 +29,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None)->str
else:
expire = datetime.utcnow() + timedelta(seconds=config.Security.jwt_ttl)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=config.Security.jwt_algorithm)
encoded_jwt = jwt.encode(to_encode, SECRET_KEY,
algorithm=config.Security.jwt_algorithm)
return encoded_jwt

def decode_access_token(token: str)->dict:

def decode_access_token(token: str) -> dict:
"""
Decode string JWT token
Expand All @@ -45,13 +48,15 @@ def decode_access_token(token: str)->dict:
dict: extracted payload
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=config.Security._available_jwt_algorithms)
except JWTError as ex:
payload = jwt.decode(
token, SECRET_KEY, algorithms=config.Security._available_jwt_algorithms)
except JWTError:
raise UnauthorizedException('Cannot decode token')
else:
return payload

async def get_user_by_token(token: str = Depends(oauth2_scheme))->User:

async def get_user_by_token(token: str = Depends(oauth2_scheme)) -> User:
"""
Returns USER data for username from token, if exists
Expand All @@ -62,16 +67,16 @@ async def get_user_by_token(token: str = Depends(oauth2_scheme))->User:
User: see accounting.users
"""
payload: dict = decode_access_token(token)
username: str = payload.get('sub',str())
return await User.get_by_username(username) # type: ignore
username: str = payload.get('sub', str())
return await User.get_by_username(username) # type: ignore


def decode_auth_header(header: str)->tuple[str,str]:
def decode_auth_header(header: str) -> tuple[str, str]:
try:
chunks: list = header.split(' ')
assert len(chunks) == 2, 'Bad header'
return(chunks[0],chunks[1])
return (chunks[0], chunks[1])
except AssertionError as ex:
raise UnauthorizedException(str(ex))
except IndexError:
raise UnauthorizedException('Wrong header payload')

44 changes: 22 additions & 22 deletions src/accounting/authentication/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
import secrets
from loguru import logger
from typing import TypeVar, Type, Optional, cast
from piccolo.columns import Timestamp
from piccolo.columns.defaults.timestamp import TimestampOffset
from piccolo.columns.column_types import Text, Boolean, Timestamp, Timestamptz
from piccolo.columns.column_types import Text, Timestamp
from piccolo_api.session_auth.tables import SessionsBase
from asyncpg.exceptions import UniqueViolationError
from utils.exceptions import IntegrityException, ObjectNotFoundException, BaseBadRequestException
from piccolo.utils.sync import run_sync
from configuration import config
from accounting.users import T_U

T_S = TypeVar('T_S', bound='Sessions')


class Sessions(SessionsBase, tablename="sessions"):
"""
INHERITED from SessionsBase
Expand All @@ -24,8 +22,8 @@ class Sessions(SessionsBase, tablename="sessions"):
require login again, or just create a new session token.
"""

token = Text(length=100, null=False) # type: ignore
user_id = Text(null=False) # type: ignore
token = Text(length=100, null=False) # type: ignore
user_id = Text(null=False) # type: ignore
expiry_date: Timestamp | datetime.datetime = Timestamp(
default=TimestampOffset(hours=1), null=False
)
Expand All @@ -34,15 +32,15 @@ class Sessions(SessionsBase, tablename="sessions"):
)

@classmethod
async def create_session( # type: ignore
async def create_session( # type: ignore
cls: Type[T_S],
user_id: str,
expiry_date: Optional[datetime.datetime] = None,
max_expiry_date: Optional[datetime.datetime] = None,
) -> Type[T_S]:
while True:
token: str = secrets.token_urlsafe(nbytes=32)
if not await cls.exists().where(cls.token == token).run(): # type: ignore
if not await cls.exists().where(cls.token == token).run(): # type: ignore
break

session = cls(token=token, user_id=user_id)
Expand All @@ -52,16 +50,16 @@ async def create_session( # type: ignore
session.max_expiry_date = max_expiry_date

await session.save().run()
return session # type: ignore
return session # type: ignore

@classmethod
def create_session_sync( # type: ignore
def create_session_sync( # type: ignore
cls, user_id: str, expiry_date: Optional[datetime.datetime] = None
) -> Type[T_S]:
return run_sync(cls.create_session(user_id, expiry_date))

@classmethod
async def get_user_id( # type: ignore
async def get_user_id( # type: ignore
cls, token: str, increase_expiry: Optional[datetime.timedelta] = None
) -> Optional[str]:
"""
Expand All @@ -73,33 +71,35 @@ async def get_user_id( # type: ignore
happens. The `max_expiry_date` remains the same, so there's a hard
limit on how long a session can be used for.
"""
session: Type[T_S] = ( # type: ignore
await cls.objects().where(cls.token == token).first().run() # type: ignore
session: Type[T_S] = ( # type: ignore
await cls.objects().where(cls.token == token).first().run() # type: ignore
)
if not session:
return None
now = datetime.datetime.now()
if (session.expiry_date > now) and (session.max_expiry_date > now): # type: ignore
if (session.expiry_date > now) and (session.max_expiry_date > now): # type: ignore
if increase_expiry and (
cast(datetime.datetime, session.expiry_date) - now < increase_expiry # type: ignore
cast(datetime.datetime, session.expiry_date) -
now < increase_expiry # type: ignore
):
session.expiry_date = ( # type: ignore
cast(datetime.datetime, session.expiry_date) + increase_expiry # type: ignore
session.expiry_date = ( # type: ignore
cast(datetime.datetime, session.expiry_date) + \
increase_expiry # type: ignore
)
await session.save().run() # type: ignore
await session.save().run() # type: ignore

return cast(Optional[str], session.user_id) # type: ignore
return cast(Optional[str], session.user_id) # type: ignore
else:
return None

@classmethod
def get_user_id_sync(cls, token: str) -> Optional[str]: # type: ignore
def get_user_id_sync(cls, token: str) -> Optional[str]: # type: ignore
return run_sync(cls.get_user_id(token))

@classmethod
async def remove_session(cls, token: str):
await cls.delete().where(cls.token == token).run() # type: ignore
await cls.delete().where(cls.token == token).run() # type: ignore

@classmethod
def remove_session_sync(cls, token: str):
return run_sync(cls.remove_session(token))
return run_sync(cls.remove_session(token))
25 changes: 10 additions & 15 deletions src/accounting/authentication/routing.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
from utils.api_versioning import APIRouter, APIVersion
from . import endpoints
from .jwt import oauth2_scheme
from .schemas import (
Token
)
from fastapi.security import OAuth2PasswordRequestForm
from accounting.schemas import UserRead
from fastapi import Depends
auth_router = APIRouter(
prefix="/auth",
tags=["AAA->Authentication"],
responses={
404: {"description": "URL not found"},
400: {"description": "Bad request"}
},
#version=APIVersion(1)
},
)

auth_router.add_api_route(
'/token',
endpoints.login_for_access_token,
'/token',
endpoints.login_for_access_token,
response_model=Token,
summary='Authenticate via JWT Bearer scheme',
methods=['post'],
#dependencies=[Depends(OAuth2PasswordRequestForm)]
)
summary='Authenticate via JWT Bearer scheme',
methods=['post']
)

auth_router.add_api_route(
'/me',
endpoints.get_current_user,
'/me',
endpoints.get_current_user,
response_model=UserRead,
summary='Get current user',
summary='Get current user',
methods=['get']
)
)
1 change: 1 addition & 0 deletions src/accounting/authentication/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel


class Token(BaseModel):
"""READ model for obtaining JWT"""
access_token: str
Expand Down
Loading

0 comments on commit 429e02a

Please sign in to comment.