Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions python_multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class FormParserConfig(FileConfig):
MAX_BODY_SIZE: float
MAX_HEADER_COUNT: int
MAX_HEADER_SIZE: int
ALLOW_BARE_LF: bool

CallbackName: TypeAlias = Literal[
"start",
Expand Down Expand Up @@ -997,6 +998,9 @@ class MultipartParser(BaseParser):
max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
max_header_count: The maximum number of headers allowed per part.
max_header_size: The maximum size of a single header line (excluding the trailing CRLF).
allow_bare_lf: Accept bare ``\\n`` instead of ``\\r\\n`` as the framing line terminator. Defaults to ``False``
(strict RFC 2046). The convention is detected from the opening boundary and applied to the whole message;
part data is never altered.
""" # noqa: E501

def __init__(
Expand All @@ -1007,6 +1011,7 @@ def __init__(
*,
max_header_count: int = DEFAULT_MAX_HEADER_COUNT,
max_header_size: int = DEFAULT_MAX_HEADER_SIZE,
allow_bare_lf: bool = False,
) -> None:
# Initialize parser state.
super().__init__()
Expand All @@ -1017,7 +1022,7 @@ def __init__(

if not isinstance(max_size, Number) or max_size < 1:
raise ValueError("max_size must be a positive number, not %r" % max_size)
self.max_size = max_size
self.max_size: int | float = max_size
self._current_size = 0

self.max_header_count = max_header_count
Expand All @@ -1029,11 +1034,15 @@ def __init__(
# Setup marks. These are used to track the state of data received.
self.marks: dict[str, int] = {}

self.allow_bare_lf = allow_bare_lf
self._lf_only = False

# Save our boundary.
if isinstance(boundary, str): # pragma: no cover
boundary = boundary.encode("latin-1")
if len(boundary) > MAX_BOUNDARY_LENGTH:
raise FormParserError(f"Boundary length {len(boundary)} exceeds maximum of {MAX_BOUNDARY_LENGTH}")
self._boundary_token = boundary
self.boundary = b"\r\n--" + boundary

def write(self, data: bytes) -> int:
Expand Down Expand Up @@ -1085,6 +1094,9 @@ def _internal_write(self, data: bytes, length: int) -> int:
current_header_count = self._current_header_count
current_header_size = self._current_header_size

allow_bare_lf = self.allow_bare_lf
lf_only = self._lf_only

# Our index defaults to 0.
i = 0

Expand Down Expand Up @@ -1129,10 +1141,10 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
if lookbehind_len <= boundary_length:
self.callback(name, boundary, 0, lookbehind_len)
elif self.flags & FLAG_PART_BOUNDARY:
lookback = boundary + b"\r\n"
lookback = boundary + (b"\n" if self._lf_only else b"\r\n")
self.callback(name, lookback, 0, lookbehind_len)
elif self.flags & FLAG_LAST_BOUNDARY:
lookback = boundary + b"--\r\n"
lookback = boundary + (b"--\n" if self._lf_only else b"--\r\n")
self.callback(name, lookback, 0, lookbehind_len)
else: # pragma: no cover (error case)
self.logger.warning("Look-back buffer error")
Expand Down Expand Up @@ -1175,14 +1187,24 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
if c == HYPHEN:
# Potential empty message.
state = MultipartState.END_BOUNDARY
elif c != CR:
# Error!
index += 1
elif c == CR:
index += 1
elif allow_bare_lf and c == LF:
# Bare-LF body: use the \n-prefixed boundary.
lf_only = self._lf_only = True
boundary = self.boundary = b"\n--" + self._boundary_token
boundary_length = len(boundary)
index = 0
self.callback("part_begin")
current_header_count = 0
current_header_size = 0
state = MultipartState.HEADER_FIELD_START
else:
msg = "Did not find CR at end of boundary (%d)" % (i,)
self.logger.warning(msg)
raise MultipartParseError(msg, offset=i)

index += 1

elif index == boundary_length - 1:
if c != LF:
msg = "Did not find LF at end of boundary (%d)" % (i,)
Expand Down Expand Up @@ -1214,8 +1236,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
# Mark the start of a header field here, reset the index, and
# continue parsing our header field.
index = 0
nl = LF if lf_only else CR

if c != CR:
if c != nl:
current_header_count += 1
if current_header_count > self.max_header_count:
raise MultipartParseError("Maximum header count exceeded", offset=i)
Expand All @@ -1224,24 +1247,25 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
# Set a mark of our header field.
set_mark("header_field")

# Notify that we're starting a header if the next character is
# not a CR; a CR at the beginning of the header will cause us
# to stop parsing headers in the MultipartState.HEADER_FIELD state,
# below.
if c != CR:
# Notify that we're starting a header unless this is the blank
# line that ends the headers (handled in HEADER_FIELD below).
if c != nl:
self.callback("header_begin")

# Move to parsing header fields.
state = MultipartState.HEADER_FIELD
i -= 1

elif state == MultipartState.HEADER_FIELD:
# If we've reached a CR at the beginning of a header, it means
# that we've reached the second of 2 newlines, and so there are
# no more headers to parse.
if c == CR and index == 0:
# A line terminator at the start of a header is the blank line
# that ends the headers.
if index == 0 and c == (LF if lf_only else CR):
delete_mark("header_field")
state = MultipartState.HEADERS_ALMOST_DONE
if lf_only:
self.callback("headers_finished")
state = MultipartState.PART_DATA_START
else:
state = MultipartState.HEADERS_ALMOST_DONE
i += 1
continue

Expand Down Expand Up @@ -1295,17 +1319,16 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No
i -= 1

elif state == MultipartState.HEADER_VALUE:
# The value runs until the terminating CR; jump straight to it
# instead of inspecting every byte.
cr = data.find(b"\r", i, length)
end = cr if cr != -1 else length
# The value runs until its terminator (CR for CRLF, LF for bare LF).
term = data.find(b"\n" if lf_only else b"\r", i, length)
end = term if term != -1 else length
advance_header_size(end - i)
if cr != -1:
i = cr
if term != -1:
i = term
data_callback("header_value", i)
self.callback("header_end")
current_header_size = 0
state = MultipartState.HEADER_VALUE_ALMOST_DONE
state = MultipartState.HEADER_FIELD_START if lf_only else MultipartState.HEADER_VALUE_ALMOST_DONE
else:
i = length

Expand Down Expand Up @@ -1387,12 +1410,24 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No

# Our index is equal to the length of our boundary!
elif index == boundary_length:
if lf_only and c == LF:
# Bare-LF part boundary: the single LF terminates it.
data_callback("part_data", i - index)
self.callback("part_end")
self.callback("part_begin")
current_header_count = 0
current_header_size = 0
index = 0
state = MultipartState.HEADER_FIELD_START
i += 1
continue

# First we increment it.
index += 1

# Now, if we've reached a newline, we need to set this as
# the potential end of our boundary.
if c == CR:
if c == CR and not lf_only:
flags |= FLAG_PART_BOUNDARY

# Otherwise, if this is a hyphen, we might be at the last
Expand Down Expand Up @@ -1561,6 +1596,8 @@ class FormParser:
"UPLOAD_KEEP_EXTENSIONS": False,
# Error on invalid Content-Transfer-Encoding?
"UPLOAD_ERROR_ON_BAD_CTE": False,
# Accept bare LF (instead of CRLF) as the framing line terminator.
"ALLOW_BARE_LF": False,
}

def __init__(
Expand Down Expand Up @@ -1795,6 +1832,7 @@ def _on_end() -> None:
max_size=self.config["MAX_BODY_SIZE"],
max_header_count=self.config["MAX_HEADER_COUNT"],
max_header_size=self.config["MAX_HEADER_SIZE"],
allow_bare_lf=self.config["ALLOW_BARE_LF"],
)

else:
Expand Down Expand Up @@ -1876,6 +1914,7 @@ def parse_form(
on_field: Callable[[Field], None] | None,
on_file: Callable[[File], None] | None,
chunk_size: int = 1048576,
config: dict[Any, Any] = {},
) -> None:
"""This function is useful if you just want to parse a request body,
without too much work. Pass it a dictionary-like object of the request's
Expand All @@ -1889,12 +1928,13 @@ def parse_form(
on_file: Callback to call with each parsed file.
chunk_size: The maximum size to read from the input stream and write to the parser at one time.
Defaults to 1 MiB.
config: Configuration variables to pass to the FormParser (e.g. ``{"ALLOW_BARE_LF": True}``).
"""
if chunk_size < 1:
raise ValueError(f"chunk_size must be a positive number, not {chunk_size!r}")

# Create our form parser.
parser = create_form_parser(headers, on_field, on_file)
parser = create_form_parser(headers, on_field, on_file, config)

# Read chunks of 1MiB and write to the parser, but never read more than
# the given Content-Length, if any.
Expand Down
22 changes: 12 additions & 10 deletions tests/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import sys
import types
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar, cast

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -61,7 +61,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
class ParametrizingMetaclass(type):
IDENTIFIER_RE = re.compile("[^A-Za-z0-9]")

def __new__(klass, name: str, bases: tuple[type, ...], attrs: types.MappingProxyType[str, Any]) -> type:
def __new__(cls, name: str, bases: tuple[type, ...], attrs: types.MappingProxyType[str, Any]) -> type:
new_attrs = attrs.copy()
for attr_name, attr in attrs.items():
# We only care about functions
Expand All @@ -74,15 +74,14 @@ def __new__(klass, name: str, bases: tuple[type, ...], attrs: types.MappingProxy
continue

# Create multiple copies of the function.
for _, values in enumerate(param_values):
for i, values in enumerate(param_values):
assert len(param_names) == len(values)

# Get a repr of the values, and fix it to be a valid identifier
human = "_".join([klass.IDENTIFIER_RE.sub("", repr(x)) for x in values])
human = "_".join([cls.IDENTIFIER_RE.sub("", repr(x)) for x in values])

# Create a new name.
# new_name = attr.__name__ + "_%d" % i
new_name = attr.__name__ + "__" + human
# Cap length so multi-MB params can't overflow Windows' PYTEST_CURRENT_TEST env var.
new_name = f"{attr.__name__}__{i}_{human[:64]}"

# Create a replacement function.
def create_new_func(
Expand All @@ -109,9 +108,12 @@ def new_func(self: types.FunctionType) -> Any:
del new_attrs[attr_name]

# We create the class as normal, except we use our new attributes.
return type.__new__(klass, name, bases, new_attrs)
return type.__new__(cls, name, bases, new_attrs)


# This is a class decorator that actually applies the above metaclass.
def parametrize_class(klass: type) -> ParametrizingMetaclass:
return ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__)
_ClassT = TypeVar("_ClassT", bound=type)


def parametrize_class(klass: _ClassT) -> _ClassT:
return cast(_ClassT, ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__))
67 changes: 67 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ def test_request_body_fuzz(self) -> None:

# Pick what we're supposed to do.
choice = random.choice([1, 2, 3])
msg = ""
if choice == 1:
# Add a random byte.
i = random.randrange(len(test_data))
Expand Down Expand Up @@ -1661,6 +1662,72 @@ def test_parse_form_negative_content_length(self) -> None:
)


# A field followed by a file, framed entirely with bare LF instead of CRLF.
_BARE_LF_BODY = (
b"--boundary\n"
b'Content-Disposition: form-data; name="field"\n'
b"\n"
b"test1\n"
b"--boundary\n"
b'Content-Disposition: form-data; name="file"; filename="file.txt"\n'
b"Content-Type: text/plain\n"
b"\n"
b"test2\n"
b"--boundary--"
)


def _parse_bare_lf(chunks: list[bytes]) -> tuple[list[Field], list[File]]:
"""Feed ``chunks`` to a bare-LF-enabled FormParser and return the parsed fields/files."""
fields: list[Field] = []
files: list[File] = []
parser = FormParser(
"multipart/form-data", fields.append, files.append, boundary="boundary", config={"ALLOW_BARE_LF": True}
)
for chunk in chunks:
parser.write(chunk)
parser.finalize()
return fields, files


def _assert_bare_lf_parsed(fields: list[Field], files: list[File]) -> None:
"""Assert the field + file framed in ``_BARE_LF_BODY`` were parsed correctly."""
assert [(f.field_name, f.value) for f in fields] == [(b"field", b"test1")]
assert len(files) == 1
assert files[0].field_name == b"file"
assert files[0].file_name == b"file.txt"
files[0].file_object.seek(0)
assert files[0].file_object.read() == b"test2"
files[0].close()


def test_bare_lf_boundary_rejected_by_default() -> None:
"""A bare-LF boundary is a parse error unless explicitly allowed."""
parser = FormParser("multipart/form-data", lambda _: None, lambda _: None, boundary="boundary")
with pytest.raises(MultipartParseError):
parser.write(_BARE_LF_BODY)


def test_bare_lf_body_parses_across_every_chunk_split() -> None:
"""The LF body parses identically however it is split, exercising the boundary look-behind."""
for first, last in split_all(_BARE_LF_BODY):
_assert_bare_lf_parsed(*_parse_bare_lf([first, last]))


def test_parse_form_forwards_config_to_enable_bare_lf() -> None:
"""parse_form forwards config, so the bare-LF opt-in is reachable from the top-level helper."""
fields: list[Field] = []
files: list[File] = []
parse_form(
{"Content-Type": b"multipart/form-data; boundary=boundary"},
BytesIO(_BARE_LF_BODY),
fields.append,
files.append,
config={"ALLOW_BARE_LF": True},
)
_assert_bare_lf_parsed(fields, files)


def suite() -> unittest.TestSuite:
suite = unittest.TestSuite()
suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile))
Expand Down
Loading