Merge pull request #13 from Molecule-AI/gap-03-fix

feat(sdk): GAP-03 conftest, GAP-05 retry backoff, KI-002 idempotency key
This commit is contained in:
Hongming Wang
2026-04-24 13:27:24 -07:00
committed by GitHub
9 changed files with 1447 additions and 535 deletions
+29
View File
@@ -0,0 +1,29 @@
name: Test
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install -e ".[test]"
- name: Run tests
run: python -m pytest tests/
- name: Lint
run: pip install ruff && ruff check molecule_agent/ molecule_plugin/
+2
View File
@@ -34,6 +34,7 @@ Design notes:
from __future__ import annotations
from .a2a_server import A2AServer
from .client import (
PeerInfo,
RemoteAgentClient,
@@ -46,6 +47,7 @@ from .client import (
from .__main__ import compute_plugin_sha256
__all__ = [
"A2AServer",
"RemoteAgentClient",
"WorkspaceState",
"PeerInfo",
+229
View File
@@ -0,0 +1,229 @@
"""A2A server for inbound agent calls.
Bundled alongside :class:`molecule_agent.client.RemoteAgentClient` to
enable remote agents to receive A2A calls from the platform without
requiring the agent author to provision their own HTTP endpoint.
Phase 30.8b contract — the server exposes ``POST /a2a/inbound`` which
the platform's ingress proxy calls when it needs to push work to a
registered remote agent.
Usage::
from molecule_agent import RemoteAgentClient, A2AServer
client = RemoteAgentClient(workspace_id="...", platform_url="...")
server = A2AServer(
agent_id=client.workspace_id,
inbound_url="https://my-agent.example.com/a2a/inbound",
message_handler=my_handler,
)
# Start server in background thread, then register with platform.
server.start_in_background()
client.reported_url = server.inbound_url # platform reaches this URL
token = client.register()
# Heartbeat loop now reports a real URL instead of "remote://no-inbound".
client.run_heartbeat_loop()
# Shutdown the server when the agent exits.
server.stop()
The ``message_handler`` signature is::
async def my_handler(request: dict) -> dict:
'''Return an A2A-formatted response dict.'''
...
Handlers are invoked on the server's internal thread pool.
"""
from __future__ import annotations
import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Callable, Awaitable
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Module-level HTTPServer instance so the handler can access server state.
_server: HTTPServer | None = None
_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Handler
# ---------------------------------------------------------------------------
class _A2AHandler(BaseHTTPRequestHandler):
"""Handles ``POST /a2a/inbound`` requests.
The request body is a JSON A2A task dispatch dict::
{
"task_id": "...",
"sender": "...",
"message": "...",
"idempotency_key": "...",
}
The ``message_handler`` ( supplied at construction) is called with the
parsed dict and its return value is written as a JSON response::
200 {"status": "ok", "result": <handler-result>}
400 {"error": "bad request: ..."}
500 {"error": "internal error: ..."}
"""
protocol_version = "HTTP/1.1"
def log_message(self, format: str, *args: Any) -> None:
"""Suppress default stderr noise; use structured logging instead."""
logger.debug("%s %s%s", self.command, self.path, format % args)
def log_error(self, format: str, *args: Any) -> None:
logger.warning("%s %s%s", self.command, self.path, format % args)
def _send_json(self, status: int, body: dict) -> None:
body_bytes = json.dumps(body).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body_bytes)))
self.end_headers()
if self.command != "HEAD":
self.wfile.write(body_bytes)
def do_POST(self) -> None:
parsed = urlparse(self.path)
if parsed.path != "/a2a/inbound":
self._send_json(404, {"error": "not found"})
return
try:
content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
raise ValueError("empty body")
body = self.rfile.read(content_length)
payload = json.loads(body)
except (ValueError, json.JSONDecodeError) as exc:
self._send_json(400, {"error": f"bad request: {exc}"})
return
try:
result = _A2AHandler._message_handler(payload)
if isinstance(result, Awaitable):
# If the handler is async, run it synchronously in the server thread.
# Agents that want full async semantics should use an explicit ASGI app;
# this path covers the common case of a simple sync handler.
import asyncio
loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(result)
finally:
loop.close()
self._send_json(200, {"status": "ok", "result": result})
except Exception as exc:
logger.exception("message_handler raised: %s", exc)
self._send_json(500, {"error": f"internal error: {exc}"})
# ---------------------------------------------------------------------------
# A2AServer
# ---------------------------------------------------------------------------
class A2AServer:
"""HTTP server that receives inbound A2A calls and dispatches them to a
handler running alongside :class:`~molecule_agent.client.RemoteAgentClient`.
Args:
agent_id: The workspace / agent identifier. Used in log messages.
inbound_url: The URL the platform's ingress proxy uses to reach this
server. Must be a reachable host:port (or a publicly accessible
URL if a tunnel is in front). The value is typically assigned to
``RemoteAgentClient.reported_url`` before registration so the
platform knows where to deliver inbound calls.
message_handler: Callable that receives a parsed A2A task dict and
returns a dict response. May be ``async def`` or regular ``def``.
host: Address to bind the HTTP server to. Defaults to ``"0.0.0.0"``
(all interfaces); bind to ``"127.0.0.1"`` if behind a reverse
proxy or tunnel.
port: TCP port to listen on. ``0`` picks an available ephemeral port
(useful when the real public URL is managed by a proxy/tunnel).
"""
def __init__(
self,
agent_id: str,
inbound_url: str,
message_handler: Callable[[dict], dict | Awaitable[dict]],
host: str = "0.0.0.0",
port: int = 0,
) -> None:
self.agent_id = agent_id
self.inbound_url = inbound_url
self.host = host
self.port = port
self._handler = message_handler
self._server: HTTPServer | None = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
# -------------------------------------------------------------------------
# Lifecycle
# -------------------------------------------------------------------------
def start_in_background(self) -> None:
"""Start the HTTP server in a daemon thread and return immediately.
Call :py:meth:`stop` to shut it down cleanly.
"""
global _server
with _lock:
self._server = HTTPServer((self.host, self.port), _A2AHandler)
_server = self._server
_A2AHandler._server = self # type: ignore[attr-defined]
_A2AHandler._message_handler = self._handler # type: ignore[attr-defined]
actual = self._server.server_address
logger.info(
"A2AServer for %s listening on %s:%s (inbound_url=%s)",
self.agent_id, actual[0], actual[1], self.inbound_url,
)
self._thread = threading.Thread(target=self._serve_forever, daemon=True)
self._thread.start()
def _serve_forever(self) -> None:
assert self._server is not None
while not self._stop_event.is_set():
try:
self._server.timeout = 0.5
self._server.handle_request()
except Exception as exc:
if not self._stop_event.is_set():
logger.warning("A2AServer handle_request raised: %s", exc)
def stop(self, timeout: float = 5.0) -> None:
"""Stop the HTTP server and join the background thread.
Idempotent — safe to call multiple times.
"""
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=timeout)
self._thread = None
if self._server is not None:
try:
self._server.server_close()
except Exception as exc:
logger.warning("A2AServer server_close raised: %s", exc)
self._server = None
global _server
with _lock:
_server = None
__all__ = ["A2AServer"]
+86 -4
View File
@@ -12,8 +12,9 @@ a Phase 30 endpoint:
returns when the platform reports the workspace paused or deleted.
No inbound A2A server is bundled here yet — that requires hosting an HTTP
endpoint the platform's proxy can reach, which is network-dependent. A
future 30.8b iteration will add an optional ``start_a2a_server()`` helper.
endpoint the platform's proxy can reach, which is network-dependent.
Use :class:`molecule_agent.a2a_server.A2AServer` to add inbound A2A support.
See that module for usage and the Phase 30.8b contract.
"""
from __future__ import annotations
@@ -24,7 +25,6 @@ import logging
import math
import os
import random
import stat
import subprocess
import tarfile
import time
@@ -57,6 +57,35 @@ _RETRY_BASE_DELAY = 1.0 # seconds — first delay
_RETRY_MAX_DELAY = 30.0 # seconds — cap
_RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay
# KI-002 — idempotency key granularity: round to the current minute so
# that concurrent restarts within the same 60-second window produce the
# same key, while distinct tasks or distinct minutes produce distinct keys.
_IDEMPOTENCY_ROUND_SECONDS = 60
def make_idempotency_key(task_text: str) -> str:
"""Compute a deterministic idempotency key for a delegation task.
Combines the task text with the current wall-clock minute to produce
a SHA-256 hex digest. Rounding to minute-level means two container
restarts within the same minute that send the same task string will
share the same key, preventing the platform from processing a duplicate
delegation. A different minute (or a different task string) yields a
different key.
Args:
task_text: The task description string being delegated.
Returns:
A 64-character hex string (SHA-256 digest).
"""
# Round current time down to the nearest minute — same-task restarts
# within this minute share a key; after the minute rolls over the key
# changes so a genuinely new task is always treated as new.
now = int(time.time()) // _IDEMPOTENCY_ROUND_SECONDS * _IDEMPOTENCY_ROUND_SECONDS
payload = f"{task_text}:{now}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None:
"""Extract a tarfile, refusing entries that would escape `dest`
@@ -658,6 +687,58 @@ class RemoteAgentClient:
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Delegation — KI-002 idempotency guard
# ------------------------------------------------------------------
def delegate(
self,
task: str,
target_id: str,
idempotency_key: str | None = None,
timeout: float = 300.0,
) -> dict[str, Any]:
"""Delegate a task to a peer workspace via the platform proxy.
KI-002: To prevent duplicate execution when a container restarts mid-
delegation, an idempotency key is computed from ``task + current
minute`` and sent as ``idempotency_key`` in the request body. The
platform deduplicates requests sharing the same key within the
minute window. Pass an explicit ``idempotency_key`` to override the
auto-computed value (useful for callers that manage their own key
scheme).
Args:
task: Human-readable task description sent to the target.
target_id: Workspace ID of the peer to delegate to.
idempotency_key: Optional override for the idempotency key. If
omitted, one is auto-generated from the task text + current
wall-clock minute.
timeout: Request timeout in seconds. Default 300 s.
Returns:
The platform's JSON response dict.
Raises:
``requests.HTTPError`` on non-2xx responses.
"""
key = idempotency_key if idempotency_key else make_idempotency_key(task)
resp = self._session.post(
f"{self.platform_url}/workspaces/{target_id}/delegate",
headers={
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
"Content-Type": "application/json",
},
json={
"task": task,
"idempotency_key": key,
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Plugin install (Phase 30.3)
# ------------------------------------------------------------------
@@ -877,6 +958,7 @@ __all__ = [
"DEFAULT_HEARTBEAT_INTERVAL",
"DEFAULT_STATE_POLL_INTERVAL",
"DEFAULT_URL_CACHE_TTL",
"compute_plugin_sha256",
"verify_plugin_sha256",
"make_idempotency_key",
]
+1 -1
View File
@@ -14,7 +14,6 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
@@ -115,3 +114,4 @@ def validate_workspace_template(path: Path) -> list[ValidationError]:
# Re-exported for type hints in __init__.py
__all__ = ["ValidationError", "SUPPORTED_RUNTIMES", "validate_workspace_template"]
+217
View File
@@ -0,0 +1,217 @@
"""Tests for molecule_agent.a2a_server."""
from __future__ import annotations
import json
import threading
from http.client import HTTPConnection
from unittest.mock import MagicMock
import time
import pytest
from molecule_agent.a2a_server import A2AServer
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _post_json(host: str, port: int, payload: dict) -> tuple[int, dict]:
conn = HTTPConnection(host, port, timeout=5)
body = json.dumps(payload).encode()
conn.request("POST", "/a2a/inbound", body=body, headers={"Content-Type": "application/json"})
resp = conn.getresponse()
return resp.status, json.loads(resp.read())
# ---------------------------------------------------------------------------
# A2AServer tests
# ---------------------------------------------------------------------------
def test_start_stop() -> None:
"""Server starts, binds an ephemeral port, and shuts down cleanly."""
handler = MagicMock(return_value={"ack": True})
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
assert host in ("0.0.0.0", "127.0.0.1", "::")
assert isinstance(port, int) and port > 0
finally:
server.stop()
def test_stop_idempotent() -> None:
"""stop() called twice does not raise."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
server.stop()
server.stop() # must not raise
def test_inbound_call_routes_to_handler() -> None:
"""POST /a2a/inbound calls message_handler and returns 200."""
handler = MagicMock(return_value={"task_id": "reply-123"})
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "req-1", "message": "ping"})
assert status == 200
assert body["status"] == "ok"
assert body["result"] == {"task_id": "reply-123"}
handler.assert_called_once_with({"task_id": "req-1", "message": "ping"})
finally:
server.stop()
def test_non_json_body_returns_400() -> None:
"""Malformed JSON body returns 400 with error detail."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/a2a/inbound", body=b"not json{", headers={"Content-Type": "application/json"})
resp = conn.getresponse()
assert resp.status == 400
body = json.loads(resp.read())
assert "error" in body
finally:
server.stop()
def test_empty_body_returns_400() -> None:
"""Empty body returns 400."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/a2a/inbound", body=b"", headers={"Content-Length": "0"})
resp = conn.getresponse()
assert resp.status == 400
finally:
server.stop()
def test_wrong_path_returns_404() -> None:
"""A POST to any path other than /a2a/inbound returns 404."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/other/path", body=b"{}")
resp = conn.getresponse()
assert resp.status == 404
handler.assert_not_called()
finally:
server.stop()
def test_handler_exception_returns_500() -> None:
"""Handler raising an exception returns 500, not crashing the server."""
handler = MagicMock(side_effect=RuntimeError("boom"))
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "req-1"})
assert status == 500
assert "error" in body
finally:
server.stop()
def test_async_handler_runs_sync() -> None:
"""An async handler is run to completion synchronously."""
async_calls: list = []
async def async_handler(payload: dict) -> dict:
async_calls.append(payload)
return {"async": True}
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=async_handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "async-req"})
assert status == 200
assert body["result"] == {"async": True}
assert len(async_calls) == 1
finally:
server.stop()
def test_concurrent_requests() -> None:
"""Multiple simultaneous POSTs are handled without crashing the server."""
call_count = {"count": 0}
lock = threading.Lock()
def counting_handler(payload: dict) -> dict:
with lock:
call_count["count"] += 1
time.sleep(0.05) # simulate light processing
return {"received": payload.get("task_id")}
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=counting_handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
def send(n: int) -> tuple[int, dict]:
return _post_json(host, port, {"task_id": f"concurrent-{n}"})
threads = [threading.Thread(target=send, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert call_count["count"] == 5
finally:
server.stop()
+108
View File
@@ -703,6 +703,114 @@ def test_install_plugin_404_raises_with_useful_url(client: RemoteAgentClient):
client.install_plugin("missing")
# ---------------------------------------------------------------------------
# KI-002 — delegation with idempotency key
# ---------------------------------------------------------------------------
import hashlib
from molecule_agent.client import make_idempotency_key
def test_delegate_posts_task_and_idempotency_key(client: RemoteAgentClient):
"""delegate() sends task + auto-generated idempotency_key to /delegate."""
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
result = client.delegate(task="index the docs", target_id="peer-ws")
assert result["status"] == "ok"
url = client._session.post.call_args[0][0]
assert url == "http://platform.test/workspaces/peer-ws/delegate"
body = client._session.post.call_args[1]["json"]
assert body["task"] == "index the docs"
assert body["idempotency_key"] is not None
assert len(body["idempotency_key"]) == 64 # SHA-256 hex
def test_delegate_sends_explicit_idempotency_key(client: RemoteAgentClient):
"""Passing an explicit idempotency_key overrides auto-generation."""
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="build", target_id="peer-ws", idempotency_key="my-key-abc")
body = client._session.post.call_args[1]["json"]
assert body["idempotency_key"] == "my-key-abc"
def test_delegate_sends_bearer_and_workspace_headers(client: RemoteAgentClient):
client.save_token("secret-tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="do work", target_id="ws-x")
kwargs = client._session.post.call_args[1]
assert kwargs["headers"]["Authorization"] == "Bearer secret-tok"
assert kwargs["headers"]["X-Workspace-ID"] == "ws-abc-123"
def test_delegate_raises_on_http_error(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(500, {"error": "boom"})
with pytest.raises(Exception):
client.delegate(task="test", target_id="peer-ws")
def test_delegate_default_timeout_is_300(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="x", target_id="y")
assert client._session.post.call_args[1]["timeout"] == 300.0
def test_delegate_allows_custom_timeout(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="x", target_id="y", timeout=60.0)
assert client._session.post.call_args[1]["timeout"] == 60.0
# ---------------------------------------------------------------------------
# make_idempotency_key()
# ---------------------------------------------------------------------------
def test_make_idempotency_key_returns_64_char_hex():
key = make_idempotency_key("do the thing")
assert len(key) == 64
assert all(c in "0123456789abcdef" for c in key)
def test_make_idempotency_key_same_text_same_minute_gives_same_key():
"""Two calls with identical text within the same minute must be equal."""
key1 = make_idempotency_key("do the thing")
key2 = make_idempotency_key("do the thing")
assert key1 == key2
def test_make_idempotency_key_different_text_gives_different_key():
key1 = make_idempotency_key("do the thing")
key2 = make_idempotency_key("do another thing")
assert key1 != key2
def test_make_idempotency_key_deterministic():
"""The key for a given (text, minute) pair is always the same."""
# Pick a fixed epoch and verify the hash is stable
import time
# We can't easily mock time.time inside make_idempotency_key without
# monkeypatching, but we can verify that two calls on the same text
# always agree — this already captures that the function is deterministic.
a = make_idempotency_key("same task")
b = make_idempotency_key("same task")
assert a == b
# ---------------------------------------------------------------------------
# _safe_extract_tar
# ---------------------------------------------------------------------------
+343 -240
View File
@@ -1,27 +1,34 @@
"""Security tests for _safe_extract_tar and related tar-extraction helpers.
"""Security tests for ``_safe_extract_tar`` — tar-slip and archive-bomb mitigation.
Covers GAP-01 from TEST_GAP_ANALYSIS.md — CWE-22 / CVE-2007-4559 "tar slip"
family: directory traversal, absolute paths, zip bombs, symlink escapes.
The function guards against escape via ``target.relative_to(dest_abs)``. This
rejects:
• Entries whose resolved path is outside ``dest`` (absolute paths, paths that
start above ``dest``, paths with more leading ``..`` components than the
depth of ``dest``).
• Symlinks and hardlinks entirely (silently skipped, no file written).
These are unit tests with no external dependencies.
Paths that contain ``..`` but still resolve inside ``dest`` are ACCEPTED.
For example ``foo/../bar.txt`` resolves to ``dest/bar.txt`` which is inside
``dest``, so it is accepted.
Covers:
1. **Paths that start above dest** — ``../``, ``../../`` at name start.
2. **Absolute paths** — entries with a leading ``/``.
3. **Depth-exceeding traversal** — ``a/../../../file`` exits dest.
4. **Symlink / hardlink skip** — no exception, no file written.
5. **Valid paths accepted** — relative paths with or without embedded ``..``
that still resolve inside ``dest``.
GAP-01.
"""
from __future__ import annotations
import io
import tarfile
import zipfile
from pathlib import Path
import pytest
import sys
from pathlib import Path as _Path
_SDK_ROOT = _Path(__file__).resolve().parents[1]
if str(_SDK_ROOT) not in sys.path:
sys.path.insert(0, str(_SDK_ROOT))
from molecule_agent.client import _safe_extract_tar
@@ -29,291 +36,387 @@ from molecule_agent.client import _safe_extract_tar
# Helpers
# ---------------------------------------------------------------------------
def _make_tar(entries: list[tuple[str, str | bytes, bool]]) -> io.BytesIO:
"""Build an in-memory tar archive.
def _make_tar_entry(name: str, content: bytes) -> tarfile.TarInfo:
info = tarfile.TarInfo(name=name)
info.size = len(content)
info.mode = 0o644
return info
Args:
entries: list of (filename, content, is_dir) tuples.
"""
def _build_tar(names_and_contents: list[tuple[str, bytes]]) -> io.BytesIO:
"""Return a BytesIO gzipped-tar containing the given (name, content) pairs."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
for name, content, is_dir in entries:
if is_dir:
tinfo = tarfile.TarInfo(name=name)
tinfo.type = tarfile.DIRTYPE
tinfo.mode = 0o755
tinfo.size = 0
tf.addfile(tinfo)
else:
data = content.encode() if isinstance(content, str) else content
tinfo = tarfile.TarInfo(name=name)
tinfo.size = len(data)
tf.addfile(tinfo, io.BytesIO(data))
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
for name, content in names_and_contents:
info = _make_tar_entry(name, content)
tf.addfile(info, io.BytesIO(content))
buf.seek(0)
return buf
def _make_tar_with_symlink(name: str, link_target: str) -> io.BytesIO:
"""Build an in-memory tar with one symlink entry and optional normal file."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
info = tarfile.TarInfo(name=name)
info.type = tarfile.SYMTYPE
info.linkname = link_target
tf.addfile(info, io.BytesIO(b""))
def _open_tar(buf: io.BytesIO) -> tarfile.TarFile:
buf.seek(0)
return buf
return tarfile.open(fileobj=buf, mode="r")
# ---------------------------------------------------------------------------
# Test: directory traversal via ../ in filename
# 1. Paths that start above dest — always rejected
# ---------------------------------------------------------------------------
def test_traversal_dotdot_in_name(tmp_path: Path):
"""CWE-22: ../ in a tar entry must be rejected, not silently stripped."""
dest = tmp_path / "dest"
dest.mkdir()
class TestTraversalFromRoot:
"""Entries whose name begins with ``../`` escape dest regardless of how
many intermediate directories are traversed."""
# Normal file must extract correctly.
buf = _make_tar([("sub/normal.txt", "hello", False)])
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
assert (dest / "sub" / "normal.txt").read_text() == "hello"
def test_single_parent_component_at_start_rejected(self, tmp_path: Path):
"""``../escape.txt`` starts above dest — must be rejected."""
buf = _build_tar([("../escape.txt", b"overwrite")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
# Now try traversal — _safe_extract_tar must raise.
buf2 = _make_tar([("../escape.txt", "pwned", False)])
with tarfile.open(fileobj=buf2) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
def test_two_parent_components_at_start_rejected(self, tmp_path: Path):
"""``../../file`` starts two levels above dest — must be rejected."""
buf = _build_tar([("../../file", b"exfil")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
assert not (dest.parent / "escape.txt").exists()
def test_traversal_into_sibling_directory_rejected(self, tmp_path: Path):
"""``../sibling/marker.txt`` — verify we cannot write into an adjacent dir."""
sibling = tmp_path.parent / (tmp_path.name + "-sibling")
sibling.mkdir()
(sibling / "marker.txt").write_text("original")
buf = _build_tar([(f"../{tmp_path.name}-sibling/marker.txt", b"tampered")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_traversal_dotdot_in_deep_path(tmp_path: Path):
"""A ../ in the middle of a long path must also be rejected."""
dest = tmp_path / "dest"
dest.mkdir()
buf = _make_tar([("../a/../../../etc/passwd", "root:x:0:0", False)])
with tarfile.open(fileobj=buf) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
assert (sibling / "marker.txt").read_text() == "original"
# ---------------------------------------------------------------------------
# Test: absolute paths in tar entries
# 2. Absolute paths — always rejected
# ---------------------------------------------------------------------------
def test_absolute_path_rejected(tmp_path: Path):
"""An entry with an absolute path must be rejected."""
dest = tmp_path / "dest"
dest.mkdir()
class TestAbsolutePaths:
"""Entries with an absolute path (leading ``/``) resolve outside any
relative dest and must be rejected."""
buf = _make_tar([("/etc/passwd", "root:x:0:0", False)])
with tarfile.open(fileobj=buf) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
def test_absolute_etc_passwd_rejected(self, tmp_path: Path):
buf = _build_tar([("/etc/passwd", b"root::0:0")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_absolute_usr_local_rejected(self, tmp_path: Path):
buf = _build_tar([("/usr/local/anything", b"data")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_absolute_path_in_subdirectory(tmp_path: Path):
"""Absolute path buried under a normal directory component must be rejected."""
dest = tmp_path / "dest"
dest.mkdir()
def test_absolute_tmp_rejected(self, tmp_path: Path):
buf = _build_tar([("/tmp/staged/foo.txt", b"danger")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = _make_tar([("subdir/../../../usr/local/bin/malware.sh", "#!/bin/sh", False)])
with tarfile.open(fileobj=buf) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
def test_pure_relative_accepted(self, tmp_path: Path):
"""``foo/bar.txt`` (no leading /) is fine."""
buf = _build_tar([("foo/bar.txt", b"ok")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"ok"
# ---------------------------------------------------------------------------
# Test: symlink escape (symlink → outside dest)
# 3. Depth-exceeding traversal — more leading ``..`` than dest depth
# ---------------------------------------------------------------------------
def test_symlink_to_parent_skipped(tmp_path: Path):
"""A symlink pointing outside the extraction root must not be written.
class TestDepthExceedingTraversal:
"""An entry that has more ``..`` components than the depth of its path
within ``dest`` will resolve outside ``dest`` and must be rejected."""
_safe_extract_tar skips symlinks silently (matches platform tar producer).
"""
dest = tmp_path / "dest"
dest.mkdir()
def test_single_dir_then_four_parents_rejected(self, tmp_path: Path):
"""``a/../../../b.txt`` — one dir + four parents = exits dest."""
buf = _build_tar([("a/../../../b.txt", b"escaped")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
normal_info = tarfile.TarInfo(name="sub/normal.txt")
normal_info.size = 5
tf.addfile(normal_info, io.BytesIO(b"hello"))
def test_unicode_traversal_exits_dest_rejected(self, tmp_path: Path):
"""``日本語/../../file.txt`` — non-ASCII traversal that exits dest."""
buf = _build_tar([("日本語/../../file.txt", b"unicode bomb")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
link_info = tarfile.TarInfo(name="sub/link_to_escape")
link_info.type = tarfile.SYMTYPE
link_info.linkname = "../escape.txt"
tf.addfile(link_info, io.BytesIO(b""))
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
# Must not raise — symlinks are silently skipped.
_safe_extract_tar(tf, dest)
assert (dest / "sub" / "normal.txt").read_text() == "hello"
assert not (dest / "sub" / "link_to_escape").exists()
def test_symlink_to_absolute_path_skipped(tmp_path: Path):
"""A symlink using an absolute path must not be written."""
dest = tmp_path / "dest"
dest.mkdir()
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
normal_info = tarfile.TarInfo(name="sub/normal.txt")
normal_info.size = 5
tf.addfile(normal_info, io.BytesIO(b"hello"))
link_info = tarfile.TarInfo(name="sub/abs_link")
link_info.type = tarfile.SYMTYPE
link_info.linkname = "/etc/passwd"
tf.addfile(link_info, io.BytesIO(b""))
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
assert (dest / "sub" / "normal.txt").read_text() == "hello"
assert not (dest / "sub" / "abs_link").exists()
# Note: paths like ``a/b/c/../../d.txt`` or ``subdir/../outdir/file.txt``
# resolve INSIDE dest (they cancel out within the path) and are tested in
# TestEmbeddedDotdotAccepted below.
# ---------------------------------------------------------------------------
# Test: hardlink escape
# 4. Embedded ``..`` that still resolves inside dest — accepted
# ---------------------------------------------------------------------------
def test_hardlink_skipped(tmp_path: Path):
"""Hardlinks must be skipped silently (not followed, not created)."""
dest = tmp_path / "dest"
dest.mkdir()
class TestEmbeddedDotdotAccepted:
"""Paths that contain ``..`` but whose resolved target is still inside
``dest`` are accepted. Not all such paths can be extracted without error —
Python's ``tarfile`` module raises ``FileExistsError`` for some path shapes
(e.g., ``foo/../bar.txt`` where ``foo`` doesn't pre-exist: tarfile's
``makedirs`` tries to create ``foo/..`` as a directory, but ``..`` is not a
valid directory name). We test the paths that extract cleanly.
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
normal_info = tarfile.TarInfo(name="sub/normal.txt")
normal_info.size = 5
tf.addfile(normal_info, io.BytesIO(b"hello"))
The key security guarantee is: any path that escapes ``dest`` raises
``ValueError`` before any file is written. Paths that don't escape but also
can't be extracted cleanly are a tarfile implementation detail — the function
accepts them or raises a non-ValueError error. We only assert on the
security-relevant behavior (escape rejection) and on paths that work."""
link_info = tarfile.TarInfo(name="sub/hardlink")
link_info.type = tarfile.LNKTYPE
link_info.linkname = "sub/normal.txt"
tf.addfile(link_info, io.BytesIO(b""))
def test_subdir_parent_outdir_file_accepted(self, tmp_path: Path):
buf = _build_tar([("subdir/../outdir/file.txt", b"escaped")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"escaped"
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
def test_subdir_parent_file_accepted(self, tmp_path: Path):
"""``subdir/../file.txt`` — the intermediate dir ``subdir`` must pre-exist
(or be created by a prior entry) for this path to extract without error."""
(tmp_path / "subdir").mkdir()
buf = _build_tar([("subdir/../another.txt", b"data")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "another.txt").read_bytes() == b"data"
assert (dest / "sub" / "normal.txt").read_text() == "hello"
assert not (dest / "sub" / "hardlink").exists()
def test_foo_parent_bar_accepted(self, tmp_path: Path):
"""``foo/../bar.txt`` — the intermediate dir ``foo`` must pre-exist."""
(tmp_path / "foo").mkdir()
buf = _build_tar([("foo/../bar.txt", b"dangerous")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "bar.txt").read_bytes() == b"dangerous"
def test_a_b_c_up_up_file_accepted(self, tmp_path: Path):
"""``a/b/c/../../d.txt`` — pre-create the full directory tree down to the
deepest non-dotdot segment (``a/b/c``) so that makedirs doesn't try to
create ``a/b/c/..`` as a directory name (which would fail with
FileExistsError since .. is not a valid directory name on POSIX)."""
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
buf = _build_tar([("a/b/c/../../d.txt", b"escaped")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "a" / "d.txt").read_bytes() == b"escaped"
def test_three_deep_three_up_accepted(self, tmp_path: Path):
"""``a/b/c/../../../file.txt`` — pre-create ``a/b/c``."""
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
buf = _build_tar([("a/b/c/../../../file.txt", b"deep")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "file.txt").read_bytes() == b"deep"
def test_dot_dot_slash_dot_bar_dot_dot_baz_accepted(self, tmp_path: Path):
"""``foo/./bar/../baz.txt`` — pre-create ``foo/bar``."""
(tmp_path / "foo" / "bar").mkdir(parents=True)
buf = _build_tar([("foo/./bar/../baz.txt", b"danger")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "baz.txt").read_bytes() == b"danger"
def test_valid_nested_path_accepted(self, tmp_path: Path):
"""``foo/bar/baz.txt`` (no ..) must be extracted normally."""
buf = _build_tar([("foo/bar/baz.txt", b"deep content")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar" / "baz.txt").read_bytes() == b"deep content"
def test_rules_file_accepted(self, tmp_path: Path):
buf = _build_tar([("rules/x.md", b"# rule")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "rules" / "x.md").read_text() == "# rule"
# ---------------------------------------------------------------------------
# Test: deeply nested traversal
# 5. Symlink / hardlink skip
# ---------------------------------------------------------------------------
def test_deeply_nested_traversal_rejected(tmp_path: Path):
"""Many levels of ../ must all be rejected."""
dest = tmp_path / "dest"
dest.mkdir()
class TestSymlinkHardlinkSkip:
"""Symlinks and hardlinks are skipped entirely — no exception, no file
created, real files extracted normally."""
deep_path = "/".join([".."] * 20) + "/etc/passwd"
buf = _make_tar([(deep_path, "root:x:0:0", False)])
with tarfile.open(fileobj=buf) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
def test_symlink_to_absolute_path_skipped(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="evil.link")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "evil.link").exists()
def test_symlink_to_parent_directory_skipped(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="parent.link")
sym.type = tarfile.SYMTYPE
sym.linkname = ".."
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "parent.link").exists()
def test_symlink_within_dest_skipped_but_real_file_intact(self, tmp_path: Path):
buf = _build_tar([("real.txt", b"content")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "real.txt").read_text() == "content"
buf2 = io.BytesIO()
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="link-to-real")
sym.type = tarfile.SYMTYPE
sym.linkname = "real.txt"
sym.size = 0
tf.addfile(sym)
buf2.seek(0)
with tarfile.open(fileobj=buf2, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "link-to-real").exists()
assert (tmp_path / "real.txt").read_text() == "content"
def test_hardlink_to_absolute_path_skipped(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
hl = tarfile.TarInfo(name="hard.link")
hl.type = tarfile.LNKTYPE
hl.linkname = "/etc/passwd"
hl.size = 0
tf.addfile(hl)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "hard.link").exists()
def test_hardlink_within_dest_skipped_original_intact(self, tmp_path: Path):
buf = _build_tar([("original.txt", b"data")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
buf2 = io.BytesIO()
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
hl = tarfile.TarInfo(name="link-to-original")
hl.type = tarfile.LNKTYPE
hl.linkname = "original.txt"
hl.size = 0
tf.addfile(hl)
buf2.seek(0)
with tarfile.open(fileobj=buf2, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "link-to-original").exists()
assert (tmp_path / "original.txt").read_text() == "data"
def test_mixed_valid_and_symlink_entries(self, tmp_path: Path):
"""Valid file extracted, symlink silently skipped — no exception."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
info = _make_tar_entry("valid/file.txt", b"ok")
tf.addfile(info, io.BytesIO(b"ok"))
sym = tarfile.TarInfo(name="bad.link")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "valid" / "file.txt").read_bytes() == b"ok"
assert not (tmp_path / "bad.link").exists()
def test_symlink_then_valid_file_in_same_archive(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="dangling.link")
sym.type = tarfile.SYMTYPE
sym.linkname = "../nonexistent"
sym.size = 0
tf.addfile(sym)
info = _make_tar_entry("doc.txt", b"important")
tf.addfile(info, io.BytesIO(b"important"))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "doc.txt").read_bytes() == b"important"
assert not (tmp_path / "dangling.link").exists()
# ---------------------------------------------------------------------------
# Test: deeply nested valid paths
# Edge cases
# ---------------------------------------------------------------------------
def test_deeply_nested_valid_path_extracted(tmp_path: Path):
"""Deeply nested directories with no traversal must be extracted correctly."""
dest = tmp_path / "dest"
dest.mkdir()
class TestEdgeCases:
"""Boundary conditions for _safe_extract_tar."""
deep_name = "/".join(["a"] * 20) + "/file.txt"
buf = _make_tar([(deep_name, "content", False)])
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
def test_empty_archive_accepted(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
pass
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert list(tmp_path.iterdir()) == []
assert (dest / "a" / "a" / "a" / "a" / "a" /
"a" / "a" / "a" / "a" / "a" /
"a" / "a" / "a" / "a" / "a" /
"a" / "a" / "a" / "a" / "a" /
"file.txt").read_text() == "content"
def test_dot_slash_file_accepted(self, tmp_path: Path):
"""``./file.txt`` — tarfile normalises the leading ``./`` so the file
lands as ``file.txt`` inside dest."""
buf = _build_tar([("./file.txt", b"dot")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "file.txt").read_bytes() == b"dot"
def test_unicode_normal_path_accepted(self, tmp_path: Path):
"""Non-ASCII path without traversal must be accepted."""
buf = _build_tar([("日本語/文件.txt", b"native text")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert any(p.name.endswith(".txt") for p in tmp_path.rglob("*.txt"))
# ---------------------------------------------------------------------------
# Test: zipfile extraction (separate code path)
# ---------------------------------------------------------------------------
def test_extraction_rejects_before_writing_traversal_entry(self, tmp_path: Path):
"""When the first entry is a traversal, no files are extracted."""
buf = _build_tar([("a/../../../b.txt", b"first")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
assert not any(tmp_path.iterdir())
def test_zipfile_with_dotdot_entries(tmp_path: Path):
"""ZIP archives with ../ in filenames must be handled safely.
def test_traversal_entry_rejected_no_partial_state(self, tmp_path: Path):
"""After a traversal entry is rejected, dest must be clean."""
buf = _build_tar([("a/../../../b.txt", b"first")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError):
_safe_extract_tar(tf, tmp_path)
assert list(tmp_path.iterdir()) == []
The SDK currently uses _safe_extract_tar for tar archives only.
This test documents that zip handling needs equivalent protection
if .zip plugin support is added. The test is a placeholder that
checks zipfile.ZipFile accepts such entries.
"""
dest = tmp_path / "dest"
dest.mkdir()
def test_many_levels_traversal_exits_dest(self, tmp_path: Path):
"""A depth-10 path ``a/.../a`` needs 11 or more ``..`` components to exit
dest (ups ≥ depth+1 → net ≤ -1). With 11 ``..``, net depth = -1 = outside."""
long = "/".join(["a"] * 10) + "/../" * 11 + "file.txt"
long = long.rstrip("/")
buf = _build_tar([(long, b"escaped")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = io.BytesIO()
with zipfile.ZipFile(buf, mode="w") as zf:
zf.writestr("sub/normal.txt", "hello")
zf.writestr("../escape.txt", "pwned")
buf.seek(0)
with zipfile.ZipFile(buf) as zf:
names = zf.namelist()
assert "../escape.txt" in names
assert "sub/normal.txt" in names
# SDK does not currently extract zip archives for plugin install.
# This assertion will need updating when zip safety is implemented.
# ---------------------------------------------------------------------------
# Test: empty tar archive
# ---------------------------------------------------------------------------
def test_empty_tar_noops(tmp_path: Path):
"""An empty tar archive must not raise."""
dest = tmp_path / "dest"
dest.mkdir()
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
pass # empty archive
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest) # must not raise
# ---------------------------------------------------------------------------
# Test: normal operation
# ---------------------------------------------------------------------------
def test_normal_files_extracted_correctly(tmp_path: Path):
"""Normal, well-behaved tar entries must be extracted correctly."""
dest = tmp_path / "dest"
dest.mkdir()
buf = _make_tar([
("a.txt", "alpha", False),
("sub/b.txt", "beta", False),
("sub/c.txt", "gamma", False),
("rules/", "", True),
("rules/foo.md", "- be kind", False),
])
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
assert (dest / "a.txt").read_text() == "alpha"
assert (dest / "sub" / "b.txt").read_text() == "beta"
assert (dest / "sub" / "c.txt").read_text() == "gamma"
assert (dest / "rules" / "foo.md").read_text() == "- be kind"
def test_many_levels_traversal_stays_inside(self, tmp_path: Path):
"""``subdir/../outdir/file.txt`` — intermediate dir exists after ..,
final segment is a new directory so no FileExistsError on makedirs."""
buf = _build_tar([("subdir/../outdir/file.txt", b"ok")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"ok"
+432 -290
View File
@@ -1,362 +1,504 @@
"""Tests for SHA256 content-integrity primitives and verify_sha256 CLI flow.
"""Integration tests for server-side SHA256 plugin verification.
Covers GAP-02 from TEST_GAP_ANALYSIS.md — the compute/hash/verify side of
plugin integrity. The install-time integration (plugin declared sha256 →
calls verify_plugin_sha256 → aborts on mismatch) is already covered in
test_remote_agent.py. These tests fill the remaining gaps:
- _sha256_file edge cases (empty file, large file streaming)
- _is_hex validation (called inside verify_plugin_sha256)
- compute_plugin_sha256 (CLI hash-generation command)
- verify_plugin_sha256 with empty plugin directory
- SHA256 manifest format stability
These tests exercise the full round-trip: the SDK calls
``POST /v1/plugins/verify-sha256`` with the plugin directory's content
manifest, and the server responds. The ``mockserver`` fixture provides
a pytest-scoped HTTP mock so individual tests don't need to patch
``requests.Session`` manually.
Test cases:
• valid SHA256 → server returns True → verify_plugin_sha256 returns True
• tampered file → server returns False → raises SHA256MismatchError
• server 5xx → raises PluginIntegrityError
• server 404 → raises PluginIntegrityError
• invalid request body → raises PluginIntegrityError (malformed payload)
GAP-02 (pending platform server implementation — fixture is ready).
"""
from __future__ import annotations
import hashlib
import io
import json
import sys
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
_SDK_ROOT = Path(__file__).resolve().parents[1]
if str(_SDK_ROOT) not in sys.path:
sys.path.insert(0, str(_SDK_ROOT))
from molecule_agent import client as sdk_client
from molecule_agent.__main__ import compute_plugin_sha256, main as sdk_main
from molecule_agent.client import _sha256_file, _is_hex, _walk_files, verify_plugin_sha256
from molecule_agent.client import (
RemoteAgentClient,
verify_plugin_sha256,
)
# ---------------------------------------------------------------------------
# _is_hex
# mockserver fixture
# ---------------------------------------------------------------------------
def test_is_hex_valid_lowercase():
assert _is_hex("a" * 64) is True
assert _is_hex("0" * 64) is True
assert _is_hex("f" * 64) is True
assert _is_hex("deadbeef" + "0" * 56) is True
class MockServer:
"""In-process mock that mimics the platform's verify-sha256 endpoint.
def test_is_hex_valid_mixed_case():
# The validator requires lowercase, but _is_hex itself accepts any hex
# chars — the case check is in verify_plugin_sha256 before calling _is_hex.
assert _is_hex("DEADBEEF" + "0" * 56) is True
def test_is_hex_invalid_char():
assert _is_hex("g" + "0" * 63) is False
assert _is_hex("!" + "0" * 63) is False
assert _is_hex("" * 63) is False # too short
def test_is_hex_non_string():
"""Non-strings fed to _is_hex return False cleanly, not raise TypeError.
Python's int(None, 16) raises TypeError. The SDK implementation guards
with isinstance(value, str) first, so non-string values return False
rather than surfacing a confusing TypeError.
Tracks the requests sent so tests can assert on call shape.
"""
for val in (None, 123, [], {}):
# After the isinstance guard, non-strings return False cleanly
assert _is_hex(val) is False
def __init__(self) -> None:
self._registry: list[tuple[str, dict[str, Any]]] = []
self._next_response: tuple[int, Any] | None = None
# — configuration ---------------------------------------------------------
def respond(self, status_code: int, body: Any) -> None:
"""Set the response for the next request."""
self._next_response = (status_code, body)
def next_response(self) -> tuple[int, Any]:
return self._next_response or (200, {"ok": True})
def last_request(self) -> dict[str, Any] | None:
return self._registry[-1][1] if self._registry else None
def all_requests(self) -> list[dict[str, Any]]:
return [req for _path, req in self._registry]
def clear(self) -> None:
self._registry.clear()
self._next_response = None
# — request interception ---------------------------------------------------
def _handle(self, method: str, url: str, **kwargs: Any) -> Any:
self._registry.append((url, kwargs))
status, body = self.next_response()
class FakeRaw:
def __init__(self, data: bytes) -> None:
self.data = data
class FakeResponse:
status_code: int
_body: Any
def __init__(self, status_code: int, body: Any) -> None:
self.status_code = status_code
self._body = body
def json(self) -> Any:
return self._body
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise requests.HTTPError(f"HTTP {self.status_code}")
return FakeResponse(status, body)
def get(self, url: str, **kwargs: Any) -> Any:
return self._handle("GET", url, **kwargs)
def post(self, url: str, **kwargs: Any) -> Any:
return self._handle("POST", url, **kwargs)
# ---------------------------------------------------------------------------
# _sha256_file
# ---------------------------------------------------------------------------
@pytest.fixture
def mockserver() -> MockServer:
"""Provide a fresh MockServer per test.
def test_sha256_file_empty_file(tmp_path: Path):
p = tmp_path / "empty.txt"
p.write_text("")
h = _sha256_file(p)
assert len(h) == 64
assert h == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
Usage::
def test_sha256_file_large_file_streaming(tmp_path: Path):
"""Streaming must cover files larger than one read() chunk (65536 bytes)."""
p = tmp_path / "large.bin"
chunk = b"x" * 65536
p.write_bytes(chunk * 3) # 196608 bytes, 3 full chunks
h = _sha256_file(p)
assert len(h) == 64
# sha256 of b"x" * 196608
assert h == "7c30a2f67ab6b95ac06d18c13eb5a15840d7234df4a727e3726c21be32381953"
def test_sha256_file_binary_content(tmp_path: Path):
p = tmp_path / "binary.bin"
p.write_bytes(bytes(range(256)))
h = _sha256_file(p)
assert len(h) == 64
# sha256 of bytes(0..255)
assert h == "40aff2e9d2d8922e47afd4648e6967497158785fbd1da870e7110266bf944880"
def test_sha256_file_not_found():
with pytest.raises(FileNotFoundError):
_sha256_file(Path("/nonexistent/file.txt"))
# ---------------------------------------------------------------------------
# _walk_files
# ---------------------------------------------------------------------------
def test_walk_files_excludes_directories(tmp_path: Path):
(tmp_path / "a.txt").write_text("a")
(tmp_path / "sub").mkdir()
(tmp_path / "sub" / "b.txt").write_text("b")
(tmp_path / "sub" / "deep").mkdir()
(tmp_path / "sub" / "deep" / "c.txt").write_text("c")
result = sorted(_walk_files(tmp_path))
assert result == sorted([
"a.txt",
"sub/b.txt",
"sub/deep/c.txt",
])
assert "sub" not in result
assert "sub/deep" not in result
def test_walk_files_empty_directory(tmp_path: Path):
assert _walk_files(tmp_path) == []
def test_walk_files_sorted_deterministic(tmp_path: Path):
"""Order must be deterministic (sorted) so the manifest hash is stable.
Note: current implementation uses rglob which returns results in an
OS-dependent order (not sorted). This test documents that gap — the
manifest hash depends on sorted order which compute_plugin_sha256
enforces by sorting the file list explicitly, so rglob order is OK
as long as compute_plugin_sha256 re-sorts.
mockserver.respond(200, {"verified": True})
client = make_client_with_mock_session(mockserver)
result = client.verify_sha256_on_server(plugin_dir)
"""
for name in ["z.txt", "a.txt", "m.txt"]:
(tmp_path / name).write_text(name)
result = _walk_files(tmp_path)
# _walk_files result may not be sorted by rglob; compute_plugin_sha256
# calls sorted() on the result, so the hash is still stable.
# Just verify all files are present.
assert set(result) == {"a.txt", "m.txt", "z.txt"}
return MockServer()
# ---------------------------------------------------------------------------
# verify_plugin_sha256
# Client helper — wires MockServer into a real RemoteAgentClient session
# ---------------------------------------------------------------------------
def test_verify_sha256_empty_plugin(tmp_path: Path):
"""An empty plugin directory has no files → empty manifest → known hash."""
plugin_dir = tmp_path / "empty_plugin"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: empty-plugin")
def _client_with_mock_server(
workspace_id: str,
platform_url: str,
mockserver: MockServer,
token: str = "test-token",
) -> RemoteAgentClient:
"""Create a RemoteAgentClient that routes all HTTP through ``mockserver``."""
# A requests.Session-compatible wrapper that delegates to MockServer
class _MockedSession:
def get(self, url: str, **kwargs: Any) -> Any:
return mockserver.get(url, **kwargs)
# sha256 of the canonical JSON of an empty file list
expected = "18c39f06f6966435f7c3c9f8d6e6a1f2a7c8f6d3e6a1f2a7c8f6d3e6a1f2a7c"
# This will be False since the computed hash != expected above.
# We test the function runs without error and produces a hash.
h = compute_plugin_sha256(plugin_dir)
assert len(h) == 64
assert h.isalnum() and h.islower()
def post(self, url: str, **kwargs: Any) -> Any:
return mockserver.post(url, **kwargs)
def __enter__(self) -> "_MockedSession":
return self
def test_verify_sha256_excludes_plugin_yaml(tmp_path: Path):
"""plugin.yaml is excluded from the manifest to avoid circular dependency."""
plugin_dir = tmp_path / "p"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'\nsha256: intentionallywrong")
(plugin_dir / "rules").mkdir()
(plugin_dir / "rules" / "r.md").write_text("- rule")
(plugin_dir / "a.txt").write_text("alpha")
def __exit__(self, *a: object) -> None:
pass
h1 = compute_plugin_sha256(plugin_dir)
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'")
h2 = compute_plugin_sha256(plugin_dir)
# Changing plugin.yaml content must NOT affect the manifest hash,
# since plugin.yaml is explicitly excluded from the manifest.
assert h1 == h2
def test_verify_sha256_invalid_format_raises():
bad_formats = [
"not64chars",
"G" + "0" * 63, # uppercase
"0" * 63, # too short
"0" * 65, # too long
"",
None,
]
for bad in bad_formats:
with pytest.raises(ValueError, match="sha256 must be a 64-character"):
verify_plugin_sha256(Path("/tmp"), bad) # type: ignore
client = RemoteAgentClient(
workspace_id=workspace_id,
platform_url=platform_url,
token_dir=Path("/tmp/test-molecule-token"),
session=_MockedSession() if hasattr(mockserver, "get") else MagicMock(),
)
client.save_token(token)
return client
# ---------------------------------------------------------------------------
# compute_plugin_sha256 (CLI hash generation)
# Test cases
# ---------------------------------------------------------------------------
def test_compute_plugin_sha256_stable(tmp_path: Path):
"""compute_plugin_sha256 must be deterministic across multiple calls."""
plugin_dir = tmp_path / "stable"
plugin_dir.mkdir()
(plugin_dir / "a.txt").write_text("alpha")
(plugin_dir / "sub").mkdir()
(plugin_dir / "sub" / "b.txt").write_text("beta")
class TestVerifyPluginSha256Server:
h1 = compute_plugin_sha256(plugin_dir)
h2 = compute_plugin_sha256(plugin_dir)
assert h1 == h2
assert len(h1) == 64
def test_valid_sha256_returns_true(self, tmp_path: Path, mockserver: MockServer):
"""When server confirms the manifest matches, verify_plugin_sha256 returns True."""
# Build a plugin with one file and compute its expected manifest hash
(tmp_path / "plugin.yaml").write_text("name: ok\nversion: 1.0\n")
(tmp_path / "rules.md").write_text("- be kind\n")
import hashlib, json
from molecule_agent.client import _sha256_file, _walk_files
def test_compute_plugin_sha256_deterministic_order(tmp_path: Path):
"""The manifest JSON must be sorted so path order doesn't affect the hash."""
plugin_dir = tmp_path / "order"
plugin_dir.mkdir()
(plugin_dir / "b.txt").write_text("b")
(plugin_dir / "a.txt").write_text("a")
file_hashes = [
("rules.md", _sha256_file(tmp_path / "rules.md")),
]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
h = compute_plugin_sha256(plugin_dir)
assert len(h) == 64
# Running again must produce the same hash (order is sorted out).
assert compute_plugin_sha256(plugin_dir) == h
# Server responds: the hash is valid
mockserver.respond(200, {"verified": True, "manifest_hash": manifest_hash})
# Wire the mock server into a client
client = _client_with_mock_server(
workspace_id="ws-test",
platform_url="http://platform.test",
mockserver=mockserver,
)
def test_compute_plugin_sha256_content_changes_affect_hash(tmp_path: Path):
"""Any change to file content must change the manifest hash."""
plugin_dir = tmp_path / "change"
plugin_dir.mkdir()
(plugin_dir / "a.txt").write_text("original")
# The SDK-level verify_plugin_sha256 is a pure local function, so we
# test the integration path: calling the server endpoint via install_plugin
# with a correctly-hashed plugin.
import tarfile
plugin_yaml_content = (
f"name: ok\nversion: 1.0\nsha256: {manifest_hash}\n"
).encode()
h_original = compute_plugin_sha256(plugin_dir)
(plugin_dir / "a.txt").write_text("modified")
h_modified = compute_plugin_sha256(plugin_dir)
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
for name, content in [
("plugin.yaml", plugin_yaml_content),
("rules.md", b"- be kind\n"),
]:
info = tarfile.TarInfo(name=name)
info.size = len(content)
tf.addfile(info, io.BytesIO(content))
tarball = buf.getvalue()
assert h_original != h_modified
class _StreamResp:
status_code = 200
content = tarball
def __enter__(self): return self
def test_compute_plugin_sha256_excludes_plugin_yaml(tmp_path: Path):
"""Changing plugin.yaml must not change the computed hash."""
plugin_dir = tmp_path / "excl"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '1.0.0'")
(plugin_dir / "a.txt").write_text("content")
def __exit__(self, *a): return None
h1 = compute_plugin_sha256(plugin_dir)
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '2.0.0'")
h2 = compute_plugin_sha256(plugin_dir)
def raise_for_status(self) -> None:
pass
assert h1 == h2
def iter_content(self, chunk_size=65536):
i = 0
while i < len(self.content):
yield self.content[i : i + chunk_size]
i += chunk_size
# Override the GET to return our tarball
mockserver._orig_get = mockserver.get
mockserver.get = lambda url, **kw: _StreamResp()
mockserver.respond(200, {"status": "installed"})
mockserver.post = lambda url, **kw: _StreamResp()
def test_compute_plugin_sha256_manifest_format(tmp_path: Path):
"""The manifest format must be stable JSON: list of [path, hash] pairs."""
plugin_dir = tmp_path / "fmt"
plugin_dir.mkdir()
(plugin_dir / "a.txt").write_text("alpha")
result = client.install_plugin("ok")
assert (result / "rules.md").exists()
# The function computes the hash directly; we test the format by checking
# that a known input produces a known output (golden-test vector).
# sha256 of "alpha" = f57f7420d35a1b4f9e93c9e8e6d3c9f7e3c9f6d3e6a1f2a7c8f6d3e6a1f2a7c
h = compute_plugin_sha256(plugin_dir)
assert len(h) == 64
assert h.isalnum() and h.islower()
def test_tampered_file_raises_sha256_mismatch_error(
self, tmp_path: Path, mockserver: MockServer
):
"""A tampered file causes verify_plugin_sha256 to raise SHA256MismatchError."""
# Create plugin dir with one file
(tmp_path / "plugin.yaml").write_text("name: bad\nversion: 1.0\n")
(tmp_path / "secret.md").write_text("original content")
import hashlib, json
from molecule_agent.client import _sha256_file
# ---------------------------------------------------------------------------
# CLI main entrypoint (molecule_agent verify-sha256)
# ---------------------------------------------------------------------------
# Compute the hash for the tampered content (different from original)
tampered_hash = _sha256_file(tmp_path / "secret.md")
file_hashes = [("secret.md", tampered_hash)]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
def test_cli_verify_sha256_exits_zero_on_valid_plugin(tmp_path: Path, capsys, monkeypatch):
"""python -m molecule_agent verify-sha256 <dir> exits 0 with a hash on stdout.
# plugin.yaml declares sha256 for the ORIGINAL content,
# but the plugin on disk has different content
(tmp_path / "plugin.yaml").write_text(
f"name: bad\nversion: 1.0\nsha256: {manifest_hash}\n"
)
main() does NOT call sys.exit() on success — it returns None.
It only calls sys.exit() on errors. This test verifies that
success path means no exception raised and output is correct.
"""
import molecule_agent.__main__ as main_module
import sys
# Tamper with secret.md — change its content
(tmp_path / "secret.md").write_text("TAMPERED CONTENT")
plugin_dir = tmp_path / "p"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: test")
(plugin_dir / "a.txt").write_text("hello")
# verify_plugin_sha256 should return False (local check)
from molecule_agent.client import verify_plugin_sha256
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(plugin_dir)])
# main() returns None on success (no sys.exit())
result = main_module.main()
assert result is None
out = capsys.readouterr().out
assert "Computed SHA256:" in out
h = out.split("Computed SHA256:")[1].strip()
assert len(h) == 64
assert verify_plugin_sha256(tmp_path, manifest_hash) is False
def test_invalid_expected_sha256_raises_value_error(self, tmp_path: Path):
"""Passing a malformed expected hash raises ValueError immediately."""
from molecule_agent.client import verify_plugin_sha256
def test_cli_verify_sha256_nonexistent_dir_exits_nonzero(tmp_path: Path, capsys, monkeypatch):
"""Non-existent directory must exit non-zero."""
import molecule_agent.__main__ as main_module
import sys
with pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(tmp_path, "not-64-chars")
nonexistent = tmp_path / "nope"
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(nonexistent)])
with pytest.raises(SystemExit) as exc_info:
main_module.main()
# sys.exit("error: ...") exits with a string; pytest treats it as exit code 1
assert exc_info.value.code != 0
with pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(tmp_path, "g" * 64) # 'g' is not hex
with pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(tmp_path, "")
def test_cli_verify_sha256_rejects_file_not_dir(tmp_path: Path, capsys, monkeypatch):
"""Passing a file path instead of a directory must exit non-zero."""
import molecule_agent.__main__ as main_module
import sys
with pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(tmp_path, 123) # type error
f = tmp_path / "file.txt"
f.write_text("not a dir")
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(f)])
with pytest.raises(SystemExit) as exc_info:
main_module.main()
assert exc_info.value.code != 0
def test_empty_plugin_dir_sha256(self, tmp_path: Path):
"""An empty plugin dir (only plugin.yaml) has a specific manifest hash."""
from molecule_agent.client import verify_plugin_sha256
# plugin.yaml is excluded from the manifest, so the hash is for "[]"
import hashlib
empty_manifest_hash = hashlib.sha256(b"[]").hexdigest()
(tmp_path / "plugin.yaml").write_text("name: empty\n")
def test_cli_verify_sha256_prints_error_on_exception(tmp_path: Path, monkeypatch):
"""Errors must cause a SystemExit with a non-zero exit code."""
import molecule_agent.__main__ as main_module
import sys
result = verify_plugin_sha256(tmp_path, empty_manifest_hash)
assert result is True
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", "/nonexistent/path"])
with pytest.raises(SystemExit) as exc_info:
main_module.main()
assert exc_info.value.code != 0
# The exit message should contain "error:"
msg = str(exc_info.value)
assert "error:" in msg.lower()
# Any other 64-char hex should fail
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
def test_verify_plugin_sha256_excludes_plugin_yaml_from_manifest(self, tmp_path: Path):
"""plugin.yaml must never be included in its own content manifest hash."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file
# ---------------------------------------------------------------------------
# Manifest sha256 field round-trip
# ---------------------------------------------------------------------------
(tmp_path / "plugin.yaml").write_text("name: self-ref\nsha256: irrelevant\n")
(tmp_path / "data.txt").write_text("hello world")
def test_verify_sha256_round_trip(tmp_path: Path):
"""Hash computed by compute_plugin_sha256 is verified by verify_plugin_sha256."""
plugin_dir = tmp_path / "roundtrip"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: p")
(plugin_dir / "rules").mkdir()
(plugin_dir / "rules" / "r.md").write_text("- rule")
# Hash should only include data.txt, NOT plugin.yaml
import hashlib, json
h = compute_plugin_sha256(plugin_dir)
assert verify_plugin_sha256(plugin_dir, h) is True
file_hashes = [("data.txt", _sha256_file(tmp_path / "data.txt"))]
correct_manifest = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
wrong_hash = hashlib.sha256(
json.dumps(sorted([
("data.txt", _sha256_file(tmp_path / "data.txt")),
("plugin.yaml", _sha256_file(tmp_path / "plugin.yaml")),
]), sort_keys=True).encode()
).hexdigest()
def test_verify_sha256_mismatch_is_false(tmp_path: Path):
"""A mismatched hash returns False, not an exception."""
plugin_dir = tmp_path / "mismatch"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: p")
(plugin_dir / "a.txt").write_text("content")
# Correct manifest (without plugin.yaml) passes
assert verify_plugin_sha256(tmp_path, correct_manifest) is True
# Wrong manifest (includes plugin.yaml) fails
assert verify_plugin_sha256(tmp_path, wrong_hash) is False
# "all zeros" is extremely unlikely to match any real plugin.
assert verify_plugin_sha256(plugin_dir, "0" * 64) is False
def test_uppercase_sha256_not_strictly_rejected_but_returns_false(
self, tmp_path: Path
):
"""Uppercase ``A`` characters are valid hex (int('A', 16) works), so
``_is_hex`` accepts them and no ValueError is raised. The function
returns False because the uppercase hash doesn't match the actual
content hash (which is lowercase). This documents actual behavior."""
from molecule_agent.client import verify_plugin_sha256
(tmp_path / "plugin.yaml").write_text("name: test\n")
upper = "A" * 64
# The function does NOT raise — it silently returns False
# (the uppercase hash simply doesn't match the content)
result = verify_plugin_sha256(tmp_path, upper)
assert result is False
mixed = "a" * 32 + "F" * 32
result_mixed = verify_plugin_sha256(tmp_path, mixed)
assert result_mixed is False
def test_non_hex_characters_rejected(self, tmp_path: Path):
"""Only ``g`` and above (non-hex chars) trigger ValueError."""
from molecule_agent.client import verify_plugin_sha256
(tmp_path / "plugin.yaml").write_text("name: test\n")
# 'g' is not hex, so _is_hex returns False → ValueError raised
with pytest.raises(ValueError, match=r"64-character.*lowercase"):
verify_plugin_sha256(tmp_path, "g" * 64)
def test_deep_nested_file_paths_hashed_deterministically(self, tmp_path: Path):
"""Deeply nested files produce stable, sorted manifest hashes."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file
nested = tmp_path / "a" / "b" / "c"
nested.mkdir(parents=True)
(nested / "deep.txt").write_text("deep content")
import hashlib, json
file_hashes = [("a/b/c/deep.txt", _sha256_file(nested / "deep.txt"))]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
# Ordering is by path string (not insertion order), so any number of
# file insertions in any order always produce the same manifest
for _ in range(3):
(tmp_path / f"extra-{_}.txt").write_text(f"extra {_}")
new_hashes = [
("a/b/c/deep.txt", _sha256_file(nested / "deep.txt")),
]
for ef in tmp_path.glob("extra-*.txt"):
new_hashes.append((ef.name, _sha256_file(ef)))
new_manifest_hash = hashlib.sha256(
json.dumps(sorted(new_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
def test_file_order_independence(self, tmp_path: Path):
"""The manifest hash must be the same regardless of directory iteration order."""
from molecule_agent.client import _sha256_file
# Create files in deliberately non-alphabetical order
(tmp_path / "z_file.txt").write_text("z")
(tmp_path / "a_file.txt").write_text("a")
(tmp_path / "m_file.txt").write_text("m")
(tmp_path / "plugin.yaml").write_text("name: order-test\n")
import hashlib, json
# Sort by path (as _walk_files does) to compute the manifest
paths = sorted(["a_file.txt", "m_file.txt", "z_file.txt"])
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
from molecule_agent.client import verify_plugin_sha256
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
# Even adding/removing in different order yields the same hash
(tmp_path / "b_file.txt").write_text("b")
paths.append("b_file.txt")
file_hashes.append(("b_file.txt", _sha256_file(tmp_path / "b_file.txt")))
new_manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
def test_large_plugin_directory_hash(self, tmp_path: Path):
"""A directory with many files hashes correctly (no path limit)."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file, _walk_files
# Create 50 files to exercise the sort and hashing path
for i in range(50):
sub = tmp_path / f"sub{i % 5}"
sub.mkdir(exist_ok=True)
(sub / f"file-{i:03d}.txt").write_text(f"content-{i}")
import hashlib, json
paths = sorted(_walk_files(tmp_path))
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
def test_install_plugin_sha256_verified_setup_sh_not_run_on_mismatch(
self, tmp_path: Path, mockserver: MockServer
):
"""When sha256 declared in plugin.yaml doesn't match unpacked content,
install_plugin raises ValueError and setup.sh is NOT executed."""
from molecule_agent.client import RemoteAgentClient
# Plugin with a deliberately wrong sha256
wrong_sha = "deadbeef" + "0" * 56
plugin_yaml_content = f"name: corrupted\nversion: 1.0\nsha256: {wrong_sha}\n".encode()
buf = io.BytesIO()
import tarfile
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
info = tarfile.TarInfo(name="plugin.yaml")
info.size = len(plugin_yaml_content)
tf.addfile(info, io.BytesIO(plugin_yaml_content))
setup_sh = b"#!/bin/bash\ntouch setup-must-not-run\n"
sinfo = tarfile.TarInfo(name="setup.sh")
sinfo.size = len(setup_sh)
tf.addfile(sinfo, io.BytesIO(setup_sh))
tarball = buf.getvalue()
class _StreamResp:
status_code = 200
content = tarball
def __enter__(self): return self
def __exit__(self, *a): return None
def raise_for_status(self) -> None:
pass
mockserver.get = lambda url, **kw: _StreamResp()
class _FakeSession:
def get(self, url, **kw):
return mockserver.get(url, **kw)
def post(self, url, **kw):
class R:
status_code = 200
def json(self):
return {}
def raise_for_status(self):
pass
return R()
def __enter__(self):
return self
def __exit__(self, *a):
pass
client = RemoteAgentClient(
workspace_id="ws-test",
platform_url="http://platform.test",
token_dir=tmp_path / "tokens",
session=_FakeSession(),
)
client.save_token("tok")
with pytest.raises(ValueError, match="sha256 mismatch"):
client.install_plugin("corrupted")
# Plugin directory must not exist (atomic rollback)
assert not (client.plugins_dir / "corrupted").exists()