Skip to content

Commit

Permalink
Convert tests by using pytest-aiohttp
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Nov 24, 2016
1 parent f98e81a commit bdfabf5
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 465 deletions.
9 changes: 6 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
flake8==3.0.4
flake8==3.1.0
coverage==4.2
sphinx==1.4.8
alabaster>=0.6.2
aiohttp==1.0.3
aiohttp==1.1.5
jinja2==2.8
pytest==3.0.3
pytest==3.0.4
pytest-cov==2.4.0
yarl==0.7.1
multidict==2.1.2
pytest-aiohttp==0.1.3
-e .
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import codecs
from setuptools import setup
import os
import re

from setuptools import setup

with codecs.open(os.path.join(os.path.abspath(os.path.dirname(
__file__)), 'aiohttp_jinja2', '__init__.py'), 'r', 'latin1') as fp:
Expand Down
229 changes: 0 additions & 229 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,229 +0,0 @@
import asyncio
import collections
import gc
import logging
import pytest
import re
import socket
import sys
import warnings

from aiohttp import web


class _AssertWarnsContext:
"""A context manager used to implement TestCase.assertWarns* methods."""

def __init__(self, expected, expected_regex=None):
self.expected = expected
if expected_regex is not None:
expected_regex = re.compile(expected_regex)
self.expected_regex = expected_regex
self.obj_name = None

def __enter__(self):
# The __warningregistry__'s need to be in a pristine state for tests
# to work properly.
for v in sys.modules.values():
if getattr(v, '__warningregistry__', None):
v.__warningregistry__ = {}
self.warnings_manager = warnings.catch_warnings(record=True)
self.warnings = self.warnings_manager.__enter__()
warnings.simplefilter("always", self.expected)
return self

def __exit__(self, exc_type, exc_value, tb):
self.warnings_manager.__exit__(exc_type, exc_value, tb)
if exc_type is not None:
# let unexpected exceptions pass through
return
try:
exc_name = self.expected.__name__
except AttributeError:
exc_name = str(self.expected)
first_matching = None
for m in self.warnings:
w = m.message
if not isinstance(w, self.expected):
continue
if first_matching is None:
first_matching = w
if (self.expected_regex is not None and
not self.expected_regex.search(str(w))):
continue
# store warning for later retrieval
self.warning = w
self.filename = m.filename
self.lineno = m.lineno
return
# Now we simply try to choose a helpful failure message
if first_matching is not None:
__tracebackhide__ = True
assert 0, '"{}" does not match "{}"'.format(
self.expected_regex.pattern, str(first_matching))
if self.obj_name:
__tracebackhide__ = True
assert 0, "{} not triggered by {}".format(exc_name,
self.obj_name)
else:
__tracebackhide__ = True
assert 0, "{} not triggered".format(exc_name)


_LoggingWatcher = collections.namedtuple("_LoggingWatcher",
["records", "output"])


class _CapturingHandler(logging.Handler):
"""
A logging handler capturing all (raw and formatted) logging output.
"""

def __init__(self):
logging.Handler.__init__(self)
self.watcher = _LoggingWatcher([], [])

def flush(self):
pass

def emit(self, record):
self.watcher.records.append(record)
msg = self.format(record)
self.watcher.output.append(msg)


class _AssertLogsContext:
"""A context manager used to implement TestCase.assertLogs()."""

LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s"

def __init__(self, logger_name=None, level=None):
self.logger_name = logger_name
if level:
self.level = logging._nameToLevel.get(level, level)
else:
self.level = logging.INFO
self.msg = None

def __enter__(self):
if isinstance(self.logger_name, logging.Logger):
logger = self.logger = self.logger_name
else:
logger = self.logger = logging.getLogger(self.logger_name)
formatter = logging.Formatter(self.LOGGING_FORMAT)
handler = _CapturingHandler()
handler.setFormatter(formatter)
self.watcher = handler.watcher
self.old_handlers = logger.handlers[:]
self.old_level = logger.level
self.old_propagate = logger.propagate
logger.handlers = [handler]
logger.setLevel(self.level)
logger.propagate = False
return handler.watcher

def __exit__(self, exc_type, exc_value, tb):
self.logger.handlers = self.old_handlers
self.logger.propagate = self.old_propagate
self.logger.setLevel(self.old_level)
if exc_type is not None:
# let unexpected exceptions pass through
return False
if len(self.watcher.records) == 0:
__tracebackhide__ = True
assert 0, ("no logs of level {} or higher triggered on {}"
.format(logging.getLevelName(self.level),
self.logger.name))


@pytest.yield_fixture
def warning():
yield _AssertWarnsContext


@pytest.yield_fixture
def log():
yield _AssertLogsContext


@pytest.fixture
def unused_port():
def f():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
return f


@pytest.yield_fixture
def loop(request):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

yield loop

loop.stop()
loop.run_forever()
loop.close()
gc.collect()
asyncio.set_event_loop(None)


@pytest.yield_fixture
def create_server(loop, unused_port):
app = handler = srv = None

@asyncio.coroutine
def create(*, debug=False, ssl_ctx=None, proto='http', **kwargs):
nonlocal app, handler, srv
app = web.Application(loop=loop, **kwargs)
port = unused_port()
handler = app.make_handler(debug=debug, keep_alive_on=False)
srv = yield from loop.create_server(handler, '127.0.0.1', port,
ssl=ssl_ctx)
if ssl_ctx:
proto += 's'
url = "{}://127.0.0.1:{}".format(proto, port)
return app, url

yield create

@asyncio.coroutine
def finish():
yield from handler.finish_connections()
yield from app.finish()
srv.close()
yield from srv.wait_closed()

loop.run_until_complete(finish())


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
if collector.funcnamefilter(name):
if not callable(obj):
return
item = pytest.Function(name, parent=collector)
if 'run_loop' in item.keywords:
return list(collector._genfunctions(name, obj))


@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
"""
Run asyncio marked test functions in an event loop instead of a normal
function call.
"""
if 'run_loop' in pyfuncitem.keywords:
funcargs = pyfuncitem.funcargs
loop = funcargs['loop']
testargs = {arg: funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames}
loop.run_until_complete(pyfuncitem.obj(**testargs))
return True


def pytest_runtest_setup(item):
if 'run_loop' in item.keywords and 'loop' not in item.fixturenames:
# inject an event loop fixture for all async tests
item.fixturenames.append('loop')
40 changes: 22 additions & 18 deletions tests/test_context_processors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import aiohttp
import aiohttp_jinja2
import asyncio

import jinja2
import pytest
from aiohttp import web

import aiohttp_jinja2


@pytest.mark.run_loop
def test_context_processors(create_server, loop):
@asyncio.coroutine
def test_context_processors(test_client, loop):

@aiohttp_jinja2.template('tmpl.jinja2')
@asyncio.coroutine
def func(request):
return {'bar': 2}

app, url = yield from create_server(
middlewares=[
app = web.Application(loop=loop, middlewares=[
aiohttp_jinja2.context_processors_middleware])
aiohttp_jinja2.setup(app, loader=jinja2.DictLoader(
{'tmpl.jinja2':
Expand All @@ -26,41 +26,44 @@ def func(request):
lambda request: {'foo': 1, 'bar': 'should be overwriten'}),
)

app.router.add_route('GET', '/', func)
app.router.add_get('/', func)

client = yield from test_client(app)

resp = yield from aiohttp.request('GET', url, loop=loop)
resp = yield from client.get('/')
assert 200 == resp.status
txt = yield from resp.text()
assert 'foo: 1, bar: 2, path: /' == txt


@pytest.mark.run_loop
def test_context_is_response(create_server, loop):
@asyncio.coroutine
def test_context_is_response(test_client, loop):

@aiohttp_jinja2.template('tmpl.jinja2')
def func(request):
return aiohttp.web_exceptions.HTTPForbidden()
return web.HTTPForbidden()

app, url = yield from create_server()
app = web.Application(loop=loop)
aiohttp_jinja2.setup(app, loader=jinja2.DictLoader(
{'tmpl.jinja2': "template"}))

app.router.add_route('GET', '/', func)
client = yield from test_client(app)

resp = yield from aiohttp.request('GET', url, loop=loop)
resp = yield from client.get('/')
assert 403 == resp.status
yield from resp.release()


@pytest.mark.run_loop
def test_context_processors_new_setup_style(create_server, loop):
@asyncio.coroutine
def test_context_processors_new_setup_style(test_client, loop):

@aiohttp_jinja2.template('tmpl.jinja2')
@asyncio.coroutine
def func(request):
return {'bar': 2}

app, url = yield from create_server()
app = web.Application(loop=loop)
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader(
Expand All @@ -74,8 +77,9 @@ def func(request):
'bar': 'should be overwriten'})))

app.router.add_route('GET', '/', func)
client = yield from test_client(app)

resp = yield from aiohttp.request('GET', url, loop=loop)
resp = yield from client.get('/')
assert 200 == resp.status
txt = yield from resp.text()
assert 'foo: 1, bar: 2, path: /' == txt
16 changes: 9 additions & 7 deletions tests/test_jinja_filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import aiohttp
import aiohttp_jinja2
import asyncio

import jinja2
import pytest
from aiohttp import web

import aiohttp_jinja2


@pytest.mark.run_loop
def test_jinja_filters(create_server, loop):
@asyncio.coroutine
def test_jinja_filters(test_client, loop):

@aiohttp_jinja2.template('tmpl.jinja2')
@asyncio.coroutine
Expand All @@ -16,16 +17,17 @@ def index(request):
def add_2(value):
return value + 2

app, url = yield from create_server()
app = web.Application(loop=loop)
aiohttp_jinja2.setup(
app,
loader=jinja2.DictLoader({'tmpl.jinja2': "{{ 5|add_2 }}"}),
filters={'add_2': add_2}
)

app.router.add_route('GET', '/', index)
client = yield from test_client(app)

resp = yield from aiohttp.request('GET', url, loop=loop)
resp = yield from client.get('/')
assert 200 == resp.status
txt = yield from resp.text()
assert '7' == txt
Loading

0 comments on commit bdfabf5

Please sign in to comment.