Skip to content

Commit

Permalink
Add typings (#900)
Browse files Browse the repository at this point in the history
  • Loading branch information
Quantum-0 authored Aug 29, 2023
1 parent 106100a commit fad541f
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 61 deletions.
4 changes: 2 additions & 2 deletions aiopg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class _PoolConnectionContextManager:

__slots__ = ("_pool", "_conn")

def __init__(self, pool: "Pool", conn: Connection):
def __init__(self, pool: "Pool", conn: Connection) -> None:
self._pool: Optional[Pool] = pool
self._conn: Optional[Connection] = conn

Expand Down Expand Up @@ -130,7 +130,7 @@ class _PoolCursorContextManager:

__slots__ = ("_pool", "_conn", "_cursor")

def __init__(self, pool: "Pool", conn: Connection, cursor: Cursor):
def __init__(self, pool: "Pool", conn: Connection, cursor: Cursor) -> None:
self._pool = pool
self._conn = conn
self._cursor = cursor
Expand Down
21 changes: 11 additions & 10 deletions aiopg/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.sql.dml import UpdateBase

from ..connection import Cursor
from ..utils import _ContextManager, _IterableContextManager
from . import exc
from .result import ResultProxy
Expand Down Expand Up @@ -43,7 +44,7 @@ class SAConnection:
"_query_compile_kwargs",
)

def __init__(self, connection, engine):
def __init__(self, connection, engine) -> None:
self._connection = connection
self._transaction = None
self._savepoint_seq = 0
Expand All @@ -52,7 +53,7 @@ def __init__(self, connection, engine):
self._cursors = weakref.WeakSet()
self._query_compile_kwargs = dict(self._QUERY_COMPILE_KWARGS)

def execute(self, query, *multiparams, **params):
def execute(self, query, *multiparams, **params) -> _IterableContextManager[ResultProxy]:
"""Executes a SQL query with optional parameters.
query - a SQL query string or any sqlalchemy expression.
Expand Down Expand Up @@ -92,18 +93,18 @@ def execute(self, query, *multiparams, **params):
coro = self._execute(query, *multiparams, **params)
return _IterableContextManager[ResultProxy](coro, _close_result_proxy)

async def _open_cursor(self):
async def _open_cursor(self) -> Cursor:
if self._connection is None:
raise exc.ResourceClosedError("This connection is closed.")
cursor = await self._connection.cursor()
self._cursors.add(cursor)
return cursor

def _close_cursor(self, cursor):
def _close_cursor(self, cursor) -> None:
self._cursors.remove(cursor)
cursor.close()

async def _execute(self, query, *multiparams, **params):
async def _execute(self, query, *multiparams, **params) -> ResultProxy:
cursor = await self._open_cursor()
dp = _distill_params(multiparams, params)
if len(dp) > 1:
Expand Down Expand Up @@ -181,7 +182,7 @@ async def scalar(self, query, *multiparams, **params):
return await res.scalar()

@property
def closed(self):
def closed(self) -> bool:
"""The readonly property that returns True if connections is closed."""
return self.connection is None or self.connection.closed

Expand Down Expand Up @@ -231,7 +232,7 @@ def begin(self, isolation_level=None, readonly=False, deferrable=False):
coro, _commit_transaction_if_active, _rollback_transaction
)

async def _begin(self, isolation_level, readonly, deferrable):
async def _begin(self, isolation_level, readonly, deferrable) -> Transaction:
if self._transaction is None:
self._transaction = RootTransaction(self)
await self._begin_impl(isolation_level, readonly, deferrable)
Expand Down Expand Up @@ -377,11 +378,11 @@ async def commit_prepared(self, xid, *, is_prepared=True):
await self._commit_impl()

@property
def in_transaction(self):
def in_transaction(self) -> bool:
"""Return True if a transaction is in progress."""
return self._transaction is not None and self._transaction.is_active

async def close(self):
async def close(self) -> None:
"""Close this SAConnection.
This results in a release of the underlying database
Expand All @@ -401,7 +402,7 @@ async def close(self):

await asyncio.shield(self._close())

async def _close(self):
async def _close(self) -> None:
if self._transaction is not None:
with contextlib.suppress(Exception):
await self._transaction.rollback()
Expand Down
36 changes: 20 additions & 16 deletions aiopg/sa/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import asyncio
import json

from sqlalchemy.dialects.postgresql.base import PGDialect

import aiopg

from ..connection import TIMEOUT
Expand Down Expand Up @@ -41,7 +45,7 @@ def _exec_default(self, default):
return default.arg


def get_dialect(json_serializer=json.dumps, json_deserializer=lambda x: x):
def get_dialect(json_serializer=json.dumps, json_deserializer=lambda x: x) -> PGDialect:
dialect = PGDialect_psycopg2(
json_serializer=json_serializer, json_deserializer=json_deserializer
)
Expand Down Expand Up @@ -69,7 +73,7 @@ def create_engine(
timeout=TIMEOUT,
pool_recycle=-1,
**kwargs
):
) -> _ContextManager[Engine]:
"""A coroutine for Engine creation.
Returns Engine instance with embedded connection pool.
Expand Down Expand Up @@ -98,7 +102,7 @@ async def _create_engine(
timeout=TIMEOUT,
pool_recycle=-1,
**kwargs
):
) -> Engine:

pool = await aiopg.create_pool(
dsn,
Expand All @@ -116,7 +120,7 @@ async def _create_engine(
await pool.release(conn)


async def _close_engine(engine: "Engine") -> None:
async def _close_engine(engine: Engine) -> None:
engine.close()
await engine.wait_closed()

Expand All @@ -136,19 +140,19 @@ class Engine:

__slots__ = ("_dialect", "_pool", "_dsn", "_loop")

def __init__(self, dialect, pool, dsn):
def __init__(self, dialect, pool, dsn) -> None:
self._dialect = dialect
self._pool = pool
self._dsn = dsn
self._loop = get_running_loop()

@property
def dialect(self):
def dialect(self) -> PGDialect:
"""An dialect for engine."""
return self._dialect

@property
def name(self):
def name(self) -> str:
"""A name of the dialect."""
return self._dialect.name

Expand Down Expand Up @@ -186,15 +190,15 @@ def freesize(self):
def closed(self):
return self._pool.closed

def close(self):
def close(self) -> None:
"""Close engine.
Mark all engine connections to be closed on getting back to pool.
Closed engine doesn't allow to acquire new connections.
"""
self._pool.close()

def terminate(self):
def terminate(self) -> None:
"""Terminate engine.
Terminate engine pool with instantly closing all acquired
Expand All @@ -206,12 +210,12 @@ async def wait_closed(self):
"""Wait for closing all engine's connections."""
await self._pool.wait_closed()

def acquire(self):
def acquire(self) -> _ContextManager[SAConnection]:
"""Get a connection from pool."""
coro = self._acquire()
return _ContextManager[SAConnection](coro, _close_connection)

async def _acquire(self):
async def _acquire(self) -> SAConnection:
raw = await self._pool.acquire()
return SAConnection(raw, self)

Expand Down Expand Up @@ -244,10 +248,10 @@ def __await__(self):
conn = yield from self._acquire().__await__()
return _ConnectionContextManager(conn, self._loop)

async def __aenter__(self):
async def __aenter__(self) -> Engine:
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
await self.wait_closed()

Expand All @@ -269,13 +273,13 @@ class _ConnectionContextManager:

__slots__ = ("_conn", "_loop")

def __init__(self, conn: SAConnection, loop: asyncio.AbstractEventLoop):
def __init__(self, conn: SAConnection, loop: asyncio.AbstractEventLoop) -> None:
self._conn = conn
self._loop = loop

def __enter__(self):
def __enter__(self) -> SAConnection:
return self._conn

def __exit__(self, *args):
def __exit__(self, *args) -> None:
asyncio.ensure_future(self._conn.close(), loop=self._loop)
self._conn = None
Loading

0 comments on commit fad541f

Please sign in to comment.