Skip to content

Commit

Permalink
refactor(client): move connect / disconnect to base clients
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jan 21, 2024
1 parent d27e436 commit 01e24f6
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 203 deletions.
144 changes: 143 additions & 1 deletion src/prisma/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from typing_extensions import Self

from ._types import Datasource, HttpConfig, TransactionId, DatasourceOverride
from .engine import BaseAbstractEngine, SyncAbstractEngine, AsyncAbstractEngine
from .engine import (
SyncQueryEngine,
AsyncQueryEngine,
BaseAbstractEngine,
SyncAbstractEngine,
AsyncAbstractEngine,
)
from .errors import ClientNotConnectedError, ClientNotRegisteredError
from ._compat import removeprefix
from ._registry import get_client
Expand Down Expand Up @@ -215,10 +221,146 @@ def _make_sqlite_url(self, url: str, *, relative_to: Path | None = None) -> str:

return f'file:{relative_to.joinpath(url_path).resolve()}'

def _prepare_connect_args(
self,
*,
timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
) -> tuple[timedelta, list[DatasourceOverride] | None]:
"""Returns (timeout, datasources) to be passed to `AbstractEngine.connect()`"""
if isinstance(timeout, UseClientDefault):
timeout = self._connect_timeout

if isinstance(timeout, int):
message = (
'Passing an int as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instance instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)

datasources: list[DatasourceOverride] | None = None
if self._datasource is not None:
ds = self._datasource.copy()
ds.setdefault('name', self._default_datasource['name'])
datasources = [ds]
elif self._active_provider == 'sqlite':
# Override the default SQLite path to protect against
# https://github.com/RobertCraigie/prisma-client-py/issues/409
datasources = [self._make_sqlite_datasource()]

return timeout, datasources


class SyncBasePrisma(BasePrisma[SyncAbstractEngine]):
__slots__ = ()

def _create_engine(self, dml_path: Path | None = None) -> SyncAbstractEngine:
if self._engine_type == EngineType.binary:
return SyncQueryEngine(
dml_path=dml_path or self._packaged_schema_path,
log_queries=self._log_queries,
http_config=self._http_config,
)

raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')

@property
def _engine_class(self) -> type[SyncAbstractEngine]:
if self._engine_type == EngineType.binary:
return SyncQueryEngine

raise RuntimeError(f'Unhandled engine type: {self._engine_type}')

def connect(
self,
timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
) -> None:
"""Connect to the Prisma query engine.
It is required to call this before accessing data.
"""
if self._internal_engine is None:
self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)

timeout, datasources = self._prepare_connect_args(timeout=timeout)

self._internal_engine.connect(
timeout=timeout,
datasources=datasources,
)

def disconnect(self, timeout: float | timedelta | None = None) -> None:
"""Disconnect the Prisma query engine."""
if self._internal_engine is not None:
engine = self._internal_engine
self._internal_engine = None

if isinstance(timeout, (int, float)):
message = (
'Passing a number as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)

engine.close(timeout=timeout)
engine.stop(timeout=timeout)


class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]):
__slots__ = ()

def _create_engine(self, dml_path: Path | None = None) -> AsyncAbstractEngine:
if self._engine_type == EngineType.binary:
return AsyncQueryEngine(
dml_path=dml_path or self._packaged_schema_path,
log_queries=self._log_queries,
http_config=self._http_config,
)

raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')

@property
def _engine_class(self) -> type[AsyncAbstractEngine]:
if self._engine_type == EngineType.binary:
return AsyncQueryEngine

raise RuntimeError(f'Unhandled engine type: {self._engine_type}')

async def connect(
self,
timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
) -> None:
"""Connect to the Prisma query engine.
It is required to call this before accessing data.
"""
if self._internal_engine is None:
self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)

timeout, datasources = self._prepare_connect_args(timeout=timeout)

await self._internal_engine.connect(
timeout=timeout,
datasources=datasources,
)

async def disconnect(self, timeout: float | timedelta | None = None) -> None:
"""Disconnect the Prisma query engine."""
if self._internal_engine is not None:
engine = self._internal_engine
self._internal_engine = None

if isinstance(timeout, (int, float)):
message = (
'Passing a number as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)

await engine.aclose(timeout=timeout)
engine.stop(timeout=timeout)
6 changes: 6 additions & 0 deletions src/prisma/engine/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def connect(
datasources: list[DatasourceOverride] | None = None,
) -> None:
log.debug('Connecting to query engine')
if datasources:
log.debug('Datasources: %s', datasources)

if self.process is not None:
raise errors.AlreadyConnectedError('Already connected to the query engine')

Expand Down Expand Up @@ -333,6 +336,9 @@ async def connect(
datasources: list[DatasourceOverride] | None = None,
) -> None:
log.debug('Connecting to query engine')
if datasources:
log.debug('Datasources: %s', datasources)

if self.process is not None:
raise errors.AlreadyConnectedError('Already connected to the query engine')

Expand Down
74 changes: 0 additions & 74 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ from ._base_client import BasePrisma, UseClientDefault, USE_CLIENT_DEFAULT
from .types import DatasourceOverride, HttpConfig, MetricsFormat
from ._types import BaseModelT, PrismaMethod, TransactionId
from .bases import _PrismaModel
from .engine import AbstractEngine, QueryEngine
from .builder import QueryBuilder, dumps
from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths
from ._compat import removeprefix, model_parse
Expand Down Expand Up @@ -126,66 +125,6 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
self.disconnect()
{% endif %}

{{ maybe_async_def }}connect(
self,
timeout: Union[int, timedelta, UseClientDefault] = USE_CLIENT_DEFAULT,
) -> None:
"""Connect to the Prisma query engine.

It is required to call this before accessing data.
"""
if isinstance(timeout, UseClientDefault):
timeout = self._connect_timeout

if isinstance(timeout, int):
message = (
'Passing an int as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instance instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)

if self._internal_engine is None:
self._internal_engine = self._create_engine(dml_path=PACKAGED_SCHEMA_PATH)

datasources: Optional[List[types.DatasourceOverride]] = None
if self._datasource is not None:
ds = self._datasource.copy()
ds.setdefault('name', '{{ datasources[0].name }}')
datasources = [ds]
{% if active_provider == 'sqlite' %}
else:
# Override the default SQLite path to protect against
# https://github.com/RobertCraigie/prisma-client-py/issues/409
datasources = [self._make_sqlite_datasource()]
{% endif %}

{{ maybe_await }}self._internal_engine.connect(
timeout=timeout,
datasources=datasources,
)

{{ maybe_async_def }}disconnect(self, timeout: Union[float, timedelta, None] = None) -> None:
"""Disconnect the Prisma query engine."""
if self._internal_engine is not None:
engine = self._internal_engine
self._internal_engine = None
if isinstance(timeout, (int, float)):
message = (
'Passing a number as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)
{% if is_async %}
await engine.aclose(timeout=timeout)
{% else %}
engine.close(timeout=timeout)
{% endif %}
engine.stop(timeout=timeout)

{% if active_provider != 'mongodb' %}
{{ maybe_async_def }}execute_raw(self, query: LiteralString, *args: Any) -> int:
resp = {{ maybe_await }}self._execute(
Expand Down Expand Up @@ -372,19 +311,6 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
)
return {{ maybe_await }}self._engine.query(builder.build(), tx_id=self._tx_id)

def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine:
if ENGINE_TYPE == EngineType.binary:
return QueryEngine(dml_path=dml_path, log_queries=self._log_queries, http_config=self._http_config)

raise NotImplementedError(f'Unsupported engine type: {ENGINE_TYPE}')

@property
def _engine_class(self) -> Type[AbstractEngine]:
if ENGINE_TYPE == EngineType.binary:
return QueryEngine
else: # pragma: no cover
raise RuntimeError(f'Unhandled engine type: {ENGINE_TYPE}')


class TransactionManager:
"""Context manager for wrapping a Prisma instance within a transaction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ from ._base_client import BasePrisma, UseClientDefault, USE_CLIENT_DEFAULT
from .types import DatasourceOverride, HttpConfig, MetricsFormat
from ._types import BaseModelT, PrismaMethod, TransactionId
from .bases import _PrismaModel
from .engine import AbstractEngine, QueryEngine
from .builder import QueryBuilder, dumps
from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths
from ._compat import removeprefix, model_parse
Expand Down Expand Up @@ -176,56 +175,6 @@ class Prisma(AsyncBasePrisma):
if self.is_connected():
await self.disconnect()

async def connect(
self,
timeout: Union[int, timedelta, UseClientDefault] = USE_CLIENT_DEFAULT,
) -> None:
"""Connect to the Prisma query engine.

It is required to call this before accessing data.
"""
if isinstance(timeout, UseClientDefault):
timeout = self._connect_timeout

if isinstance(timeout, int):
message = (
'Passing an int as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instance instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)

if self._internal_engine is None:
self._internal_engine = self._create_engine(dml_path=PACKAGED_SCHEMA_PATH)

datasources: Optional[List[types.DatasourceOverride]] = None
if self._datasource is not None:
ds = self._datasource.copy()
ds.setdefault('name', 'db')
datasources = [ds]

await self._internal_engine.connect(
timeout=timeout,
datasources=datasources,
)

async def disconnect(self, timeout: Union[float, timedelta, None] = None) -> None:
"""Disconnect the Prisma query engine."""
if self._internal_engine is not None:
engine = self._internal_engine
self._internal_engine = None
if isinstance(timeout, (int, float)):
message = (
'Passing a number as `timeout` argument is deprecated '
'and will be removed in the next major release. '
'Use a `datetime.timedelta` instead.'
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
timeout = timedelta(seconds=timeout)
await engine.aclose(timeout=timeout)
engine.stop(timeout=timeout)

async def execute_raw(self, query: LiteralString, *args: Any) -> int:
resp = await self._execute(
method='execute_raw',
Expand Down Expand Up @@ -410,19 +359,6 @@ class Prisma(AsyncBasePrisma):
)
return await self._engine.query(builder.build(), tx_id=self._tx_id)

def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine:
if ENGINE_TYPE == EngineType.binary:
return QueryEngine(dml_path=dml_path, log_queries=self._log_queries, http_config=self._http_config)

raise NotImplementedError(f'Unsupported engine type: {ENGINE_TYPE}')

@property
def _engine_class(self) -> Type[AbstractEngine]:
if ENGINE_TYPE == EngineType.binary:
return QueryEngine
else: # pragma: no cover
raise RuntimeError(f'Unhandled engine type: {ENGINE_TYPE}')


class TransactionManager:
"""Context manager for wrapping a Prisma instance within a transaction.
Expand Down
Loading

0 comments on commit 01e24f6

Please sign in to comment.