11import typing
2+ from dataclasses import dataclass , field
23from enum import Enum
34from tempfile import SpooledTemporaryFile
45from urllib .parse import unquote_plus
@@ -21,15 +22,13 @@ class FormMessage(Enum):
2122 END = 5
2223
2324
24- class MultiPartMessage (Enum ):
25- PART_BEGIN = 1
26- PART_DATA = 2
27- PART_END = 3
28- HEADER_FIELD = 4
29- HEADER_VALUE = 5
30- HEADER_END = 6
31- HEADERS_FINISHED = 7
32- END = 8
25+ @dataclass
26+ class MultipartPart :
27+ content_disposition : typing .Optional [bytes ] = None
28+ field_name : str = ""
29+ data : bytes = b""
30+ file : typing .Optional [UploadFile ] = None
31+ item_headers : typing .List [typing .Tuple [bytes , bytes ]] = field (default_factory = list )
3332
3433
3534def _user_safe_decode (src : bytes , codec : str ) -> str :
@@ -120,53 +119,115 @@ class MultiPartParser:
120119 max_file_size = 1024 * 1024
121120
122121 def __init__ (
123- self , headers : Headers , stream : typing .AsyncGenerator [bytes , None ]
122+ self ,
123+ headers : Headers ,
124+ stream : typing .AsyncGenerator [bytes , None ],
125+ * ,
126+ max_files : typing .Union [int , float ] = 1000 ,
127+ max_fields : typing .Union [int , float ] = 1000 ,
124128 ) -> None :
125129 assert (
126130 multipart is not None
127131 ), "The `python-multipart` library must be installed to use form parsing."
128132 self .headers = headers
129133 self .stream = stream
130- self .messages : typing .List [typing .Tuple [MultiPartMessage , bytes ]] = []
134+ self .max_files = max_files
135+ self .max_fields = max_fields
136+ self .items : typing .List [typing .Tuple [str , typing .Union [str , UploadFile ]]] = []
137+ self ._current_files = 0
138+ self ._current_fields = 0
139+ self ._current_partial_header_name : bytes = b""
140+ self ._current_partial_header_value : bytes = b""
141+ self ._current_part = MultipartPart ()
142+ self ._charset = ""
143+ self ._file_parts_to_write : typing .List [typing .Tuple [MultipartPart , bytes ]] = []
144+ self ._file_parts_to_finish : typing .List [MultipartPart ] = []
131145
132146 def on_part_begin (self ) -> None :
133- message = (MultiPartMessage .PART_BEGIN , b"" )
134- self .messages .append (message )
147+ self ._current_part = MultipartPart ()
135148
136149 def on_part_data (self , data : bytes , start : int , end : int ) -> None :
137- message = (MultiPartMessage .PART_DATA , data [start :end ])
138- self .messages .append (message )
150+ message_bytes = data [start :end ]
151+ if self ._current_part .file is None :
152+ self ._current_part .data += message_bytes
153+ else :
154+ self ._file_parts_to_write .append ((self ._current_part , message_bytes ))
139155
140156 def on_part_end (self ) -> None :
141- message = (MultiPartMessage .PART_END , b"" )
142- self .messages .append (message )
157+ if self ._current_part .file is None :
158+ self .items .append (
159+ (
160+ self ._current_part .field_name ,
161+ _user_safe_decode (self ._current_part .data , self ._charset ),
162+ )
163+ )
164+ else :
165+ self ._file_parts_to_finish .append (self ._current_part )
166+ # The file can be added to the items right now even though it's not
167+ # finished yet, because it will be finished in the `parse()` method, before
168+ # self.items is used in the return value.
169+ self .items .append ((self ._current_part .field_name , self ._current_part .file ))
143170
144171 def on_header_field (self , data : bytes , start : int , end : int ) -> None :
145- message = (MultiPartMessage .HEADER_FIELD , data [start :end ])
146- self .messages .append (message )
172+ self ._current_partial_header_name += data [start :end ]
147173
148174 def on_header_value (self , data : bytes , start : int , end : int ) -> None :
149- message = (MultiPartMessage .HEADER_VALUE , data [start :end ])
150- self .messages .append (message )
175+ self ._current_partial_header_value += data [start :end ]
151176
152177 def on_header_end (self ) -> None :
153- message = (MultiPartMessage .HEADER_END , b"" )
154- self .messages .append (message )
178+ field = self ._current_partial_header_name .lower ()
179+ if field == b"content-disposition" :
180+ self ._current_part .content_disposition = self ._current_partial_header_value
181+ self ._current_part .item_headers .append (
182+ (field , self ._current_partial_header_value )
183+ )
184+ self ._current_partial_header_name = b""
185+ self ._current_partial_header_value = b""
155186
156187 def on_headers_finished (self ) -> None :
157- message = (MultiPartMessage .HEADERS_FINISHED , b"" )
158- self .messages .append (message )
188+ disposition , options = parse_options_header (
189+ self ._current_part .content_disposition
190+ )
191+ try :
192+ self ._current_part .field_name = _user_safe_decode (
193+ options [b"name" ], self ._charset
194+ )
195+ except KeyError :
196+ raise MultiPartException (
197+ 'The Content-Disposition header field "name" must be ' "provided."
198+ )
199+ if b"filename" in options :
200+ self ._current_files += 1
201+ if self ._current_files > self .max_files :
202+ raise MultiPartException (
203+ f"Too many files. Maximum number of files is { self .max_files } ."
204+ )
205+ filename = _user_safe_decode (options [b"filename" ], self ._charset )
206+ tempfile = SpooledTemporaryFile (max_size = self .max_file_size )
207+ self ._current_part .file = UploadFile (
208+ file = tempfile , # type: ignore[arg-type]
209+ size = 0 ,
210+ filename = filename ,
211+ headers = Headers (raw = self ._current_part .item_headers ),
212+ )
213+ else :
214+ self ._current_fields += 1
215+ if self ._current_fields > self .max_fields :
216+ raise MultiPartException (
217+ f"Too many fields. Maximum number of fields is { self .max_fields } ."
218+ )
219+ self ._current_part .file = None
159220
160221 def on_end (self ) -> None :
161- message = (MultiPartMessage .END , b"" )
162- self .messages .append (message )
222+ pass
163223
164224 async def parse (self ) -> FormData :
165225 # Parse the Content-Type header to get the multipart boundary.
166226 _ , params = parse_options_header (self .headers ["Content-Type" ])
167227 charset = params .get (b"charset" , "utf-8" )
168228 if type (charset ) == bytes :
169229 charset = charset .decode ("latin-1" )
230+ self ._charset = charset
170231 try :
171232 boundary = params [b"boundary" ]
172233 except KeyError :
@@ -186,68 +247,21 @@ async def parse(self) -> FormData:
186247
187248 # Create the parser.
188249 parser = multipart .MultipartParser (boundary , callbacks )
189- header_field = b""
190- header_value = b""
191- content_disposition = None
192- field_name = ""
193- data = b""
194- file : typing .Optional [UploadFile ] = None
195-
196- items : typing .List [typing .Tuple [str , typing .Union [str , UploadFile ]]] = []
197- item_headers : typing .List [typing .Tuple [bytes , bytes ]] = []
198-
199250 # Feed the parser with data from the request.
200251 async for chunk in self .stream :
201252 parser .write (chunk )
202- messages = list (self .messages )
203- self .messages .clear ()
204- for message_type , message_bytes in messages :
205- if message_type == MultiPartMessage .PART_BEGIN :
206- content_disposition = None
207- data = b""
208- item_headers = []
209- elif message_type == MultiPartMessage .HEADER_FIELD :
210- header_field += message_bytes
211- elif message_type == MultiPartMessage .HEADER_VALUE :
212- header_value += message_bytes
213- elif message_type == MultiPartMessage .HEADER_END :
214- field = header_field .lower ()
215- if field == b"content-disposition" :
216- content_disposition = header_value
217- item_headers .append ((field , header_value ))
218- header_field = b""
219- header_value = b""
220- elif message_type == MultiPartMessage .HEADERS_FINISHED :
221- disposition , options = parse_options_header (content_disposition )
222- try :
223- field_name = _user_safe_decode (options [b"name" ], charset )
224- except KeyError :
225- raise MultiPartException (
226- 'The Content-Disposition header field "name" must be '
227- "provided."
228- )
229- if b"filename" in options :
230- filename = _user_safe_decode (options [b"filename" ], charset )
231- tempfile = SpooledTemporaryFile (max_size = self .max_file_size )
232- file = UploadFile (
233- file = tempfile , # type: ignore[arg-type]
234- size = 0 ,
235- filename = filename ,
236- headers = Headers (raw = item_headers ),
237- )
238- else :
239- file = None
240- elif message_type == MultiPartMessage .PART_DATA :
241- if file is None :
242- data += message_bytes
243- else :
244- await file .write (message_bytes )
245- elif message_type == MultiPartMessage .PART_END :
246- if file is None :
247- items .append ((field_name , _user_safe_decode (data , charset )))
248- else :
249- await file .seek (0 )
250- items .append ((field_name , file ))
253+ # Write file data, it needs to use await with the UploadFile methods that
254+ # call the corresponding file methods *in a threadpool*, otherwise, if
255+ # they were called directly in the callback methods above (regular,
256+ # non-async functions), that would block the event loop in the main thread.
257+ for part , data in self ._file_parts_to_write :
258+ assert part .file # for type checkers
259+ await part .file .write (data )
260+ for part in self ._file_parts_to_finish :
261+ assert part .file # for type checkers
262+ await part .file .seek (0 )
263+ self ._file_parts_to_write .clear ()
264+ self ._file_parts_to_finish .clear ()
251265
252266 parser .finalize ()
253- return FormData (items )
267+ return FormData (self . items )
0 commit comments