Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 42fc46fde1 | |||
| 003d48e2b7 | |||
| 367b7eb024 | |||
| 1da02f9d06 |
@@ -66,27 +66,19 @@ jobs:
|
||||
# PR#372's ci.yml port used. Diffs against the PR base or the
|
||||
# previous push SHA, then matches against the wheel-relevant
|
||||
# path set.
|
||||
#
|
||||
# Root fix (mc#917): Gitea Actions does not expose github.event.before
|
||||
# as a ${{ }} template-expression that resolves in shell scripts for
|
||||
# push events (it becomes empty string). The env var GITHUB_EVENT_BEFORE
|
||||
# IS set by the runner for push events. Guard git cat-file with
|
||||
# `timeout 30` to prevent indefinite hangs on malformed BASE values.
|
||||
BASE="${GITHUB_BASE_REF:-${{ github.event.before }}}"
|
||||
if [ "${{ github.event_name }}" = "pull_request" ] && [ -n "${{ github.event.pull_request.base.sha }}" ]; then
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
else
|
||||
BASE="${GITHUB_EVENT_BEFORE:-}"
|
||||
fi
|
||||
if [ -z "$BASE" ] || echo "$BASE" | grep -qE '^0+$'; then
|
||||
# New branch or no previous SHA: treat as wheel-relevant.
|
||||
echo "wheel=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
if ! timeout 30 git cat-file -e "$BASE" 2>/dev/null; then
|
||||
if ! git cat-file -e "$BASE" 2>/dev/null; then
|
||||
git fetch --depth=1 origin "$BASE" 2>/dev/null || true
|
||||
fi
|
||||
if ! timeout 30 git cat-file -e "$BASE" 2>/dev/null; then
|
||||
echo "::notice::BASE=$BASE not in local clone (shallow fetch or pruned ref)"
|
||||
if ! git cat-file -e "$BASE" 2>/dev/null; then
|
||||
echo "wheel=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
+7
-148
@@ -12,14 +12,12 @@ Environment variables (set by the workspace container):
|
||||
PLATFORM_URL — platform API base URL (e.g. http://platform:8080)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
# Top-level (not inside main()) so the wheel rewriter expands this to
|
||||
@@ -767,163 +765,24 @@ async def main(): # pragma: no cover
|
||||
break
|
||||
|
||||
|
||||
# --- HTTP/SSE Transport (for Hermes runtime) ---
|
||||
|
||||
# Per-connection pending request queue.
|
||||
# Maps connection-id → asyncio.Queue of JSON-RPC responses.
|
||||
_http_connection_queues: dict[str, asyncio.Queue] = {}
|
||||
_http_connection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _handle_http_mcp(request) -> dict | None:
|
||||
"""Handle an incoming JSON-RPC request over HTTP. Returns the JSON-RPC response dict, or None for notifications."""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}
|
||||
|
||||
req_id = body.get("id")
|
||||
method = body.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": _build_initialize_result(),
|
||||
}
|
||||
elif method == "notifications/initialized":
|
||||
return None # No response needed
|
||||
elif method == "tools/list":
|
||||
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": TOOLS}}
|
||||
elif method == "tools/call":
|
||||
params = body.get("params", {})
|
||||
tool_name = params.get("name", "")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_text = await handle_tool_call(tool_name, tool_args)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"content": [{"type": "text", "text": result_text}]},
|
||||
}
|
||||
else:
|
||||
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}
|
||||
|
||||
|
||||
async def _run_http_server(port: int) -> None:
|
||||
"""Run MCP server over HTTP/SSE — compatible with Hermes MCP-native agents."""
|
||||
try:
|
||||
from starlette.applications import Starlette # noqa: F401
|
||||
from starlette.routing import Route # noqa: F401
|
||||
from starlette.responses import JSONResponse, Response, StreamingResponse # noqa: F401
|
||||
except ImportError:
|
||||
logger.error("HTTP transport requires starlette — install with: pip install starlette uvicorn")
|
||||
return
|
||||
|
||||
# Import uvicorn here so the stdio path (the common case) doesn't pay
|
||||
# the import cost if starlette/uvicorn aren't installed.
|
||||
import uvicorn # noqa: F401
|
||||
|
||||
_http_connection_queues.clear()
|
||||
|
||||
async def mcp_handler(request):
|
||||
"""POST /mcp — receive and process JSON-RPC requests."""
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await _handle_http_mcp(request)
|
||||
if response is None:
|
||||
return Response(status_code=202)
|
||||
async with _http_connection_lock:
|
||||
queue = _http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
return Response(status_code=202)
|
||||
# No SSE subscriber — return JSON directly
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
"""GET /mcp/stream — SSE stream for push-based responses."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues[conn_id] = queue
|
||||
|
||||
async def event_stream():
|
||||
yield f"event: connected\ndata: {json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await asyncio.wait_for(queue.get(), timeout=300)
|
||||
yield f"event: message\ndata: {json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues.pop(conn_id, None)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
logger.info(f"A2A MCP HTTP server listening on http://127.0.0.1:{port}/mcp")
|
||||
await server.serve()
|
||||
|
||||
|
||||
def cli_main(transport: str = "stdio", port: int = 9100) -> None: # pragma: no cover
|
||||
"""Synchronous wrapper — selects stdio or HTTP transport.
|
||||
def cli_main() -> None: # pragma: no cover
|
||||
"""Synchronous wrapper around the async MCP stdio loop.
|
||||
|
||||
Called by ``mcp_cli.main`` (the ``molecule-mcp`` console-script
|
||||
entry point in scripts/build_runtime_package.py) AFTER env
|
||||
validation and the standalone register + heartbeat thread setup.
|
||||
Direct callers (in-container code that already validated env and
|
||||
runs heartbeat.py separately) can also invoke this.
|
||||
runs heartbeat.py separately) can also invoke this — it's the
|
||||
smallest possible "run the MCP stdio JSON-RPC loop" surface.
|
||||
|
||||
Wheel-smoke gates in scripts/wheel_smoke.py pin the importability
|
||||
of this name (alongside ``mcp_cli.main``) so a silent rename can't
|
||||
break every external-runtime operator's MCP install — the 0.1.16
|
||||
``main_sync`` rename incident is the cautionary precedent.
|
||||
|
||||
Args:
|
||||
transport: "stdio" (default) or "http" (HTTP+SSE for Hermes).
|
||||
port: TCP port for HTTP transport (default 9100).
|
||||
"""
|
||||
if transport == "http":
|
||||
asyncio.run(_run_http_server(port))
|
||||
else:
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
parser = argparse.ArgumentParser(description="A2A MCP Server")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
default="stdio",
|
||||
choices=["stdio", "http"],
|
||||
help="Transport mode: stdio (default) or http (HTTP+SSE for Hermes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9100,
|
||||
help="TCP port for HTTP transport (default 9100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
cli_main(transport=args.transport, port=args.port)
|
||||
cli_main()
|
||||
|
||||
@@ -34,6 +34,8 @@ async def list_peers() -> list[dict]:
|
||||
|
||||
async def delegate_task(workspace_id: str, task: str) -> str:
|
||||
"""Send a task to a peer workspace via A2A and return the response text."""
|
||||
if not workspace_id:
|
||||
return "Error: workspace_id is required"
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
# Discover target URL
|
||||
try:
|
||||
|
||||
@@ -2103,3 +2103,71 @@ def test_peer_metadata_set_replaces_existing_entry_in_place(_reset_peer_metadata
|
||||
)
|
||||
cached = a2a_client._peer_metadata[peer]
|
||||
assert cached[1]["name"] == "v2", "re-write must update the value in place"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _safe_activity_id — non-string input guard (line 192)
|
||||
# =============================================================================
|
||||
|
||||
class TestSafeActivityId:
|
||||
"""Coverage for a2a_mcp_server._safe_activity_id()."""
|
||||
|
||||
def test_non_string_returns_empty_string(self):
|
||||
"""Non-str input is the defensive guard branch (line 192)."""
|
||||
from a2a_mcp_server import _safe_activity_id
|
||||
|
||||
# Each non-string type exercises the isinstance guard
|
||||
for value in (None, 123, [], {"a": 1}, 0.0):
|
||||
result = _safe_activity_id(value)
|
||||
assert result == "", f"{type(value).__name__} must return empty string"
|
||||
|
||||
def test_string_invalid_format_returns_empty(self):
|
||||
"""Valid type but non-UUID format → empty string."""
|
||||
from a2a_mcp_server import _safe_activity_id
|
||||
|
||||
assert _safe_activity_id("not-a-uuid") == ""
|
||||
assert _safe_activity_id("") == ""
|
||||
|
||||
def test_string_valid_format_passthrough(self):
|
||||
"""Valid UUID-format string passes through."""
|
||||
from a2a_mcp_server import _safe_activity_id
|
||||
|
||||
valid = "00000000-0000-0000-0000-000000000001"
|
||||
assert _safe_activity_id(valid) == valid
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _safe_ts — non-string input guard (line 198)
|
||||
# =============================================================================
|
||||
|
||||
class TestSafeTs:
|
||||
"""Coverage for a2a_mcp_server._safe_ts()."""
|
||||
|
||||
def test_non_string_returns_empty_string(self):
|
||||
"""Non-str input is the defensive guard branch (line 198)."""
|
||||
from a2a_mcp_server import _safe_ts
|
||||
|
||||
for value in (None, 123, [], {"a": 1}, 0.0):
|
||||
result = _safe_ts(value)
|
||||
assert result == "", f"{type(value).__name__} must return empty string"
|
||||
|
||||
def test_string_invalid_format_returns_empty(self):
|
||||
"""Valid type but non-ISO8601 format → empty string.
|
||||
|
||||
The regex accepts any 9999-99-99T99:99:99Z skeleton since it
|
||||
checks structure not calendar validity; use a clearly-invalid
|
||||
skeleton to exercise the else-branch.
|
||||
"""
|
||||
from a2a_mcp_server import _safe_ts
|
||||
|
||||
# Missing T separator — clearly not ISO8601
|
||||
assert _safe_ts("2026-05-13 10:30:00Z") == ""
|
||||
assert _safe_ts("not a date") == ""
|
||||
assert _safe_ts("") == ""
|
||||
|
||||
def test_string_valid_format_passthrough(self):
|
||||
"""Valid ISO8601 string passes through."""
|
||||
from a2a_mcp_server import _safe_ts
|
||||
|
||||
valid = "2026-05-13T10:30:00Z"
|
||||
assert _safe_ts(valid) == valid
|
||||
|
||||
@@ -1,671 +0,0 @@
|
||||
"""Tests for the HTTP/SSE transport of a2a_mcp_server.
|
||||
|
||||
Covers:
|
||||
- _handle_http_mcp: JSON-RPC request parsing and routing
|
||||
- Starlette app routes: POST /mcp, GET /mcp/stream, GET /health
|
||||
- cli_main argparse: --transport and --port flags
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
"""Minimal request duck-type for _handle_http_mcp."""
|
||||
|
||||
def __init__(self, body_json: dict, headers: dict | None = None):
|
||||
self._body = body_json
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_http_mcp — unit tests (no I/O)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_initialize():
|
||||
"""initialize method returns protocol version, capabilities, and server info."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 42, "method": "initialize", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 42
|
||||
assert "protocolVersion" in resp["result"]
|
||||
assert "capabilities" in resp["result"]
|
||||
assert resp["result"]["serverInfo"]["name"] == "molecule"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_notifications_initialized_returns_none():
|
||||
"""notifications/initialized is a notification (no response needed)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_list():
|
||||
"""tools/list returns the TOOLS schema."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 7, "method": "tools/list"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 7
|
||||
assert "tools" in resp["result"]
|
||||
assert isinstance(resp["result"]["tools"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_unknown_method_returns_error():
|
||||
"""Unknown method returns -32601 Method not found."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 3, "method": "foobar", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 3
|
||||
assert resp["error"]["code"] == -32601
|
||||
assert "Method not found" in resp["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_malformed_json_returns_parse_error():
|
||||
"""Request with bad JSON returns -32700 parse error."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest.__new__(_DummyRequest)
|
||||
req.headers = {}
|
||||
req.json = AsyncMock(side_effect=ValueError("bad json"))
|
||||
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["error"]["code"] == -32700
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_with_get_workspace_info():
|
||||
"""tools/call for get_workspace_info returns workspace info (mocked platform call)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="mocked info")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 9,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 9
|
||||
assert resp["result"]["content"][0]["text"] == "mocked info"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_unknown_tool():
|
||||
"""tools/call for an unknown tool returns the handle_tool_call error text."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 11,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "not_a_real_tool", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 11
|
||||
assert "Unknown tool" in resp["result"]["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Starlette app — integration tests with TestClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clear_http_globals():
|
||||
"""Reset module-level HTTP state before and after each test."""
|
||||
import a2a_mcp_server
|
||||
|
||||
# Save and restore globals
|
||||
saved_queues = a2a_mcp_server._http_connection_queues.copy()
|
||||
saved_lock = a2a_mcp_server._http_connection_lock
|
||||
a2a_mcp_server._http_connection_queues.clear()
|
||||
yield
|
||||
# Restore
|
||||
a2a_mcp_server._http_connection_queues = saved_queues
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _register_sse_queue():
|
||||
"""Register a queue for SSE push delivery (synchronous — callable from tests)."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue = asyncio.Queue(maxsize=100)
|
||||
import a2a_mcp_server
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
return conn_id, queue
|
||||
|
||||
|
||||
def _build_test_app(port: int = 9100):
|
||||
"""Build the Starlette app for testing without starting a real server.
|
||||
|
||||
Mirrors the app construction inside _run_http_server, but returns
|
||||
the app directly so TestClient can drive it without binding a port.
|
||||
"""
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
import a2a_mcp_server
|
||||
|
||||
async def mcp_handler(request):
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await a2a_mcp_server._handle_http_mcp(request)
|
||||
if response is None:
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
queue = a2a_mcp_server._http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
import asyncio as _asyncio
|
||||
|
||||
async def event_stream():
|
||||
import json as _json
|
||||
yield f"event: connected\ndata: {_json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await _asyncio.wait_for(queue.get(), timeout=300)
|
||||
import json as _json
|
||||
yield f"event: message\ndata: {_json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except _asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues.pop(conn_id, None)
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
return Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestHTTPAppRoutes:
|
||||
"""Integration tests using Starlette TestClient against the HTTP app.
|
||||
|
||||
Starlette TestClient uses the ASGI interface directly (no real HTTP server
|
||||
or uvicorn needed), so no uvicorn mock is required.
|
||||
"""
|
||||
|
||||
def test_health_returns_ok_and_transport(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9100)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["transport"] == "http+sse"
|
||||
assert data["port"] == 9100
|
||||
|
||||
def test_health_accepts_different_port(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9999)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.json()["port"] == 9999
|
||||
|
||||
def test_mcp_post_initialize(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 1
|
||||
assert "protocolVersion" in data["result"]
|
||||
|
||||
def test_mcp_post_tools_list(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "tools" in data["result"]
|
||||
assert len(data["result"]["tools"]) > 0
|
||||
|
||||
def test_mcp_post_notifications_initialized_returns_202(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
})
|
||||
|
||||
# Notifications return 202 with no body
|
||||
assert resp.status_code == 202
|
||||
|
||||
def test_mcp_post_unknown_method_returns_200_with_error(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "no_such_method",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
def test_mcp_post_malformed_json_returns_error(self, _clear_http_globals):
|
||||
"""Malformed JSON body returns a JSON-RPC parse-error response (HTTP 200)."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
content=b"not json at all",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
# _handle_http_mcp catches ValueError from request.json() and returns
|
||||
# a JSON-RPC parse-error response with HTTP 200.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["error"]["code"] == -32700
|
||||
assert "Parse error" in resp.json()["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_stream_populates_queue(self, _clear_http_globals):
|
||||
"""_register_sse_queue adds a queue to _http_connection_queues before any async work."""
|
||||
import a2a_mcp_server
|
||||
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
# The queue is registered synchronously — no await needed, no cleanup ran yet.
|
||||
assert conn_id in a2a_mcp_server._http_connection_queues
|
||||
assert len(conn_id) == 36 # valid UUID format
|
||||
assert not queue.full()
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_queue_delivers_response(self, _clear_http_globals):
|
||||
"""POST /mcp with x-mcp-conn-id routes response into the SSE queue."""
|
||||
import uuid
|
||||
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register an SSE queue to simulate an active SSE subscriber
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
|
||||
# POST a tools/call with the conn_id header
|
||||
with TestClient(_build_test_app()) as client:
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="test-ws-info")):
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 99,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
},
|
||||
)
|
||||
|
||||
# The handler returns 202 because the response was queued for SSE delivery
|
||||
assert resp.status_code == 202
|
||||
|
||||
# Verify the response was placed in the SSE queue
|
||||
result = await asyncio.wait_for(queue.get(), timeout=2.0)
|
||||
assert result["id"] == 99
|
||||
assert result["result"]["content"][0]["text"] == "test-ws-info"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_tool_call — remaining tool branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_send_message_to_user_with_mixed_attachments():
|
||||
"""attachments with non-string elements are filtered; the list branch is exercised."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_send_message_to_user", AsyncMock(return_value="sent ok")) as mock_fn:
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 21,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "send_message_to_user",
|
||||
"arguments": {
|
||||
"message": "hello",
|
||||
# Mixed types: list contains a dict (non-string) and an empty string
|
||||
"attachments": [{"url": "http://x"}, "", "valid.zip", None],
|
||||
},
|
||||
},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "sent ok"
|
||||
# Only string, non-empty values passed through
|
||||
mock_fn.assert_called_once()
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs["attachments"] == ["valid.zip"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_wait_for_message():
|
||||
"""wait_for_message is dispatched and returns the wrapped result."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_wait_for_message", AsyncMock(return_value="no messages")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 22,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "wait_for_message", "arguments": {"timeout_secs": 5.0}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "no messages"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_inbox_peek():
|
||||
"""inbox_peek is dispatched with the limit argument."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_inbox_peek", AsyncMock(return_value="2 items")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 23,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "inbox_peek", "arguments": {"limit": 5}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "2 items"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_inbox_pop():
|
||||
"""inbox_pop is dispatched with the activity_id argument."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_inbox_pop", AsyncMock(return_value="acked")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 24,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "inbox_pop", "arguments": {"activity_id": "abc-123"}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "acked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_chat_history():
|
||||
"""chat_history is dispatched with peer_id, limit, and before_ts arguments."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_chat_history", AsyncMock(return_value="history")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 25,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_history",
|
||||
"arguments": {"peer_id": "ws-peer-1", "limit": 10, "before_ts": ""},
|
||||
},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "history"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cli_main argparse — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_mcp_post_falls_back_to_json_when_sse_queue_is_full(_clear_http_globals):
|
||||
"""When the SSE queue is full (>100 pending), the handler returns JSON directly."""
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register a queue and fill it to capacity
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=2) # small queue for testing
|
||||
|
||||
async def _setup():
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
queue.put_nowait({"id": 1})
|
||||
queue.put_nowait({"id": 2})
|
||||
|
||||
_sync_run(_setup())
|
||||
assert queue.full()
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={"jsonrpc": "2.0", "id": 99, "method": "initialize", "params": {}},
|
||||
)
|
||||
|
||||
# With a full queue, the handler returns the response as JSON (not 202)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == 99
|
||||
assert "result" in resp.json()
|
||||
|
||||
|
||||
def _sync_run(coro):
|
||||
"""Run a coroutine synchronously for test isolation (no real event loop needed)."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
def test_cli_main_transport_stdio_calls_main(monkeypatch):
|
||||
"""cli_main(transport='stdio') calls asyncio.run(main) without HTTP."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_calls: list = []
|
||||
|
||||
async def fake_main():
|
||||
run_calls.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="stdio", port=9100)
|
||||
|
||||
assert "called" in run_calls
|
||||
|
||||
|
||||
def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
|
||||
"""cli_main(transport='http') calls _run_http_server without stdio."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_http_calls = []
|
||||
|
||||
async def fake_run_http(port):
|
||||
run_http_calls.append(port)
|
||||
|
||||
# asyncio.run must execute the coroutine for _run_http_server to be called
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_run_http_server", fake_run_http)
|
||||
# stdio path must not be entered
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9102)
|
||||
|
||||
assert run_http_calls == [9102]
|
||||
|
||||
|
||||
def test_cli_main_http_skips_stdio_check(monkeypatch):
|
||||
"""When transport=http, _assert_stdio_is_pipe_compatible must NOT be called."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called = []
|
||||
|
||||
def fake_assert():
|
||||
called.append("assert_called")
|
||||
|
||||
# Patch on the module object directly
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", fake_assert)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", lambda fn: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9100)
|
||||
|
||||
assert "assert_called" not in called
|
||||
|
||||
|
||||
def test_cli_main_default_transport_is_stdio(monkeypatch):
|
||||
"""cli_main() with no args defaults to stdio transport."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called_as: list = []
|
||||
|
||||
async def fake_main():
|
||||
called_as.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main() # No args — defaults to stdio
|
||||
|
||||
assert "called" in called_as
|
||||
|
||||
|
||||
def test_cli_main_main_raises_propagates(monkeypatch):
|
||||
"""If main() raises, cli_main() re-raises (doesn't swallow)."""
|
||||
import a2a_mcp_server
|
||||
|
||||
async def fake_main():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
a2a_mcp_server.cli_main(transport="stdio")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# uvicorn/starlette lazy-import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_http_server_is_coroutine_function():
|
||||
"""_run_http_server is a coroutine function accepting a port argument."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
assert inspect.iscoroutinefunction(_run_http_server)
|
||||
|
||||
|
||||
def test_run_http_server_signature_port_int():
|
||||
"""_run_http_server accepts port as int."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
sig = inspect.signature(_run_http_server)
|
||||
assert "port" in sig.parameters
|
||||
assert sig.parameters["port"].annotation == int
|
||||
@@ -0,0 +1,432 @@
|
||||
"""Test coverage for ``builtin_tools.a2a_tools`` and ``send_message_wrapper``.
|
||||
|
||||
Issue #367: 21 new test cases targeting previously-uncovered branches.
|
||||
|
||||
HTTP mocking: each test patches ``builtin_tools.a2a_tools.httpx.AsyncClient``
|
||||
with an ``AsyncMock`` so no real network I/O occurs. The patch target is
|
||||
the attribute as seen inside the ``a2a_tools`` module (where httpx is imported
|
||||
as ``import httpx``), so ``@pytest.fixture(autouse=True)`` from conftest.py is
|
||||
harmless — it replaces the module-level name *after* our patch exits.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conftest.py fixture — swap the MagicMock for the real module for THIS file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _real_a2a_tools_module():
|
||||
"""Replace conftest's MagicMock of builtin_tools.a2a_tools with the real module.
|
||||
|
||||
conftest.py sets sys.modules["builtin_tools.a2a_tools"] = <MagicMock> so that
|
||||
adapter tests don't accidentally hit the platform. For THIS test file we
|
||||
want the real module, so we restore it from disk and swap it back after.
|
||||
"""
|
||||
import builtin_tools.a2a_tools as real_module
|
||||
|
||||
# conftest.py may have clobbered builtin_tools.__path__; restore it so the
|
||||
# import above finds builtin_tools/a2a_tools.py on disk.
|
||||
if "builtin_tools" in sys.modules:
|
||||
real_builtin = sys.modules["builtin_tools"]
|
||||
if getattr(real_builtin, "__path__", None) == []:
|
||||
builtin_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
real_builtin.__path__ = [os.path.join(builtin_dir, "builtin_tools")]
|
||||
|
||||
saved = sys.modules.get("builtin_tools.a2a_tools")
|
||||
# Ensure we have the real module (reload if sys.modules already has it)
|
||||
if saved is None or saved is real_module:
|
||||
import importlib
|
||||
importlib.reload(real_module)
|
||||
sys.modules["builtin_tools.a2a_tools"] = real_module
|
||||
yield
|
||||
sys.modules["builtin_tools.a2a_tools"] = saved
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_env(monkeypatch):
|
||||
"""Per-test: set required env vars."""
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_response(
|
||||
json_data, status_code: int = 200
|
||||
) -> MagicMock:
|
||||
"""Return a fully-configured AsyncMock that mirrors httpx.Response."""
|
||||
resp = MagicMock()
|
||||
resp.json = MagicMock(return_value=json_data)
|
||||
resp.status_code = status_code
|
||||
return resp
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — list_peers
|
||||
# =============================================================================
|
||||
|
||||
class TestListPeers:
|
||||
"""Coverage for builtin_tools/a2a_tools.list_peers()."""
|
||||
|
||||
async def test_returns_peers_on_200(self):
|
||||
"""Successful GET returns the peer list."""
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
peers = [
|
||||
{"id": "ws-1", "name": "Alpha", "role": "sre", "status": "online"},
|
||||
{"id": "ws-2", "name": "Beta", "role": "dev", "status": "busy"},
|
||||
]
|
||||
mock_resp = _make_mock_response(peers, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == peers
|
||||
|
||||
async def test_returns_empty_list_on_non_200(self):
|
||||
"""list_peers swallows all non-200 responses gracefully."""
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
mock_resp = _make_mock_response({}, 500)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == []
|
||||
|
||||
async def test_returns_empty_list_on_exception(self):
|
||||
"""Network errors must not propagate — list_peers returns []. """
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(
|
||||
side_effect=RuntimeError("dns failure")
|
||||
)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — delegate_task
|
||||
# =============================================================================
|
||||
|
||||
_DISCOVER_ROUTE = "http://test.invalid/registry/discover/ws-target"
|
||||
|
||||
|
||||
class TestDelegateTask:
|
||||
"""Coverage for builtin_tools/a2a_tools.delegate_task(workspace_id, task)."""
|
||||
|
||||
async def test_empty_workspace_id_returns_error(self):
|
||||
"""Empty workspace_id is validated before any network call."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
out = await delegate_task("", "do it")
|
||||
assert "Error" in out
|
||||
assert "workspace_id" in out.lower()
|
||||
|
||||
async def test_discover_returns_non_200(self):
|
||||
"""Discovery 4xx/5xx → error message with status code."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({}, 404)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "404" in out
|
||||
|
||||
async def test_discover_returns_200_with_empty_url(self):
|
||||
"""Discovery 200 but no url field → actionable error."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"name": "orphan"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "no URL" in out
|
||||
|
||||
async def test_a2a_post_returns_500(self):
|
||||
"""A2A send 5xx with empty body → str(data) returned (code doesn't check status_code)."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({}, 500)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
# Code checks json body, not status_code; empty body {} → str({})
|
||||
assert out == "{}"
|
||||
|
||||
async def test_result_parts_empty_dict(self):
|
||||
"""Regression #279: {"parts": []} → str(result), not "(no text)"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": {"parts": []}}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
# Must return str(result), not "(no text)"
|
||||
assert "parts" in out
|
||||
assert "(no text)" not in out
|
||||
|
||||
async def test_result_is_plain_string(self):
|
||||
"""A bare string result returns as-is."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": "just a plain string"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "just a plain string"
|
||||
|
||||
async def test_result_is_number(self):
|
||||
"""Non-dict, non-string result → falls through to "(no text)"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": 12345}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "(no text)"
|
||||
|
||||
async def test_result_parts_non_dict_element(self):
|
||||
"""parts[0] is not a dict → falls through to "(no text)".
|
||||
|
||||
The code checks if parts[0] is a dict; since 123 is an int, it hits
|
||||
the else-branch and returns "(no text)".
|
||||
"""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": {"parts": [123, "also a string"]}}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "(no text)"
|
||||
|
||||
async def test_error_dict_form(self):
|
||||
"""{"error": {"message": "..."}} → "Error: ..."."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response(
|
||||
{"error": {"message": "peer overloaded", "code": 429}}, 200
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "Error: peer overloaded"
|
||||
|
||||
async def test_error_string_form(self):
|
||||
"""{"error": "string error"} → "Error: string error"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"error": "workspace offline"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "Error: workspace offline"
|
||||
|
||||
async def test_error_null(self):
|
||||
"""{"error": null} → "Error: None" (edge case — str(null) in message)."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"error": None}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
|
||||
async def test_a2a_post_raises_exception(self):
|
||||
"""Network error during A2A POST → Error: sending A2A message: ..."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(
|
||||
side_effect=ConnectionError("connection refused")
|
||||
)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "connection refused" in out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — get_peers_summary
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetPeersSummary:
|
||||
"""Coverage for builtin_tools/a2a_tools.get_peers_summary()."""
|
||||
|
||||
async def test_empty_peers_returns_no_peers_available(self):
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
mock_resp = _make_mock_response([], 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "No peers" in out
|
||||
|
||||
async def test_peer_missing_fields(self):
|
||||
"""Peers with missing name/id/role/status must not KeyError/TypeError."""
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
mock_resp = _make_mock_response([{"id": "ws-x"}], 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "ws-x" in out
|
||||
assert isinstance(out, str)
|
||||
|
||||
async def test_healthy_peer_roundtrip(self):
|
||||
"""Sanity: normal peer dicts produce a formatted list."""
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
peers = [
|
||||
{"id": "ws-alpha", "name": "Alpha", "role": "sre", "status": "online"},
|
||||
]
|
||||
mock_resp = _make_mock_response(peers, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "Alpha" in out
|
||||
assert "ws-alpha" in out
|
||||
assert "sre" in out
|
||||
assert "online" in out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# send_message_wrapper — safe_send_message
|
||||
# =============================================================================
|
||||
|
||||
from adapters.smolagents.send_message_wrapper import safe_send_message
|
||||
|
||||
|
||||
class TestSafeSendMessage:
|
||||
"""Coverage for adapters.smolagents.send_message_wrapper.safe_send_message()."""
|
||||
|
||||
def test_non_string_input_converted(self):
|
||||
"""Non-str text is str()-converted before escaping."""
|
||||
delivered = []
|
||||
safe_send_message(42, send_fn=lambda s: delivered.append(s))
|
||||
assert delivered == ["[smolagents] 42"]
|
||||
assert isinstance(delivered[0], str)
|
||||
|
||||
def test_html_entities_escaped(self):
|
||||
"""< > ' are escaped so rendered UIs cannot be injected.
|
||||
|
||||
The payload <script>alert('xss')</script> has no literal '&', so &
|
||||
does not appear. The escape output is: <script>alert('xss')</script>
|
||||
"""
|
||||
delivered = []
|
||||
safe_send_message(
|
||||
"<script>alert('xss')</script>",
|
||||
send_fn=lambda s: delivered.append(s),
|
||||
)
|
||||
assert "<" in delivered[0]
|
||||
assert ">" in delivered[0]
|
||||
assert "'" in delivered[0]
|
||||
assert "<script>" in delivered[0]
|
||||
# The angle brackets and quotes must NOT appear unescaped
|
||||
assert "<script>" not in delivered[0]
|
||||
assert "alert('" not in delivered[0]
|
||||
|
||||
def test_truncation_at_max_len(self):
|
||||
"""Text > 2000 chars is truncated; caller is warned."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
long_text = "A" * 2500
|
||||
safe_send_message(long_text, send_fn=lambda s: delivered.append(s))
|
||||
assert len(delivered[0]) < len(long_text)
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "truncating" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
def test_no_truncation_under_max_len(self):
|
||||
"""Text ≤ 2000 chars is passed through intact with no warning."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
text = "A" * 1500
|
||||
safe_send_message(text, send_fn=lambda s: delivered.append(s))
|
||||
expected = f"[smolagents] {text}"
|
||||
assert delivered[0] == expected
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_debug_log_emitted(self):
|
||||
"""Every delivery logs at DEBUG with final payload length."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
safe_send_message("hello", send_fn=lambda s: delivered.append(s))
|
||||
mock_logger.debug.assert_called_once()
|
||||
assert "delivering" in mock_logger.debug.call_args[0][0]
|
||||
|
||||
def test_label_prefix_always_present(self):
|
||||
"""Every delivered payload starts with '[smolagents]'."""
|
||||
delivered = []
|
||||
safe_send_message("x", send_fn=lambda s: delivered.append(s))
|
||||
assert delivered[0].startswith("[smolagents]")
|
||||
@@ -998,3 +998,87 @@ def test_heartbeat_500_does_not_increment_auth_counter(monkeypatch, caplog):
|
||||
f"5xx must NOT be classified as auth failure — would mislead operator. "
|
||||
f"Got 'revoked' ERRORs: {[r.message[:80] for r in revoked_errors]}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Subcommand dispatch — doctor / --help / MOLECULE_WORKSPACES errors
|
||||
# =============================================================================
|
||||
|
||||
class TestSubcommandDispatch:
|
||||
"""Coverage for mcp_cli.py argv dispatch (lines 110-122, 138-140)."""
|
||||
|
||||
def test_doctor_subcommand_calls_mcp_doctor_run(self, monkeypatch, capsys):
|
||||
"""molecule-mcp doctor → imports mcp_doctor and exits with its code."""
|
||||
import mcp_doctor
|
||||
|
||||
monkeypatch.setattr(mcp_doctor, "run", lambda: 0)
|
||||
monkeypatch.setattr(sys, "argv", ["mcp", "doctor"])
|
||||
# Also stub PLATFORM_URL so we don't hit the env-check first
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "tok1")
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mcp_cli.main()
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
def test_help_flag_exits_zero_and_prints_usage(self, monkeypatch, capsys):
|
||||
"""molecule-mcp --help / -h / help → prints usage and exits 0."""
|
||||
for arg in ("--help", "-h", "help"):
|
||||
monkeypatch.setattr(sys, "argv", ["mcp", arg])
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mcp_cli.main()
|
||||
captured = capsys.readouterr()
|
||||
assert exc_info.value.code == 0
|
||||
assert "molecule-mcp" in captured.out
|
||||
assert "doctor" in captured.out
|
||||
|
||||
def test_molecule_workspaces_error_prints_to_stderr(self, monkeypatch, capsys):
|
||||
"""MOLECULE_WORKSPACES with invalid entries prints to stderr."""
|
||||
# Must have PLATFORM_URL set or it exits before reaching this branch
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
# Invalid MOLECULE_WORKSPACES format so _resolve_workspaces returns errors
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACES", "invalid-entry")
|
||||
# Reset argv to a clean state
|
||||
monkeypatch.setattr(sys, "argv", ["mcp"])
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mcp_cli.main()
|
||||
captured = capsys.readouterr()
|
||||
assert exc_info.value.code == 2
|
||||
assert "invalid MOLECULE_WORKSPACES" in captured.err
|
||||
|
||||
|
||||
class TestRegisterWorkspaceTokenImportError:
|
||||
"""Coverage for mcp_cli.py lines 181-185 — ImportError fallback."""
|
||||
|
||||
def test_import_error_is_swallowed_and_continues(
|
||||
self, monkeypatch, capsys, tmp_path
|
||||
):
|
||||
"""When platform_auth.register_workspace_token is absent, CLI continues."""
|
||||
# Set up a valid single-workspace environment so main() does NOT
|
||||
# exit early — it reaches the register_workspace_token call
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "tok1")
|
||||
# Ensure heartbeat is disabled
|
||||
monkeypatch.setenv("MOLECULE_MCP_DISABLE_HEARTBEAT", "1")
|
||||
monkeypatch.setenv("CONFIGS_DIR", str(tmp_path))
|
||||
|
||||
# Remove register_workspace_token from platform_auth so the
|
||||
# ImportError branch fires (lines 181-185)
|
||||
import platform_auth
|
||||
saved = getattr(platform_auth, "register_workspace_token", None)
|
||||
if saved is not None:
|
||||
delattr(platform_auth, "register_workspace_token")
|
||||
|
||||
try:
|
||||
# If ImportError is not handled, main() raises ImportError here.
|
||||
# The test verifies it is handled (no exception propagates).
|
||||
with pytest.raises(SystemExit):
|
||||
mcp_cli.main()
|
||||
finally:
|
||||
# Restore so other tests are not affected
|
||||
if saved is not None:
|
||||
platform_auth.register_workspace_token = saved
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Test coverage for shared_runtime helpers (issue #366).
|
||||
|
||||
Six helper functions previously had zero test coverage:
|
||||
_extract_part_text, extract_message_text, format_conversation_history,
|
||||
build_task_text, append_peer_guidance, brief_task
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from shared_runtime import (
|
||||
_extract_part_text,
|
||||
append_peer_guidance,
|
||||
brief_task,
|
||||
build_task_text,
|
||||
extract_message_text,
|
||||
format_conversation_history,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _extract_part_text
|
||||
# =============================================================================
|
||||
|
||||
class TestExtractPartText:
|
||||
"""Coverage for shared_runtime._extract_part_text()."""
|
||||
|
||||
def test_dict_with_text_field(self):
|
||||
assert _extract_part_text({"text": "hello"}) == "hello"
|
||||
|
||||
def test_dict_without_text_field(self):
|
||||
assert _extract_part_text({"type": "image"}) == ""
|
||||
|
||||
def test_dict_with_empty_text_field(self):
|
||||
assert _extract_part_text({"text": ""}) == ""
|
||||
|
||||
def test_dict_with_root_nesting(self):
|
||||
"""Text buried in part['root']['text'] is extracted."""
|
||||
assert _extract_part_text({"root": {"text": "nested"}}) == "nested"
|
||||
|
||||
def test_dict_with_root_non_dict(self):
|
||||
"""part['root'] that is not a dict is safely skipped."""
|
||||
assert _extract_part_text({"root": "string", "text": "top"}) == "top"
|
||||
|
||||
def test_object_with_text_attribute(self):
|
||||
class FakePart:
|
||||
text = "attr-text"
|
||||
|
||||
assert _extract_part_text(FakePart()) == "attr-text"
|
||||
|
||||
def test_object_with_root_object_with_text(self):
|
||||
"""Object with root.attr.text is extracted (A2A v1 object style)."""
|
||||
|
||||
class FakeRoot:
|
||||
text = "root-attr-text"
|
||||
|
||||
class FakePart:
|
||||
root = FakeRoot()
|
||||
|
||||
assert _extract_part_text(FakePart()) == "root-attr-text"
|
||||
|
||||
def test_object_with_empty_text_attribute(self):
|
||||
class FakePart:
|
||||
text = ""
|
||||
|
||||
assert _extract_part_text(FakePart()) == ""
|
||||
|
||||
def test_none_input(self):
|
||||
assert _extract_part_text(None) == ""
|
||||
|
||||
def test_unexpected_type(self):
|
||||
"""Plain int/float/bool falls through to empty string."""
|
||||
assert _extract_part_text(42) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# extract_message_text
|
||||
# =============================================================================
|
||||
|
||||
class TestExtractMessageText:
|
||||
"""Coverage for shared_runtime.extract_message_text()."""
|
||||
|
||||
def test_list_of_dict_parts(self):
|
||||
parts = [{"text": "hello"}, {"text": "world"}]
|
||||
assert extract_message_text(parts) == "hello world"
|
||||
|
||||
def test_single_part(self):
|
||||
assert extract_message_text([{"text": "single"}]) == "single"
|
||||
|
||||
def test_context_object_with_message_parts(self):
|
||||
"""RequestContext-like: .message.parts is the parts list."""
|
||||
|
||||
class FakeContext:
|
||||
class _Msg:
|
||||
parts = [{"text": "from context"}]
|
||||
|
||||
message = _Msg()
|
||||
|
||||
assert extract_message_text(FakeContext()) == "from context"
|
||||
|
||||
def test_context_object_without_message(self):
|
||||
"""No .message attr → falls back to treating input as a parts list."""
|
||||
|
||||
class FakeContext:
|
||||
pass # no .message
|
||||
|
||||
# Pass a list directly as the context-like object
|
||||
assert extract_message_text([{"text": "fallback"}]) == "fallback"
|
||||
|
||||
def test_whitespace_normalized(self):
|
||||
"""Leading/trailing whitespace is stripped; internal newlines are preserved."""
|
||||
parts = [{"text": " hello "}, {"text": "\nworld\n"}]
|
||||
result = extract_message_text(parts)
|
||||
# Leading/trailing stripped, but internal \n stays (join uses single space)
|
||||
assert result == "hello \nworld"
|
||||
assert not result.startswith(" ")
|
||||
assert not result.endswith(" ")
|
||||
|
||||
def test_empty_parts_list(self):
|
||||
assert extract_message_text([]) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# format_conversation_history
|
||||
# =============================================================================
|
||||
|
||||
class TestFormatConversationHistory:
|
||||
"""Coverage for shared_runtime.format_conversation_history()."""
|
||||
|
||||
def test_single_user_message(self):
|
||||
hist = [("human", "hello")]
|
||||
out = format_conversation_history(hist)
|
||||
assert out == "User: hello"
|
||||
|
||||
def test_single_agent_message(self):
|
||||
hist = [("ai", "response")]
|
||||
out = format_conversation_history(hist)
|
||||
assert out == "Agent: response"
|
||||
|
||||
def test_interleaved_history(self):
|
||||
hist = [
|
||||
("human", "hello"),
|
||||
("ai", "hi there"),
|
||||
("human", "what is 2+2?"),
|
||||
("ai", "four"),
|
||||
]
|
||||
out = format_conversation_history(hist)
|
||||
lines = out.split("\n")
|
||||
assert lines[0] == "User: hello"
|
||||
assert lines[1] == "Agent: hi there"
|
||||
assert lines[2] == "User: what is 2+2?"
|
||||
assert lines[3] == "Agent: four"
|
||||
|
||||
def test_empty_history(self):
|
||||
assert format_conversation_history([]) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# build_task_text
|
||||
# =============================================================================
|
||||
|
||||
class TestBuildTaskText:
|
||||
"""Coverage for shared_runtime.build_task_text()."""
|
||||
|
||||
def test_no_history_returns_user_message_unchanged(self):
|
||||
assert build_task_text("do the thing", []) == "do the thing"
|
||||
|
||||
def test_history_prepends_transcript(self):
|
||||
hist = [("human", "hello"), ("ai", "hi")]
|
||||
result = build_task_text("follow-up", hist)
|
||||
assert "Conversation so far:" in result
|
||||
assert "User: hello" in result
|
||||
assert "Agent: hi" in result
|
||||
assert "follow-up" in result
|
||||
|
||||
def test_user_message_after_conversation_header(self):
|
||||
hist = [("human", "hello")]
|
||||
result = build_task_text("do it", hist)
|
||||
assert result.startswith("Conversation so far:")
|
||||
assert result.endswith("Current request: do it")
|
||||
|
||||
def test_empty_user_message_with_history(self):
|
||||
"""Empty user_message is still rendered with history."""
|
||||
hist = [("human", "hello")]
|
||||
result = build_task_text("", hist)
|
||||
assert "Conversation so far:" in result
|
||||
assert "Current request:" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# append_peer_guidance
|
||||
# =============================================================================
|
||||
|
||||
class TestAppendPeerGuidance:
|
||||
"""Coverage for shared_runtime.append_peer_guidance()."""
|
||||
|
||||
def test_base_text_appended(self):
|
||||
result = append_peer_guidance(
|
||||
"base text",
|
||||
peers_info="alpha: ws-1",
|
||||
default_text="default",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
assert result.startswith("base text")
|
||||
assert "## Peers" in result
|
||||
assert "alpha: ws-1" in result
|
||||
assert "Use delegate_task" in result
|
||||
|
||||
def test_null_base_text_uses_default(self):
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
peers_info="peer info",
|
||||
default_text="DEFAULT_TEXT",
|
||||
tool_name="tool",
|
||||
)
|
||||
assert result.startswith("DEFAULT_TEXT")
|
||||
|
||||
def test_whitespace_base_text_strips_to_empty_peers_still_added(self):
|
||||
"""Whitespace-only base_text is stripped but default_text is NOT used
|
||||
(only None triggers the fallback). The peers section is still appended."""
|
||||
result = append_peer_guidance(
|
||||
" ",
|
||||
peers_info="peer",
|
||||
default_text="DEF",
|
||||
tool_name="t",
|
||||
)
|
||||
# " ".strip() == ""; default_text is NOT substituted for whitespace
|
||||
assert "## Peers" in result
|
||||
assert "peer" in result
|
||||
assert "DEF" not in result # default_text only on None, not whitespace
|
||||
|
||||
def test_none_base_text_uses_default(self):
|
||||
"""None base_text triggers fallback to default_text."""
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
peers_info="peer",
|
||||
default_text="DEFAULT",
|
||||
tool_name="tool",
|
||||
)
|
||||
assert result.startswith("DEFAULT")
|
||||
assert "## Peers" in result
|
||||
|
||||
def test_empty_peers_info_skips_section(self):
|
||||
result = append_peer_guidance(
|
||||
"base",
|
||||
peers_info="",
|
||||
default_text="def",
|
||||
tool_name="tool",
|
||||
)
|
||||
# No "## Peers" section when peers_info is empty
|
||||
assert result == "base"
|
||||
|
||||
def test_whitespace_in_base_and_peers_normalized(self):
|
||||
result = append_peer_guidance(
|
||||
" base \n",
|
||||
peers_info=" peer-1 \n",
|
||||
default_text="def",
|
||||
tool_name="tool",
|
||||
)
|
||||
# Base should be stripped of leading/trailing whitespace
|
||||
assert result.startswith("base")
|
||||
# Peer info should be appended
|
||||
assert "peer-1" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# brief_task
|
||||
# =============================================================================
|
||||
|
||||
class TestBriefTask:
|
||||
"""Coverage for shared_runtime.brief_task()."""
|
||||
|
||||
def test_short_text_returned_unchanged(self):
|
||||
assert brief_task("hello", limit=60) == "hello"
|
||||
|
||||
def test_exact_limit_no_ellipsis(self):
|
||||
text = "A" * 60
|
||||
assert brief_task(text, limit=60) == text
|
||||
assert "..." not in text
|
||||
|
||||
def test_truncated_with_ellipsis(self):
|
||||
text = "A" * 80
|
||||
result = brief_task(text, limit=60)
|
||||
assert len(result) == 63 # 60 chars + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_limit_10_shortens(self):
|
||||
result = brief_task("hello world", limit=10)
|
||||
assert len(result) == 13 # 10 chars + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_limit_0_returns_ellipsis(self):
|
||||
"""limit=0 → 0-char slice + "..." since len("hello") > 0."""
|
||||
result = brief_task("hello", limit=0)
|
||||
assert result == "..."
|
||||
|
||||
def test_limit_1_single_char_plus_ellipsis(self):
|
||||
result = brief_task("hello", limit=1)
|
||||
assert len(result) == 4 # 1 char + "..."
|
||||
assert result.startswith("h")
|
||||
assert result.endswith("...")
|
||||
Reference in New Issue
Block a user