Skip to content

Commit

Permalink
Pass the request ctx rather than use the globals in the app
Browse files Browse the repository at this point in the history
The globals have a performance penalty which can be justified for the
convinience in user code. In the app however the ctx can easily be
passed through the method calls thereby reducing the performance
penalty.

This may affect extensions if they have subclassed the app and
overridden these methods.
  • Loading branch information
pgjones committed Aug 20, 2023
1 parent 1d8b53f commit 0e738ab
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 41 deletions.
77 changes: 42 additions & 35 deletions src/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,9 @@ def test_cli_runner(self, **kwargs: t.Any) -> FlaskCliRunner:
return cls(self, **kwargs) # type: ignore

def handle_http_exception(
self, e: HTTPException
self,
e: HTTPException,
ctx: RequestContext,
) -> HTTPException | ft.ResponseReturnValue:
"""Handles an HTTP exception. By default this will invoke the
registered error handlers and fall back to returning the
Expand Down Expand Up @@ -722,13 +724,15 @@ def handle_http_exception(
if isinstance(e, RoutingException):
return e

handler = self._find_error_handler(e, request.blueprints)
handler = self._find_error_handler(e, ctx.request.blueprints)
if handler is None:
return e
return self.ensure_sync(handler)(e)

def handle_user_exception(
self, e: Exception
self,
e: Exception,
ctx: RequestContext,
) -> HTTPException | ft.ResponseReturnValue:
"""This method is called whenever an exception occurs that
should be handled. A special case is :class:`~werkzeug
Expand All @@ -750,16 +754,16 @@ def handle_user_exception(
e.show_exception = True

if isinstance(e, HTTPException) and not self.trap_http_exception(e):
return self.handle_http_exception(e)
return self.handle_http_exception(e, ctx)

handler = self._find_error_handler(e, request.blueprints)
handler = self._find_error_handler(e, ctx.request.blueprints)

if handler is None:
raise

return self.ensure_sync(handler)(e)

def handle_exception(self, e: Exception) -> Response:
def handle_exception(self, e: Exception, ctx: RequestContext) -> Response:
"""Handle an exception that did not have an error handler
associated with it, or that was raised from an error handler.
This always causes a 500 ``InternalServerError``.
Expand Down Expand Up @@ -802,19 +806,20 @@ def handle_exception(self, e: Exception) -> Response:

raise e

self.log_exception(exc_info)
self.log_exception(exc_info, ctx)
server_error: InternalServerError | ft.ResponseReturnValue
server_error = InternalServerError(original_exception=e)
handler = self._find_error_handler(server_error, request.blueprints)
handler = self._find_error_handler(server_error, ctx.request.blueprints)

if handler is not None:
server_error = self.ensure_sync(handler)(server_error)

return self.finalize_request(server_error, from_error_handler=True)
return self.finalize_request(server_error, ctx, from_error_handler=True)

def log_exception(
self,
exc_info: (tuple[type, BaseException, TracebackType] | tuple[None, None, None]),
ctx: RequestContext,
) -> None:
"""Logs an exception. This is called by :meth:`handle_exception`
if debugging is disabled and right before the handler is called.
Expand All @@ -824,10 +829,10 @@ def log_exception(
.. versionadded:: 0.8
"""
self.logger.error(
f"Exception on {request.path} [{request.method}]", exc_info=exc_info
f"Exception on {ctx.request.path} [{ctx.request.method}]", exc_info=exc_info
)

def dispatch_request(self) -> ft.ResponseReturnValue:
def dispatch_request(self, ctx: RequestContext) -> ft.ResponseReturnValue:
"""Does the request dispatching. Matches the URL and returns the
return value of the view or error handler. This does not have to
be a response object. In order to convert the return value to a
Expand All @@ -837,22 +842,21 @@ def dispatch_request(self) -> ft.ResponseReturnValue:
This no longer does the exception handling, this code was
moved to the new :meth:`full_dispatch_request`.
"""
req = request_ctx.request
if req.routing_exception is not None:
self.raise_routing_exception(req)
rule: Rule = req.url_rule # type: ignore[assignment]
if ctx.request.routing_exception is not None:
self.raise_routing_exception(ctx.request)
rule: Rule = ctx.request.url_rule # type: ignore[assignment]
# if we provide automatic options for this URL and the
# request came with the OPTIONS method, reply automatically
if (
getattr(rule, "provide_automatic_options", False)
and req.method == "OPTIONS"
and ctx.request.method == "OPTIONS"
):
return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment]
view_args: dict[str, t.Any] = ctx.request.view_args # type: ignore[assignment]
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args)

def full_dispatch_request(self) -> Response:
def full_dispatch_request(self, ctx: RequestContext) -> Response:
"""Dispatches the request and on top of that performs request
pre and postprocessing as well as HTTP exception catching and
error handling.
Expand All @@ -863,16 +867,17 @@ def full_dispatch_request(self) -> Response:

try:
request_started.send(self, _async_wrapper=self.ensure_sync)
rv = self.preprocess_request()
rv = self.preprocess_request(ctx)
if rv is None:
rv = self.dispatch_request()
rv = self.dispatch_request(ctx)
except Exception as e:
rv = self.handle_user_exception(e)
return self.finalize_request(rv)
rv = self.handle_user_exception(e, ctx)
return self.finalize_request(rv, ctx)

def finalize_request(
self,
rv: ft.ResponseReturnValue | HTTPException,
ctx: RequestContext,
from_error_handler: bool = False,
) -> Response:
"""Given the return value from a view function this finalizes
Expand All @@ -889,7 +894,7 @@ def finalize_request(
"""
response = self.make_response(rv)
try:
response = self.process_response(response)
response = self.process_response(response, ctx)
request_finished.send(
self, _async_wrapper=self.ensure_sync, response=response
)
Expand Down Expand Up @@ -1216,7 +1221,7 @@ def make_response(self, rv: ft.ResponseReturnValue) -> Response:

return rv

def preprocess_request(self) -> ft.ResponseReturnValue | None:
def preprocess_request(self, ctx: RequestContext) -> ft.ResponseReturnValue | None:
"""Called before the request is dispatched. Calls
:attr:`url_value_preprocessors` registered with the app and the
current blueprint (if any). Then calls :attr:`before_request_funcs`
Expand All @@ -1226,12 +1231,12 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None:
value is handled as if it was the return value from the view, and
further request handling is stopped.
"""
names = (None, *reversed(request.blueprints))
names = (None, *reversed(ctx.request.blueprints))

for name in names:
if name in self.url_value_preprocessors:
for url_func in self.url_value_preprocessors[name]:
url_func(request.endpoint, request.view_args)
url_func(ctx.request.endpoint, ctx.request.view_args)

for name in names:
if name in self.before_request_funcs:
Expand All @@ -1243,7 +1248,7 @@ def preprocess_request(self) -> ft.ResponseReturnValue | None:

return None

def process_response(self, response: Response) -> Response:
def process_response(self, response: Response, ctx: RequestContext) -> Response:
"""Can be overridden in order to modify the response object
before it's sent to the WSGI server. By default this will
call all the :meth:`after_request` decorated functions.
Expand All @@ -1256,23 +1261,25 @@ def process_response(self, response: Response) -> Response:
:return: a new response object or the same, has to be an
instance of :attr:`response_class`.
"""
ctx = request_ctx._get_current_object() # type: ignore[attr-defined]

for func in ctx._after_request_functions:
response = self.ensure_sync(func)(response)

for name in chain(request.blueprints, (None,)):
for name in chain(ctx.request.blueprints, (None,)):
if name in self.after_request_funcs:
for func in reversed(self.after_request_funcs[name]):
response = self.ensure_sync(func)(response)

if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)
self.session_interface.save_session(
self, ctx.session, response # type: ignore[arg-type]
)

return response

def do_teardown_request(
self, exc: BaseException | None = _sentinel # type: ignore
self,
ctx: RequestContext,
exc: BaseException | None = _sentinel, # type: ignore
) -> None:
"""Called after the request is dispatched and the response is
returned, right before the request context is popped.
Expand All @@ -1297,7 +1304,7 @@ def do_teardown_request(
if exc is _sentinel:
exc = sys.exc_info()[1]

for name in chain(request.blueprints, (None,)):
for name in chain(ctx.request.blueprints, (None,)):
if name in self.teardown_request_funcs:
for func in reversed(self.teardown_request_funcs[name]):
self.ensure_sync(func)(exc)
Expand Down Expand Up @@ -1452,10 +1459,10 @@ def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:
try:
try:
ctx.push()
response = self.full_dispatch_request()
response = self.full_dispatch_request(ctx)
except Exception as e:
error = e
response = self.handle_exception(e)
response = self.handle_exception(e, ctx)
except: # noqa: B001
error = sys.exc_info()[1]
raise
Expand Down
2 changes: 1 addition & 1 deletion src/flask/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def pop(self, exc: BaseException | None = _sentinel) -> None: # type: ignore
if clear_request:
if exc is _sentinel:
exc = sys.exc_info()[1]
self.app.do_teardown_request(exc)
self.app.do_teardown_request(self, exc)

request_close = getattr(self.request, "close", None)
if request_close is not None:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_reqctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def test_bad_environ_raises_bad_request():
# use a non-printable character in the Host - this is key to this test
environ["HTTP_HOST"] = "\x8a"

with app.request_context(environ):
response = app.full_dispatch_request()
with app.request_context(environ) as ctx:
response = app.full_dispatch_request(ctx)
assert response.status_code == 400


Expand All @@ -308,8 +308,8 @@ def index():
# these characters are all IDNA-compatible
environ["HTTP_HOST"] = "ąśźäüжŠßя.com"

with app.request_context(environ):
response = app.full_dispatch_request()
with app.request_context(environ) as ctx:
response = app.full_dispatch_request(ctx)

assert response.status_code == 200

Expand Down
2 changes: 1 addition & 1 deletion tests/test_subclassing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_suppressed_exception_logging():
class SuppressedFlask(flask.Flask):
def log_exception(self, exc_info):
def log_exception(self, exc_info, ctx):
pass

out = StringIO()
Expand Down

0 comments on commit 0e738ab

Please sign in to comment.