diff --git a/main.py b/main.py index 351aaf5..af74cce 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import logging +from seatable_thumbnail import DBSession from seatable_thumbnail.serializers import ThumbnailSerializer from seatable_thumbnail.permissions import ThumbnailPermission from seatable_thumbnail.thumbnail import Thumbnail @@ -34,12 +35,15 @@ async def __call__(self, scope, receive, send): # ===== thumbnail ===== elif 'thumbnail/' == request.url[:10]: + db_session = DBSession() + # serializer try: - serializer = ThumbnailSerializer(request) + serializer = ThumbnailSerializer(db_session, request) thumbnail_info = serializer.thumbnail_info except Exception as e: logger.exception(e) + db_session.close() response_start, response_body = gen_error_response( 400, 'Bad request.') await send(response_start) @@ -48,8 +52,9 @@ async def __call__(self, scope, receive, send): # permission try: - permission = ThumbnailPermission(**thumbnail_info) + permission = ThumbnailPermission(db_session, **thumbnail_info) if not permission.check(): + db_session.close() response_start, response_body = gen_error_response( 403, 'Forbidden.') await send(response_start) @@ -57,12 +62,15 @@ async def __call__(self, scope, receive, send): return except Exception as e: logger.exception(e) + db_session.close() response_start, response_body = gen_error_response( 500, 'Internal server error.') await send(response_start) await send(response_body) return + db_session.close() + # cache try: etag = thumbnail_info.get('etag') diff --git a/seatable_thumbnail/__init__.py b/seatable_thumbnail/__init__.py index 0986957..c8b546f 100644 --- a/seatable_thumbnail/__init__.py +++ b/seatable_thumbnail/__init__.py @@ -12,4 +12,3 @@ engine = create_engine(db_url, **db_kwargs) Base = declarative_base() DBSession = sessionmaker(bind=engine) -db_session = DBSession() diff --git a/seatable_thumbnail/permissions.py b/seatable_thumbnail/permissions.py index cdc1023..78cad26 100644 --- a/seatable_thumbnail/permissions.py +++ b/seatable_thumbnail/permissions.py @@ -1,5 +1,4 @@ from seaserv import ccnet_api -from seatable_thumbnail import db_session from seatable_thumbnail.models import DTables, DTableShare, \ DTableGroupShare, DTableViewUserShare, DTableViewGroupShare, \ DTableExternalLinks @@ -7,9 +6,10 @@ class ThumbnailPermission(object): - def __init__(self, **info): + def __init__(self, db_session, **info): + self.db_session = db_session self.__dict__.update(info) - self.dtable = db_session.query( + self.dtable = self.db_session.query( DTables).filter_by(uuid=self.dtable_uuid).first() def check(self): @@ -54,7 +54,7 @@ def check_dtable_permission(self): return PERMISSION_READ_WRITE if dtable: # check user's all permissions from `share`, `group-share` and checkout higher one - dtable_share = db_session.query( + dtable_share = self.db_session.query( DTableShare).filter_by(dtable_id=dtable.id, to_user=username).first() if dtable_share and dtable_share.permission == PERMISSION_READ_WRITE: return dtable_share.permission @@ -65,7 +65,7 @@ def check_dtable_permission(self): else: groups = ccnet_api.get_groups(username, return_ancestors=True) group_ids = [group.id for group in groups] - group_permissions = db_session.query( + group_permissions = self.db_session.query( DTableGroupShare.permission).filter(DTableGroupShare.dtable_id == dtable.id, DTableGroupShare.group_id.in_(group_ids)).all() for group_permission in group_permissions: @@ -91,7 +91,7 @@ def get_user_view_share_permission(self): username = self.username dtable = self.dtable - view_share = db_session.query( + view_share = self.db_session.query( DTableViewUserShare).filter_by(dtable_id=dtable.id, to_user=username).order_by(DTableViewUserShare.permission.desc()).first() if not view_share: return '' @@ -103,7 +103,7 @@ def get_group_view_share_permission(self): username = self.username dtable = self.dtable - view_shares = db_session.query( + view_shares = self.db_session.query( DTableViewGroupShare).filter_by(dtable_id=dtable.id).order_by(DTableViewGroupShare.permission.desc()).all() target_view_share = None diff --git a/seatable_thumbnail/serializers.py b/seatable_thumbnail/serializers.py index d57a1eb..60d6b89 100644 --- a/seatable_thumbnail/serializers.py +++ b/seatable_thumbnail/serializers.py @@ -5,7 +5,6 @@ from email.utils import formatdate from seaserv import seafile_api -from seatable_thumbnail import db_session import seatable_thumbnail.settings as settings from seatable_thumbnail.constants import FILE_EXT_TYPE_MAP, \ IMAGE, PSD, VIDEO, XMIND @@ -13,14 +12,14 @@ class ThumbnailSerializer(object): - def __init__(self, request): + def __init__(self, db_session, request): + self.db_session = db_session self.request = request self.check() self.gen_thumbnail_info() def check(self): self.params_check() - db_session.commit() # clear db session cache self.session_check() self.resource_check() self.gen_thumbnail_info() @@ -40,7 +39,7 @@ def parse_django_session(self, session_data): def session_check(self): session_key = self.request.cookies[settings.SESSION_KEY] - django_session = db_session.query( + django_session = self.db_session.query( DjangoSession).filter_by(session_key=session_key).first() self.session_data = self.parse_django_session(django_session.session_data) @@ -102,7 +101,7 @@ def resource_check(self): file_path = self.params['file_path'] size = self.params['size'] - workspace = db_session.query( + workspace = self.db_session.query( Workspaces).filter_by(id=workspace_id).first() repo_id = workspace.repo_id workspace_owner = workspace.owner