Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'I/O operation on closed file' and 'Form data has been processed already' upon redirect on multipart data #9201

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ClientResponse,
Fingerprint,
RequestInfo,
process_data_to_payload,
)
from .client_ws import (
DEFAULT_WS_CLIENT_TIMEOUT,
Expand Down Expand Up @@ -521,6 +522,9 @@ async def _request(
for trace in traces:
await trace.send_request_start(method, url.update_query(params), headers)

# preprocess the data so we can reuse the Payload object when redirect is needed
data = process_data_to_payload(data)

timer = tm.timer()
try:
with timer:
Expand Down
18 changes: 18 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ class ConnectionKey:
proxy_headers_hash: Optional[int] # hash(CIMultiDict)


def process_data_to_payload(body: Any) -> Any:
# this function is used to convert data to payload before looping into redirects,
# so payload with io objects can be keep alive and use the stored data for the next request
if body is None:
return None

# FormData
GLGDLY marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(body, FormData):
body = body()

try:
body = payload.PAYLOAD_REGISTRY.get(body, disposition=None)
except payload.LookupError:
pass # keep for ClientRequest to handle
GLGDLY marked this conversation as resolved.
Show resolved Hide resolved

return body


class ClientRequest:
GET_METHODS = {
hdrs.METH_GET,
Expand Down
7 changes: 3 additions & 4 deletions aiohttp/formdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary)
self._fields: List[Any] = []
self._is_multipart = False
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset

Expand Down Expand Up @@ -117,8 +116,8 @@

def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError("Form data has been processed already")
if not self._fields:
return self._writer

Check warning on line 120 in aiohttp/formdata.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/formdata.py#L120

Added line #L120 was not covered by tests
GLGDLY marked this conversation as resolved.
Show resolved Hide resolved
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
Expand Down Expand Up @@ -149,7 +148,7 @@

self._writer.append_payload(part)

self._is_processed = True
self._fields.clear()
return self._writer

def __call__(self) -> Payload:
Expand Down
82 changes: 58 additions & 24 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,39 @@
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(disposition, filename=self._filename)

self._writable = True
self._seekable = True
try:
# It is weird but some IO object dont have `seekable()` method as IOBase object,
# it seems better for us to direct try if the `seek()` and `tell()` is available
# e.g. tarfile.TarFile._Stream
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self._value.seek(self._value.tell())
except (AttributeError, OSError):
GLGDLY marked this conversation as resolved.
Show resolved Hide resolved
self._seekable = False

if self._seekable:
self._stream_pos = self._value.tell()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: check if tell can be blocking

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure tell() is blocking as aiofiles delegates it to the executor as well https://pypi.org/project/aiofiles/

else:
self._stream_pos = 0

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: see if executor jobs can be combined

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be combined into a single executor job. Example

diff --git a/aiohttp/payload.py b/aiohttp/payload.py
index 395c44f6e..3b222b93e 100644
--- a/aiohttp/payload.py
+++ b/aiohttp/payload.py
@@ -319,15 +319,19 @@ class IOBasePayload(Payload):
         else:
             self._stream_pos = 0
 
-    async def write(self, writer: AbstractStreamWriter) -> None:
-        loop = asyncio.get_event_loop()
+    def _read_first(self) -> None:
+        """Read the first chunk of data from the stream."""
         if self._seekable:
-            await loop.run_in_executor(None, self._value.seek, self._stream_pos)
+            self._value.seek(self._stream_pos)
         elif not self._writable:
             raise RuntimeError(
                 f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
             )
-        chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+        return self._value.read(2**16)
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        loop = asyncio.get_running_loop()
+        chunk = await loop.run_in_executor(None, self._read_first)
         while chunk:
             await writer.write(chunk)
             chunk = await loop.run_in_executor(None, self._value.read, 2**16)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to seek here if it's the first time? That seems like it will be the common case

await loop.run_in_executor(None, self._value.seek, self._stream_pos)
elif not self._writable:
raise RuntimeError(
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify close will still always happen

if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
self._value.seek(self._stream_pos)

Check warning on line 342 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L342

Added line #L342 was not covered by tests
return "".join(r.decode(encoding, errors) for r in self._value.readlines())


Expand Down Expand Up @@ -354,40 +376,50 @@
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
return os.fstat(self._value.fileno()).st_size - self._stream_pos
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this doesn't run in the event loop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like multipart will call this in the event loop from append_payload via ClientRequest.update_body_from_data

except OSError:
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to seek on the first time?

self._value.seek(self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this isn't called in the event loop as it does block

return self._value.read()

async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to seek here if it's the first time?

await loop.run_in_executor(None, self._value.seek, self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: see if executor jobs can be combined

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jobs can be combined like #9201 (comment)

elif not self._writable:
raise RuntimeError(

Check warning on line 393 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L393

Added line #L393 was not covered by tests
f'Non-seekable IO payload "{self._value}" is already consumed (possibly due to redirect, consider storing in a seekable IO buffer instead)'
)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify close still happens in failure

if not self._seekable:
self._writable = False # Non-seekable IO `_value` can only be consumed once

Check warning on line 406 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L406

Added line #L406 was not covered by tests


class BytesIOPayload(IOBasePayload):
_value: io.BytesIO

@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
def size(self) -> Optional[int]:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this doesn't run in the event loop

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end = self._value.seek(0, os.SEEK_END)
self._value.seek(self._stream_pos)
return end - self._stream_pos
return None

Check warning on line 418 in aiohttp/payload.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/payload.py#L418

Added line #L418 was not covered by tests

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: make sure this is run in the executor

self._value.seek(self._stream_pos)
return self._value.read().decode(encoding, errors)


Expand All @@ -397,7 +429,7 @@
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: make sure this is run in the executor

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return os.fstat(self._value.fileno()).st_size - self._stream_pos
except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
Expand All @@ -406,6 +438,8 @@
return None

def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
if self._seekable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: verify this does not run in the event loop

self._value.seek(self._stream_pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to happen on the first attempt?

return self._value.read().decode(encoding, errors)


Expand Down
10 changes: 6 additions & 4 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,8 @@ async def test_GET_DEFLATE(
aiohttp_client: AiohttpClient, data: Optional[bytes]
) -> None:
async def handler(request: web.Request) -> web.Response:
recv_data = await request.read()
assert recv_data == b"" # both cases should receive empty bytes
return web.json_response({"ok": True})

write_mock = None
Expand All @@ -1553,10 +1555,10 @@ async def write_bytes(
self: ClientRequest, writer: StreamWriter, conn: Connection
) -> None:
nonlocal write_mock
original_write = writer._write
original_write = writer.write

with mock.patch.object(
writer, "_write", autospec=True, spec_set=True, side_effect=original_write
writer, "write", autospec=True, spec_set=True, side_effect=original_write
) as write_mock:
await original_write_bytes(self, writer, conn)

Expand All @@ -1571,8 +1573,8 @@ async def write_bytes(
assert content == {"ok": True}

assert write_mock is not None
# No chunks should have been sent for an empty body.
write_mock.assert_not_called()
# Empty b"" should have been sent for an empty body.
write_mock.assert_called_once_with(b"")


async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None:
Expand Down
122 changes: 105 additions & 17 deletions tests/test_formdata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import io
import pathlib
import tarfile
from unittest import mock

import pytest

from aiohttp import FormData, web
from aiohttp.client_exceptions import ClientConnectionError
from aiohttp.http_writer import StreamWriter
from aiohttp.pytest_plugin import AiohttpClient

Expand Down Expand Up @@ -95,28 +98,113 @@
assert b'name="email 1"' in buf


async def test_mark_formdata_as_processed(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response()
async def test_formdata_boundary_param() -> None:
boundary = "some_boundary"
form = FormData(boundary=boundary)
assert form._writer.boundary == boundary

app = web.Application()
app.add_routes([web.post("/", handler)])

client = await aiohttp_client(app)
async def test_formdata_on_redirect(aiohttp_client: AiohttpClient) -> None:
with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj:
content = fobj.read()
fobj.seek(0)

data = FormData()
data.add_field("test", "test_value", content_type="application/json")
async def handler_0(request: web.Request) -> web.Response:
raise web.HTTPPermanentRedirect("/1")

resp = await client.post("/", data=data)
assert len(data._writer._parts) == 1
async def handler_1(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
return web.Response()

with pytest.raises(RuntimeError):
await client.post("/", data=data)
app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

resp.release()
client = await aiohttp_client(app)

data = FormData()
data.add_field("sample.txt", fobj)

async def test_formdata_boundary_param() -> None:
boundary = "some_boundary"
form = FormData(boundary=boundary)
assert form._writer.boundary == boundary
resp = await client.post("/0", data=data)
assert len(data._writer._parts) == 1
assert resp.status == 200

resp.release()


async def test_formdata_on_redirect_after_recv(aiohttp_client: AiohttpClient) -> None:
with pathlib.Path(pathlib.Path(__file__).parent / "sample.txt").open("rb") as fobj:
content = fobj.read()
fobj.seek(0)

async def handler_0(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
raise web.HTTPPermanentRedirect("/1")

async def handler_1(request: web.Request) -> web.Response:
req_data = await request.post()
assert ["sample.txt"] == list(req_data.keys())
file_field = req_data["sample.txt"]
assert isinstance(file_field, web.FileField)
assert content == file_field.file.read()
return web.Response()

app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

client = await aiohttp_client(app)

data = FormData()
data.add_field("sample.txt", fobj)

resp = await client.post("/0", data=data)
assert len(data._writer._parts) == 1
assert resp.status == 200

resp.release()


async def test_streaming_tarfile_on_redirect(aiohttp_client: AiohttpClient) -> None:
data = b"This is a tar file payload text file."

async def handler_0(request: web.Request) -> web.Response:
await request.read()
raise web.HTTPPermanentRedirect("/1")

async def handler_1(request: web.Request) -> web.Response:
await request.read()
return web.Response()

Check warning on line 185 in tests/test_formdata.py

View check run for this annotation

Codecov / codecov/patch

tests/test_formdata.py#L185

Added line #L185 was not covered by tests

app = web.Application()
app.router.add_post("/0", handler_0)
app.router.add_post("/1", handler_1)

client = await aiohttp_client(app)

buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
ti = tarfile.TarInfo(name="payload1.txt")
ti.size = len(data)
tf.addfile(tarinfo=ti, fileobj=io.BytesIO(data))

# Streaming tarfile.
buf.seek(0)
tf = tarfile.open(fileobj=buf, mode="r|")
for entry in tf:
with pytest.raises(ClientConnectionError) as exc_info:
await client.post("/0", data=tf.extractfile(entry))
raw_exc_info = exc_info._excinfo
assert isinstance(raw_exc_info, tuple)
cause_exc = raw_exc_info[1].__cause__
assert isinstance(cause_exc, RuntimeError)
assert len(cause_exc.args) == 1
assert cause_exc.args[0].startswith("Non-seekable IO payload")
Loading
Loading