Files
artifact-store/tests/unit/test_storage_s3.py

197 lines
6.2 KiB
Python

"""S3-compatible backend tests (ARTIFACT-STORE-WP-0004)."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import pytest
from artifactstore.identity import ContentAddress, digest_bytes
from artifactstore.storage import ObjectNotFoundError, S3Backend, S3BackendConfig
async def _stream(data: bytes, chunk_size: int = 4) -> AsyncIterator[bytes]:
for i in range(0, len(data), chunk_size):
yield data[i : i + chunk_size]
async def _consume(stream: AsyncIterator[bytes]) -> bytes:
out = bytearray()
async for chunk in stream:
out.extend(chunk)
return bytes(out)
def _ca(data: bytes) -> ContentAddress:
return digest_bytes(data).primary.content_address
class FakeNotFoundError(Exception):
def __init__(self) -> None:
super().__init__("not found")
self.response = {"Error": {"Code": "NoSuchKey"}}
class FakeBody:
def __init__(self, data: bytes) -> None:
self._data = data
self._offset = 0
async def read(self, size: int) -> bytes:
if self._offset >= len(self._data):
return b""
chunk = self._data[self._offset : self._offset + size]
self._offset += len(chunk)
return chunk
class FakeS3Client:
def __init__(self) -> None:
self.objects: dict[str, bytes] = {}
self.calls: list[tuple[str, dict[str, Any]]] = []
self.uploads: dict[str, list[tuple[int, bytes]]] = {}
async def __aenter__(self) -> FakeS3Client:
return self
async def __aexit__(self, *_exc: object) -> None:
return None
async def put_object(self, **kwargs: Any) -> None:
self.calls.append(("put_object", kwargs))
self.objects[kwargs["Key"]] = kwargs["Body"]
async def create_multipart_upload(self, **kwargs: Any) -> dict[str, str]:
self.calls.append(("create_multipart_upload", kwargs))
upload_id = f"upload-{len(self.uploads) + 1}"
self.uploads[upload_id] = []
return {"UploadId": upload_id}
async def upload_part(self, **kwargs: Any) -> dict[str, str]:
self.calls.append(("upload_part", kwargs))
self.uploads[kwargs["UploadId"]].append((kwargs["PartNumber"], kwargs["Body"]))
return {"ETag": f"etag-{kwargs['PartNumber']}"}
async def complete_multipart_upload(self, **kwargs: Any) -> None:
self.calls.append(("complete_multipart_upload", kwargs))
parts = self.uploads[kwargs["UploadId"]]
self.objects[kwargs["Key"]] = b"".join(part for _num, part in sorted(parts))
async def abort_multipart_upload(self, **kwargs: Any) -> None:
self.calls.append(("abort_multipart_upload", kwargs))
async def get_object(self, **kwargs: Any) -> dict[str, FakeBody]:
self.calls.append(("get_object", kwargs))
try:
data = self.objects[kwargs["Key"]]
except KeyError as exc:
raise FakeNotFoundError from exc
range_header = kwargs.get("Range")
if range_header:
bounds = str(range_header).removeprefix("bytes=").split("-", maxsplit=1)
start = int(bounds[0])
end = int(bounds[1])
data = data[start : end + 1]
return {"Body": FakeBody(data)}
async def head_object(self, **kwargs: Any) -> dict[str, int]:
self.calls.append(("head_object", kwargs))
try:
data = self.objects[kwargs["Key"]]
except KeyError as exc:
raise FakeNotFoundError from exc
return {"ContentLength": len(data)}
async def delete_object(self, **kwargs: Any) -> None:
self.calls.append(("delete_object", kwargs))
self.objects.pop(kwargs["Key"], None)
async def head_bucket(self, **kwargs: Any) -> None:
self.calls.append(("head_bucket", kwargs))
@pytest.fixture
def fake_client() -> FakeS3Client:
return FakeS3Client()
@pytest.fixture
def backend(fake_client: FakeS3Client) -> S3Backend:
return S3Backend(
S3BackendConfig(
endpoint_url="http://minio.test",
region="us-east-1",
bucket="artifacts",
key_prefix="artifact-store",
storage_class="STANDARD",
sse="AES256",
multipart_threshold_bytes=8,
multipart_chunk_bytes=5,
),
client_factory=lambda: fake_client,
chunk_size=3,
)
async def test_put_get_head_delete_round_trip(
backend: S3Backend,
fake_client: FakeS3Client,
) -> None:
data = b"abc"
ca = _ca(data)
receipt = await backend.put(ca, _stream(data), size_hint=len(data))
digest = ca.to_digest()
assert receipt.object_key == (
f"artifact-store/{digest.algorithm}/{digest.hex[:2]}/{digest.hex[2:4]}/{digest.hex}"
)
assert fake_client.calls[0][0] == "put_object"
assert fake_client.calls[0][1]["StorageClass"] == "STANDARD"
assert fake_client.calls[0][1]["ServerSideEncryption"] == "AES256"
meta = await backend.head(ca)
assert meta.size_bytes == len(data)
stream = await backend.get(ca)
assert await _consume(stream) == data
await backend.delete(ca)
with pytest.raises(ObjectNotFoundError):
await backend.head(ca)
async def test_get_supports_range(backend: S3Backend, fake_client: FakeS3Client) -> None:
data = b"0123456789"
ca = _ca(data)
await backend.put(ca, _stream(data), size_hint=len(data))
stream = await backend.get(ca, byte_range=(2, 5))
assert await _consume(stream) == b"2345"
assert fake_client.calls[-1][1]["Range"] == "bytes=2-5"
async def test_put_uses_multipart_above_threshold(
backend: S3Backend,
fake_client: FakeS3Client,
) -> None:
data = b"abcdefghijkl"
ca = _ca(data)
receipt = await backend.put(ca, _stream(data), size_hint=len(data))
assert receipt.size_bytes == len(data)
assert [name for name, _kwargs in fake_client.calls] == [
"create_multipart_upload",
"upload_part",
"upload_part",
"upload_part",
"complete_multipart_upload",
]
stream = await backend.get(ca)
assert await _consume(stream) == data
async def test_health_uses_head_bucket(backend: S3Backend) -> None:
status = await backend.health()
assert status.healthy is True
assert status.backend_id == "s3"