diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index 9e2d974..0d8a666 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -2,11 +2,12 @@ import json import os from io import BytesIO -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union from urllib.parse import urljoin import aiofiles import httpx +from aiofiles.threadpool.binary import AsyncBufferedReader from rich.progress import Progress, TaskID from dvuploader.file import File @@ -33,6 +34,8 @@ UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId=" REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId=" +DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB + # Initialize logging init_logging() @@ -399,34 +402,25 @@ async def _chunked_upload( ) async with aiofiles.open(file.filepath, "rb") as f: - chunk = await f.read(chunk_size) - e_tags.append( - await _upload_chunk( - session=session, - url=next(urls), - file=BytesIO(chunk), - progress=progress, - pbar=pbar, - hash_func=file.checksum._hash_fun, + filesize = os.path.getsize(file.filepath) + current_position = 0 + + while current_position < filesize: + current_chunk_size = min(chunk_size, filesize - current_position) + + e_tags.append( + await _upload_chunk( + session=session, + url=next(urls), + file=f, + progress=progress, + pbar=pbar, + hash_func=file.checksum._hash_fun, + chunk_size=current_chunk_size, + ) ) - ) - - while chunk: - chunk = await f.read(chunk_size) - if not chunk: - break - else: - e_tags.append( - await _upload_chunk( - session=session, - url=next(urls), - file=BytesIO(chunk), - progress=progress, - pbar=pbar, - hash_func=file.checksum._hash_fun, - ) - ) + current_position += current_chunk_size return e_tags @@ -452,10 +446,11 @@ def _validate_ticket_response(response: Dict) -> None: async def _upload_chunk( session: httpx.AsyncClient, url: str, - file: BytesIO, + file: Union[BytesIO, AsyncBufferedReader], progress: Progress, pbar: TaskID, hash_func, + chunk_size: int, ): """ Upload a single chunk of data. @@ -467,6 +462,7 @@ async def _upload_chunk( progress (Progress): Progress tracking object. pbar (TaskID): Progress bar task ID. hash_func: Hash function for checksum. + chunk_size (int): Size of chunk to upload. Returns: str: ETag from server response. @@ -475,8 +471,13 @@ async def _upload_chunk( if TESTING: url = url.replace("localstack", "localhost", 1) + if isinstance(file, BytesIO): + file_size = len(file.getvalue()) + else: + file_size = chunk_size + headers = { - "Content-length": str(len(file.getvalue())), + "Content-length": str(file_size), } params = { @@ -487,6 +488,7 @@ async def _upload_chunk( progress=progress, pbar=pbar, hash_func=hash_func, + chunk_size=chunk_size, ), } @@ -664,10 +666,11 @@ async def _multipart_json_data_request( async def upload_bytes( - file: BytesIO, + file: Union[BytesIO, AsyncBufferedReader], progress: Progress, pbar: TaskID, hash_func, + chunk_size: Optional[int] = None, ) -> AsyncGenerator[bytes, None]: """ Generate chunks of file data for upload. @@ -681,12 +684,27 @@ async def upload_bytes( Yields: bytes: Next chunk of file data. """ + read_bytes = 0 while True: - data = file.read(1024 * 1024) # 1MB + if chunk_size is not None and read_bytes >= chunk_size: + break + + if isinstance(file, AsyncBufferedReader): + data = await file.read(DEFAULT_CHUNK_SIZE) # 1MB + else: + data = file.read(DEFAULT_CHUNK_SIZE) # 1MB if not data: break + if chunk_size is not None: + remaining = chunk_size - read_bytes + if remaining <= 0: + break + if len(data) > remaining: + data = data[:remaining] + read_bytes += len(data) + # Update the hash function with the data hash_func.update(data)