Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions dvuploader/directupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import json
import os
from io import BytesIO
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin

import aiofiles
import httpx
from aiofiles.threadpool.binary import AsyncBufferedReader
from rich.progress import Progress, TaskID

from dvuploader.file import File
Expand All @@ -33,6 +34,8 @@
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId="
REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId="

DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB

# Initialize logging
init_logging()

Expand Down Expand Up @@ -399,34 +402,25 @@ async def _chunked_upload(
)

async with aiofiles.open(file.filepath, "rb") as f:
chunk = await f.read(chunk_size)
e_tags.append(
await _upload_chunk(
session=session,
url=next(urls),
file=BytesIO(chunk),
progress=progress,
pbar=pbar,
hash_func=file.checksum._hash_fun,
filesize = os.path.getsize(file.filepath)
current_position = 0

while current_position < filesize:
current_chunk_size = min(chunk_size, filesize - current_position)

e_tags.append(
await _upload_chunk(
session=session,
url=next(urls),
file=f,
progress=progress,
pbar=pbar,
hash_func=file.checksum._hash_fun,
chunk_size=current_chunk_size,
)
)
)

while chunk:
chunk = await f.read(chunk_size)

if not chunk:
break
else:
e_tags.append(
await _upload_chunk(
session=session,
url=next(urls),
file=BytesIO(chunk),
progress=progress,
pbar=pbar,
hash_func=file.checksum._hash_fun,
)
)
current_position += current_chunk_size

return e_tags

Expand All @@ -452,10 +446,11 @@ def _validate_ticket_response(response: Dict) -> None:
async def _upload_chunk(
session: httpx.AsyncClient,
url: str,
file: BytesIO,
file: Union[BytesIO, AsyncBufferedReader],
progress: Progress,
pbar: TaskID,
hash_func,
chunk_size: int,
):
"""
Upload a single chunk of data.
Expand All @@ -467,6 +462,7 @@ async def _upload_chunk(
progress (Progress): Progress tracking object.
pbar (TaskID): Progress bar task ID.
hash_func: Hash function for checksum.
chunk_size (int): Size of chunk to upload.

Returns:
str: ETag from server response.
Expand All @@ -475,8 +471,13 @@ async def _upload_chunk(
if TESTING:
url = url.replace("localstack", "localhost", 1)

if isinstance(file, BytesIO):
file_size = len(file.getvalue())
else:
file_size = chunk_size

headers = {
"Content-length": str(len(file.getvalue())),
"Content-length": str(file_size),
}

params = {
Expand All @@ -487,6 +488,7 @@ async def _upload_chunk(
progress=progress,
pbar=pbar,
hash_func=hash_func,
chunk_size=chunk_size,
),
}

Expand Down Expand Up @@ -664,10 +666,11 @@ async def _multipart_json_data_request(


async def upload_bytes(
file: BytesIO,
file: Union[BytesIO, AsyncBufferedReader],
progress: Progress,
pbar: TaskID,
hash_func,
chunk_size: Optional[int] = None,
) -> AsyncGenerator[bytes, None]:
"""
Generate chunks of file data for upload.
Expand All @@ -681,12 +684,27 @@ async def upload_bytes(
Yields:
bytes: Next chunk of file data.
"""
read_bytes = 0
while True:
data = file.read(1024 * 1024) # 1MB
if chunk_size is not None and read_bytes >= chunk_size:
break

if isinstance(file, AsyncBufferedReader):
data = await file.read(DEFAULT_CHUNK_SIZE) # 1MB
else:
data = file.read(DEFAULT_CHUNK_SIZE) # 1MB

if not data:
break

if chunk_size is not None:
remaining = chunk_size - read_bytes
if remaining <= 0:
break
if len(data) > remaining:
data = data[:remaining]
read_bytes += len(data)

# Update the hash function with the data
hash_func.update(data)

Expand Down