diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index d50e5b3..1e503ad 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -55,6 +55,7 @@ class FormParserConfig(FileConfig): MAX_BODY_SIZE: float MAX_HEADER_COUNT: int MAX_HEADER_SIZE: int + ALLOW_BARE_LF: bool CallbackName: TypeAlias = Literal[ "start", @@ -997,6 +998,9 @@ class MultipartParser(BaseParser): max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. max_header_count: The maximum number of headers allowed per part. max_header_size: The maximum size of a single header line (excluding the trailing CRLF). + allow_bare_lf: Accept bare ``\\n`` instead of ``\\r\\n`` as the framing line terminator. Defaults to ``False`` + (strict RFC 2046). The convention is detected from the opening boundary and applied to the whole message; + part data is never altered. """ # noqa: E501 def __init__( @@ -1007,6 +1011,7 @@ def __init__( *, max_header_count: int = DEFAULT_MAX_HEADER_COUNT, max_header_size: int = DEFAULT_MAX_HEADER_SIZE, + allow_bare_lf: bool = False, ) -> None: # Initialize parser state. super().__init__() @@ -1017,7 +1022,7 @@ def __init__( if not isinstance(max_size, Number) or max_size < 1: raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size = max_size + self.max_size: int | float = max_size self._current_size = 0 self.max_header_count = max_header_count @@ -1029,11 +1034,15 @@ def __init__( # Setup marks. These are used to track the state of data received. self.marks: dict[str, int] = {} + self.allow_bare_lf = allow_bare_lf + self._lf_only = False + # Save our boundary. if isinstance(boundary, str): # pragma: no cover boundary = boundary.encode("latin-1") if len(boundary) > MAX_BOUNDARY_LENGTH: raise FormParserError(f"Boundary length {len(boundary)} exceeds maximum of {MAX_BOUNDARY_LENGTH}") + self._boundary_token = boundary self.boundary = b"\r\n--" + boundary def write(self, data: bytes) -> int: @@ -1085,6 +1094,9 @@ def _internal_write(self, data: bytes, length: int) -> int: current_header_count = self._current_header_count current_header_size = self._current_header_size + allow_bare_lf = self.allow_bare_lf + lf_only = self._lf_only + # Our index defaults to 0. i = 0 @@ -1129,10 +1141,10 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No if lookbehind_len <= boundary_length: self.callback(name, boundary, 0, lookbehind_len) elif self.flags & FLAG_PART_BOUNDARY: - lookback = boundary + b"\r\n" + lookback = boundary + (b"\n" if self._lf_only else b"\r\n") self.callback(name, lookback, 0, lookbehind_len) elif self.flags & FLAG_LAST_BOUNDARY: - lookback = boundary + b"--\r\n" + lookback = boundary + (b"--\n" if self._lf_only else b"--\r\n") self.callback(name, lookback, 0, lookbehind_len) else: # pragma: no cover (error case) self.logger.warning("Look-back buffer error") @@ -1175,14 +1187,24 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No if c == HYPHEN: # Potential empty message. state = MultipartState.END_BOUNDARY - elif c != CR: - # Error! + index += 1 + elif c == CR: + index += 1 + elif allow_bare_lf and c == LF: + # Bare-LF body: use the \n-prefixed boundary. + lf_only = self._lf_only = True + boundary = self.boundary = b"\n--" + self._boundary_token + boundary_length = len(boundary) + index = 0 + self.callback("part_begin") + current_header_count = 0 + current_header_size = 0 + state = MultipartState.HEADER_FIELD_START + else: msg = "Did not find CR at end of boundary (%d)" % (i,) self.logger.warning(msg) raise MultipartParseError(msg, offset=i) - index += 1 - elif index == boundary_length - 1: if c != LF: msg = "Did not find LF at end of boundary (%d)" % (i,) @@ -1214,8 +1236,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # Mark the start of a header field here, reset the index, and # continue parsing our header field. index = 0 + nl = LF if lf_only else CR - if c != CR: + if c != nl: current_header_count += 1 if current_header_count > self.max_header_count: raise MultipartParseError("Maximum header count exceeded", offset=i) @@ -1224,11 +1247,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # Set a mark of our header field. set_mark("header_field") - # Notify that we're starting a header if the next character is - # not a CR; a CR at the beginning of the header will cause us - # to stop parsing headers in the MultipartState.HEADER_FIELD state, - # below. - if c != CR: + # Notify that we're starting a header unless this is the blank + # line that ends the headers (handled in HEADER_FIELD below). + if c != nl: self.callback("header_begin") # Move to parsing header fields. @@ -1236,12 +1257,15 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No i -= 1 elif state == MultipartState.HEADER_FIELD: - # If we've reached a CR at the beginning of a header, it means - # that we've reached the second of 2 newlines, and so there are - # no more headers to parse. - if c == CR and index == 0: + # A line terminator at the start of a header is the blank line + # that ends the headers. + if index == 0 and c == (LF if lf_only else CR): delete_mark("header_field") - state = MultipartState.HEADERS_ALMOST_DONE + if lf_only: + self.callback("headers_finished") + state = MultipartState.PART_DATA_START + else: + state = MultipartState.HEADERS_ALMOST_DONE i += 1 continue @@ -1295,17 +1319,16 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No i -= 1 elif state == MultipartState.HEADER_VALUE: - # The value runs until the terminating CR; jump straight to it - # instead of inspecting every byte. - cr = data.find(b"\r", i, length) - end = cr if cr != -1 else length + # The value runs until its terminator (CR for CRLF, LF for bare LF). + term = data.find(b"\n" if lf_only else b"\r", i, length) + end = term if term != -1 else length advance_header_size(end - i) - if cr != -1: - i = cr + if term != -1: + i = term data_callback("header_value", i) self.callback("header_end") current_header_size = 0 - state = MultipartState.HEADER_VALUE_ALMOST_DONE + state = MultipartState.HEADER_FIELD_START if lf_only else MultipartState.HEADER_VALUE_ALMOST_DONE else: i = length @@ -1387,12 +1410,24 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # Our index is equal to the length of our boundary! elif index == boundary_length: + if lf_only and c == LF: + # Bare-LF part boundary: the single LF terminates it. + data_callback("part_data", i - index) + self.callback("part_end") + self.callback("part_begin") + current_header_count = 0 + current_header_size = 0 + index = 0 + state = MultipartState.HEADER_FIELD_START + i += 1 + continue + # First we increment it. index += 1 # Now, if we've reached a newline, we need to set this as # the potential end of our boundary. - if c == CR: + if c == CR and not lf_only: flags |= FLAG_PART_BOUNDARY # Otherwise, if this is a hyphen, we might be at the last @@ -1561,6 +1596,8 @@ class FormParser: "UPLOAD_KEEP_EXTENSIONS": False, # Error on invalid Content-Transfer-Encoding? "UPLOAD_ERROR_ON_BAD_CTE": False, + # Accept bare LF (instead of CRLF) as the framing line terminator. + "ALLOW_BARE_LF": False, } def __init__( @@ -1795,6 +1832,7 @@ def _on_end() -> None: max_size=self.config["MAX_BODY_SIZE"], max_header_count=self.config["MAX_HEADER_COUNT"], max_header_size=self.config["MAX_HEADER_SIZE"], + allow_bare_lf=self.config["ALLOW_BARE_LF"], ) else: @@ -1876,6 +1914,7 @@ def parse_form( on_field: Callable[[Field], None] | None, on_file: Callable[[File], None] | None, chunk_size: int = 1048576, + config: dict[Any, Any] = {}, ) -> None: """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's @@ -1889,12 +1928,13 @@ def parse_form( on_file: Callback to call with each parsed file. chunk_size: The maximum size to read from the input stream and write to the parser at one time. Defaults to 1 MiB. + config: Configuration variables to pass to the FormParser (e.g. ``{"ALLOW_BARE_LF": True}``). """ if chunk_size < 1: raise ValueError(f"chunk_size must be a positive number, not {chunk_size!r}") # Create our form parser. - parser = create_form_parser(headers, on_field, on_file) + parser = create_form_parser(headers, on_field, on_file, config) # Read chunks of 1MiB and write to the parser, but never read more than # the given Content-Length, if any. diff --git a/tests/compat.py b/tests/compat.py index 3c60fac..bc41801 100644 --- a/tests/compat.py +++ b/tests/compat.py @@ -5,7 +5,7 @@ import re import sys import types -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, cast if TYPE_CHECKING: from collections.abc import Callable @@ -61,7 +61,7 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: class ParametrizingMetaclass(type): IDENTIFIER_RE = re.compile("[^A-Za-z0-9]") - def __new__(klass, name: str, bases: tuple[type, ...], attrs: types.MappingProxyType[str, Any]) -> type: + def __new__(cls, name: str, bases: tuple[type, ...], attrs: types.MappingProxyType[str, Any]) -> type: new_attrs = attrs.copy() for attr_name, attr in attrs.items(): # We only care about functions @@ -74,15 +74,14 @@ def __new__(klass, name: str, bases: tuple[type, ...], attrs: types.MappingProxy continue # Create multiple copies of the function. - for _, values in enumerate(param_values): + for i, values in enumerate(param_values): assert len(param_names) == len(values) # Get a repr of the values, and fix it to be a valid identifier - human = "_".join([klass.IDENTIFIER_RE.sub("", repr(x)) for x in values]) + human = "_".join([cls.IDENTIFIER_RE.sub("", repr(x)) for x in values]) - # Create a new name. - # new_name = attr.__name__ + "_%d" % i - new_name = attr.__name__ + "__" + human + # Cap length so multi-MB params can't overflow Windows' PYTEST_CURRENT_TEST env var. + new_name = f"{attr.__name__}__{i}_{human[:64]}" # Create a replacement function. def create_new_func( @@ -109,9 +108,12 @@ def new_func(self: types.FunctionType) -> Any: del new_attrs[attr_name] # We create the class as normal, except we use our new attributes. - return type.__new__(klass, name, bases, new_attrs) + return type.__new__(cls, name, bases, new_attrs) # This is a class decorator that actually applies the above metaclass. -def parametrize_class(klass: type) -> ParametrizingMetaclass: - return ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__) +_ClassT = TypeVar("_ClassT", bound=type) + + +def parametrize_class(klass: _ClassT) -> _ClassT: + return cast(_ClassT, ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__)) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 8dda87c..839534b 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1130,6 +1130,7 @@ def test_request_body_fuzz(self) -> None: # Pick what we're supposed to do. choice = random.choice([1, 2, 3]) + msg = "" if choice == 1: # Add a random byte. i = random.randrange(len(test_data)) @@ -1661,6 +1662,72 @@ def test_parse_form_negative_content_length(self) -> None: ) +# A field followed by a file, framed entirely with bare LF instead of CRLF. +_BARE_LF_BODY = ( + b"--boundary\n" + b'Content-Disposition: form-data; name="field"\n' + b"\n" + b"test1\n" + b"--boundary\n" + b'Content-Disposition: form-data; name="file"; filename="file.txt"\n' + b"Content-Type: text/plain\n" + b"\n" + b"test2\n" + b"--boundary--" +) + + +def _parse_bare_lf(chunks: list[bytes]) -> tuple[list[Field], list[File]]: + """Feed ``chunks`` to a bare-LF-enabled FormParser and return the parsed fields/files.""" + fields: list[Field] = [] + files: list[File] = [] + parser = FormParser( + "multipart/form-data", fields.append, files.append, boundary="boundary", config={"ALLOW_BARE_LF": True} + ) + for chunk in chunks: + parser.write(chunk) + parser.finalize() + return fields, files + + +def _assert_bare_lf_parsed(fields: list[Field], files: list[File]) -> None: + """Assert the field + file framed in ``_BARE_LF_BODY`` were parsed correctly.""" + assert [(f.field_name, f.value) for f in fields] == [(b"field", b"test1")] + assert len(files) == 1 + assert files[0].field_name == b"file" + assert files[0].file_name == b"file.txt" + files[0].file_object.seek(0) + assert files[0].file_object.read() == b"test2" + files[0].close() + + +def test_bare_lf_boundary_rejected_by_default() -> None: + """A bare-LF boundary is a parse error unless explicitly allowed.""" + parser = FormParser("multipart/form-data", lambda _: None, lambda _: None, boundary="boundary") + with pytest.raises(MultipartParseError): + parser.write(_BARE_LF_BODY) + + +def test_bare_lf_body_parses_across_every_chunk_split() -> None: + """The LF body parses identically however it is split, exercising the boundary look-behind.""" + for first, last in split_all(_BARE_LF_BODY): + _assert_bare_lf_parsed(*_parse_bare_lf([first, last])) + + +def test_parse_form_forwards_config_to_enable_bare_lf() -> None: + """parse_form forwards config, so the bare-LF opt-in is reachable from the top-level helper.""" + fields: list[Field] = [] + files: list[File] = [] + parse_form( + {"Content-Type": b"multipart/form-data; boundary=boundary"}, + BytesIO(_BARE_LF_BODY), + fields.append, + files.append, + config={"ALLOW_BARE_LF": True}, + ) + _assert_bare_lf_parsed(fields, files) + + def suite() -> unittest.TestSuite: suite = unittest.TestSuite() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile))