Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e38a885a4 | |||
| 9f3948dc3a | |||
| c4deda1035 | |||
| 0dbda533fb |
@@ -18,13 +18,6 @@ name: publish-workspace-server-image
|
||||
# :staging-<sha> — per-commit digest, stable for canary verify
|
||||
# :staging-latest — tracks most recent build on this branch
|
||||
#
|
||||
# Production auto-deploy:
|
||||
# After both platform and tenant images are pushed, deploy-production waits
|
||||
# for strict required push contexts on the same SHA to go green, then
|
||||
# calls the production CP redeploy-fleet endpoint with target_tag=
|
||||
# staging-<sha>. Set repo variable or secret PROD_AUTO_DEPLOY_DISABLED=true
|
||||
# to stop production rollout while keeping image publishing enabled.
|
||||
#
|
||||
# ECR target: 153263036946.dkr.ecr.us-east-2.amazonaws.com/molecule-ai/*
|
||||
# Required secrets: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AUTO_SYNC_TOKEN
|
||||
#
|
||||
@@ -45,10 +38,15 @@ on:
|
||||
- '.gitea/workflows/publish-workspace-server-image.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
# No `concurrency:` block here. Gitea 1.22.6 can cancel queued runs despite
|
||||
# `cancel-in-progress: false`; that is not acceptable for a workflow with a
|
||||
# production deploy job. Per-SHA image tags are immutable, and staging-latest is
|
||||
# best-effort last-writer-wins metadata.
|
||||
# Serialize per-branch so two rapid main pushes don't race the same
|
||||
# :staging-latest tag retag. Allow parallel runs as they produce
|
||||
# different :staging-<sha> tags and last-write-wins on :staging-latest.
|
||||
#
|
||||
# cancel-in-progress: false → in-flight builds finish; the next push's
|
||||
# build queues. This avoids a partially-pushed image.
|
||||
concurrency:
|
||||
group: publish-workspace-server-image-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -65,22 +63,20 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
# Health check: verify Docker daemon is accessible before attempting any
|
||||
# build steps. This fails loudly at step 1 when the runner's docker.sock
|
||||
# is inaccessible rather than silently continuing where `docker build`
|
||||
# fails deep in the process with a cryptic ECR auth error.
|
||||
- name: Verify Docker daemon access
|
||||
- name: Diagnose Docker daemon access
|
||||
run: |
|
||||
set -euo pipefail
|
||||
echo "::group::Docker daemon health check"
|
||||
echo "::group::Docker daemon diagnosis"
|
||||
echo "Runner: ${HOSTNAME:-unknown}"
|
||||
docker info 2>&1 | head -5 || {
|
||||
echo "::error::Docker daemon is not accessible at /var/run/docker.sock"
|
||||
echo "::error::Runner: ${HOSTNAME:-unknown}"
|
||||
echo "::error::Check: (1) daemon is running, (2) runner user is in docker group, (3) sock permissions are 660+"
|
||||
exit 1
|
||||
}
|
||||
echo "Docker daemon OK"
|
||||
echo "--- Socket info ---"
|
||||
ls -la /var/run/docker.sock 2>/dev/null || echo "/var/run/docker.sock: not found"
|
||||
stat /var/run/docker.sock 2>/dev/null || true
|
||||
echo "--- User info ---"
|
||||
id
|
||||
echo "--- docker version ---"
|
||||
docker version 2>&1 || true
|
||||
echo "--- docker info (full) ---"
|
||||
docker info 2>&1 || echo "docker info failed: exit $?"
|
||||
echo "::endgroup::"
|
||||
|
||||
# Pre-clone manifest deps before docker build.
|
||||
@@ -179,173 +175,3 @@ jobs:
|
||||
--tag "${TENANT_IMAGE_NAME}:${TAG_SHA}" \
|
||||
--tag "${TENANT_IMAGE_NAME}:${TAG_LATEST}" \
|
||||
--push .
|
||||
|
||||
# bp-exempt: production deploy side-effect; merge is gated by CI / all-required and this job waits for push CI before acting.
|
||||
deploy-production:
|
||||
name: Production auto-deploy
|
||||
needs: build-and-push
|
||||
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 75
|
||||
env:
|
||||
CP_URL: ${{ vars.PROD_CP_URL || 'https://api.moleculesai.app' }}
|
||||
CP_ADMIN_API_TOKEN: ${{ secrets.CP_ADMIN_API_TOKEN }}
|
||||
GITEA_HOST: git.moleculesai.app
|
||||
GITEA_TOKEN: ${{ secrets.PROD_AUTO_DEPLOY_CONTROL_TOKEN || secrets.AUTO_SYNC_TOKEN }}
|
||||
PROD_AUTO_DEPLOY_DISABLED: ${{ vars.PROD_AUTO_DEPLOY_DISABLED || secrets.PROD_AUTO_DEPLOY_DISABLED || '' }}
|
||||
PROD_AUTO_DEPLOY_CANARY_SLUG: ${{ vars.PROD_AUTO_DEPLOY_CANARY_SLUG || 'hongming' }}
|
||||
PROD_AUTO_DEPLOY_SOAK_SECONDS: ${{ vars.PROD_AUTO_DEPLOY_SOAK_SECONDS || '60' }}
|
||||
PROD_AUTO_DEPLOY_BATCH_SIZE: ${{ vars.PROD_AUTO_DEPLOY_BATCH_SIZE || '3' }}
|
||||
PROD_AUTO_DEPLOY_DRY_RUN: ${{ vars.PROD_AUTO_DEPLOY_DRY_RUN || '' }}
|
||||
PROD_ALLOW_NON_PROD_CP_URL: ${{ vars.PROD_ALLOW_NON_PROD_CP_URL || '' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Build deploy plan
|
||||
id: plan
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 .gitea/scripts/prod-auto-deploy.py plan > "$RUNNER_TEMP/prod-auto-deploy-plan.json"
|
||||
jq . "$RUNNER_TEMP/prod-auto-deploy-plan.json"
|
||||
enabled="$(jq -r '.enabled' "$RUNNER_TEMP/prod-auto-deploy-plan.json")"
|
||||
echo "enabled=$enabled" >> "$GITHUB_OUTPUT"
|
||||
if [ "$enabled" != "true" ]; then
|
||||
reason="$(jq -r '.disabled_reason' "$RUNNER_TEMP/prod-auto-deploy-plan.json")"
|
||||
echo "::notice::Production auto-deploy disabled: $reason"
|
||||
{
|
||||
echo "## Production auto-deploy skipped"
|
||||
echo ""
|
||||
echo "Reason: \`$reason\`"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
if [ -z "${CP_ADMIN_API_TOKEN:-}" ]; then
|
||||
echo "::error::CP_ADMIN_API_TOKEN secret is required for production auto-deploy."
|
||||
exit 1
|
||||
fi
|
||||
if [ -z "${GITEA_TOKEN:-}" ]; then
|
||||
echo "::error::AUTO_SYNC_TOKEN secret is required so production deploy can wait for green CI."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Self-test production deploy helper
|
||||
if: ${{ steps.plan.outputs.enabled == 'true' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 -m pip install --quiet 'pytest==9.0.2' 'PyYAML==6.0.2'
|
||||
python3 -m pytest .gitea/scripts/tests/test_prod_auto_deploy.py -q
|
||||
python3 .gitea/scripts/lint-workflow-yaml.py --workflow-dir .gitea/workflows
|
||||
|
||||
- name: Wait for green main CI on this SHA
|
||||
if: ${{ steps.plan.outputs.enabled == 'true' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 .gitea/scripts/prod-auto-deploy.py wait-ci
|
||||
|
||||
- name: Call production CP redeploy-fleet
|
||||
if: ${{ steps.plan.outputs.enabled == 'true' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 .gitea/scripts/prod-auto-deploy.py assert-enabled
|
||||
PLAN="$RUNNER_TEMP/prod-auto-deploy-plan.json"
|
||||
TARGET_TAG="$(jq -r '.target_tag' "$PLAN")"
|
||||
BODY="$(jq -c '.body' "$PLAN")"
|
||||
|
||||
echo "POST $CP_URL/cp/admin/tenants/redeploy-fleet"
|
||||
echo " target_tag: $TARGET_TAG"
|
||||
echo " body: $BODY"
|
||||
|
||||
HTTP_RESPONSE="$RUNNER_TEMP/prod-redeploy-response.json"
|
||||
HTTP_CODE_FILE="$RUNNER_TEMP/prod-redeploy-http-code.txt"
|
||||
set +e
|
||||
curl -sS -o "$HTTP_RESPONSE" -w '%{http_code}' \
|
||||
-m 1200 \
|
||||
-H "Authorization: Bearer $CP_ADMIN_API_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-X POST "$CP_URL/cp/admin/tenants/redeploy-fleet" \
|
||||
-d "$BODY" > "$HTTP_CODE_FILE"
|
||||
set -e
|
||||
|
||||
HTTP_CODE="$(cat "$HTTP_CODE_FILE" 2>/dev/null || echo "000")"
|
||||
[ -z "$HTTP_CODE" ] && HTTP_CODE="000"
|
||||
echo "HTTP $HTTP_CODE"
|
||||
jq '{ok, result_count: (.results // [] | length)}' "$HTTP_RESPONSE" || true
|
||||
|
||||
{
|
||||
echo "## Production auto-deploy"
|
||||
echo ""
|
||||
echo "**Commit:** \`${GITHUB_SHA:0:7}\`"
|
||||
echo "**Target tag:** \`$TARGET_TAG\`"
|
||||
echo "**HTTP:** $HTTP_CODE"
|
||||
echo ""
|
||||
echo "### Per-tenant result"
|
||||
echo ""
|
||||
echo "| Slug | Phase | SSM Status | Exit | Healthz | Error present |"
|
||||
echo "|------|-------|------------|------|---------|---------------|"
|
||||
jq -r '.results[]? | "| \(.slug) | \(.phase) | \(.ssm_status // "-") | \(.ssm_exit_code) | \(.healthz_ok) | \((.error // "") != "") |"' "$HTTP_RESPONSE" || true
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
if [ "$HTTP_CODE" != "200" ]; then
|
||||
echo "::error::redeploy-fleet returned HTTP $HTTP_CODE"
|
||||
exit 1
|
||||
fi
|
||||
OK="$(jq -r '.ok' "$HTTP_RESPONSE")"
|
||||
if [ "$OK" != "true" ]; then
|
||||
echo "::error::redeploy-fleet reported ok=false; production rollout halted."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Verify reachable tenants report this SHA
|
||||
if: ${{ steps.plan.outputs.enabled == 'true' }}
|
||||
env:
|
||||
TENANT_DOMAIN: moleculesai.app
|
||||
run: |
|
||||
set -euo pipefail
|
||||
RESP="$RUNNER_TEMP/prod-redeploy-response.json"
|
||||
mapfile -t SLUGS < <(jq -r '.results[]? | .slug' "$RESP")
|
||||
if [ ${#SLUGS[@]} -eq 0 ]; then
|
||||
echo "::error::No tenants returned from redeploy-fleet; refusing to mark production deploy verified."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
STALE_COUNT=0
|
||||
UNREACHABLE_COUNT=0
|
||||
UNHEALTHY_COUNT=0
|
||||
for slug in "${SLUGS[@]}"; do
|
||||
healthz_ok="$(jq -r --arg slug "$slug" '.results[]? | select(.slug == $slug) | .healthz_ok' "$RESP" | tail -1)"
|
||||
if [ "$healthz_ok" != "true" ]; then
|
||||
echo "::error::$slug did not report healthz_ok=true in redeploy-fleet response."
|
||||
UNHEALTHY_COUNT=$((UNHEALTHY_COUNT + 1))
|
||||
continue
|
||||
fi
|
||||
url="https://${slug}.${TENANT_DOMAIN}/buildinfo"
|
||||
body="$(curl -sS --max-time 30 --retry 3 --retry-delay 5 --retry-connrefused "$url" || true)"
|
||||
actual="$(echo "$body" | jq -r '.git_sha // ""' 2>/dev/null || echo "")"
|
||||
if [ -z "$actual" ]; then
|
||||
echo "::error::$slug did not return /buildinfo after deploy."
|
||||
UNREACHABLE_COUNT=$((UNREACHABLE_COUNT + 1))
|
||||
continue
|
||||
fi
|
||||
if [ "$actual" != "$GITHUB_SHA" ]; then
|
||||
echo "::error::$slug is stale: actual=${actual:0:7}, expected=${GITHUB_SHA:0:7}"
|
||||
STALE_COUNT=$((STALE_COUNT + 1))
|
||||
else
|
||||
echo "$slug: ${actual:0:7}"
|
||||
fi
|
||||
done
|
||||
|
||||
{
|
||||
echo ""
|
||||
echo "### Buildinfo verification"
|
||||
echo ""
|
||||
echo "Expected SHA: \`${GITHUB_SHA:0:7}\`"
|
||||
echo "Verified tenants: ${#SLUGS[@]}"
|
||||
echo "Stale tenants: $STALE_COUNT"
|
||||
echo "Unhealthy tenants: $UNHEALTHY_COUNT"
|
||||
echo "Unreachable tenants: $UNREACHABLE_COUNT"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
if [ "$STALE_COUNT" -gt 0 ] || [ "$UNHEALTHY_COUNT" -gt 0 ] || [ "$UNREACHABLE_COUNT" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -66,19 +66,27 @@ jobs:
|
||||
# PR#372's ci.yml port used. Diffs against the PR base or the
|
||||
# previous push SHA, then matches against the wheel-relevant
|
||||
# path set.
|
||||
BASE="${GITHUB_BASE_REF:-${{ github.event.before }}}"
|
||||
#
|
||||
# Root fix (mc#917): Gitea Actions does not expose github.event.before
|
||||
# as a ${{ }} template-expression that resolves in shell scripts for
|
||||
# push events (it becomes empty string). The env var GITHUB_EVENT_BEFORE
|
||||
# IS set by the runner for push events. Guard git cat-file with
|
||||
# `timeout 30` to prevent indefinite hangs on malformed BASE values.
|
||||
if [ "${{ github.event_name }}" = "pull_request" ] && [ -n "${{ github.event.pull_request.base.sha }}" ]; then
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
else
|
||||
BASE="${GITHUB_EVENT_BEFORE:-}"
|
||||
fi
|
||||
if [ -z "$BASE" ] || echo "$BASE" | grep -qE '^0+$'; then
|
||||
# New branch or no previous SHA: treat as wheel-relevant.
|
||||
echo "wheel=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
if ! git cat-file -e "$BASE" 2>/dev/null; then
|
||||
if ! timeout 30 git cat-file -e "$BASE" 2>/dev/null; then
|
||||
git fetch --depth=1 origin "$BASE" 2>/dev/null || true
|
||||
fi
|
||||
if ! git cat-file -e "$BASE" 2>/dev/null; then
|
||||
if ! timeout 30 git cat-file -e "$BASE" 2>/dev/null; then
|
||||
echo "::notice::BASE=$BASE not in local clone (shallow fetch or pruned ref)"
|
||||
echo "wheel=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
+148
-7
@@ -12,12 +12,14 @@ Environment variables (set by the workspace container):
|
||||
PLATFORM_URL — platform API base URL (e.g. http://platform:8080)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
# Top-level (not inside main()) so the wheel rewriter expands this to
|
||||
@@ -765,24 +767,163 @@ async def main(): # pragma: no cover
|
||||
break
|
||||
|
||||
|
||||
def cli_main() -> None: # pragma: no cover
|
||||
"""Synchronous wrapper around the async MCP stdio loop.
|
||||
# --- HTTP/SSE Transport (for Hermes runtime) ---
|
||||
|
||||
# Per-connection pending request queue.
|
||||
# Maps connection-id → asyncio.Queue of JSON-RPC responses.
|
||||
_http_connection_queues: dict[str, asyncio.Queue] = {}
|
||||
_http_connection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _handle_http_mcp(request) -> dict | None:
|
||||
"""Handle an incoming JSON-RPC request over HTTP. Returns the JSON-RPC response dict, or None for notifications."""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}
|
||||
|
||||
req_id = body.get("id")
|
||||
method = body.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": _build_initialize_result(),
|
||||
}
|
||||
elif method == "notifications/initialized":
|
||||
return None # No response needed
|
||||
elif method == "tools/list":
|
||||
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": TOOLS}}
|
||||
elif method == "tools/call":
|
||||
params = body.get("params", {})
|
||||
tool_name = params.get("name", "")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_text = await handle_tool_call(tool_name, tool_args)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"content": [{"type": "text", "text": result_text}]},
|
||||
}
|
||||
else:
|
||||
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}
|
||||
|
||||
|
||||
async def _run_http_server(port: int) -> None:
|
||||
"""Run MCP server over HTTP/SSE — compatible with Hermes MCP-native agents."""
|
||||
try:
|
||||
from starlette.applications import Starlette # noqa: F401
|
||||
from starlette.routing import Route # noqa: F401
|
||||
from starlette.responses import JSONResponse, Response, StreamingResponse # noqa: F401
|
||||
except ImportError:
|
||||
logger.error("HTTP transport requires starlette — install with: pip install starlette uvicorn")
|
||||
return
|
||||
|
||||
# Import uvicorn here so the stdio path (the common case) doesn't pay
|
||||
# the import cost if starlette/uvicorn aren't installed.
|
||||
import uvicorn # noqa: F401
|
||||
|
||||
_http_connection_queues.clear()
|
||||
|
||||
async def mcp_handler(request):
|
||||
"""POST /mcp — receive and process JSON-RPC requests."""
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await _handle_http_mcp(request)
|
||||
if response is None:
|
||||
return Response(status_code=202)
|
||||
async with _http_connection_lock:
|
||||
queue = _http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
return Response(status_code=202)
|
||||
# No SSE subscriber — return JSON directly
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
"""GET /mcp/stream — SSE stream for push-based responses."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues[conn_id] = queue
|
||||
|
||||
async def event_stream():
|
||||
yield f"event: connected\ndata: {json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await asyncio.wait_for(queue.get(), timeout=300)
|
||||
yield f"event: message\ndata: {json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues.pop(conn_id, None)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
logger.info(f"A2A MCP HTTP server listening on http://127.0.0.1:{port}/mcp")
|
||||
await server.serve()
|
||||
|
||||
|
||||
def cli_main(transport: str = "stdio", port: int = 9100) -> None: # pragma: no cover
|
||||
"""Synchronous wrapper — selects stdio or HTTP transport.
|
||||
|
||||
Called by ``mcp_cli.main`` (the ``molecule-mcp`` console-script
|
||||
entry point in scripts/build_runtime_package.py) AFTER env
|
||||
validation and the standalone register + heartbeat thread setup.
|
||||
Direct callers (in-container code that already validated env and
|
||||
runs heartbeat.py separately) can also invoke this — it's the
|
||||
smallest possible "run the MCP stdio JSON-RPC loop" surface.
|
||||
runs heartbeat.py separately) can also invoke this.
|
||||
|
||||
Wheel-smoke gates in scripts/wheel_smoke.py pin the importability
|
||||
of this name (alongside ``mcp_cli.main``) so a silent rename can't
|
||||
break every external-runtime operator's MCP install — the 0.1.16
|
||||
``main_sync`` rename incident is the cautionary precedent.
|
||||
|
||||
Args:
|
||||
transport: "stdio" (default) or "http" (HTTP+SSE for Hermes).
|
||||
port: TCP port for HTTP transport (default 9100).
|
||||
"""
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
if transport == "http":
|
||||
asyncio.run(_run_http_server(port))
|
||||
else:
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
cli_main()
|
||||
parser = argparse.ArgumentParser(description="A2A MCP Server")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
default="stdio",
|
||||
choices=["stdio", "http"],
|
||||
help="Transport mode: stdio (default) or http (HTTP+SSE for Hermes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9100,
|
||||
help="TCP port for HTTP transport (default 9100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
cli_main(transport=args.transport, port=args.port)
|
||||
|
||||
@@ -0,0 +1,671 @@
|
||||
"""Tests for the HTTP/SSE transport of a2a_mcp_server.
|
||||
|
||||
Covers:
|
||||
- _handle_http_mcp: JSON-RPC request parsing and routing
|
||||
- Starlette app routes: POST /mcp, GET /mcp/stream, GET /health
|
||||
- cli_main argparse: --transport and --port flags
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
"""Minimal request duck-type for _handle_http_mcp."""
|
||||
|
||||
def __init__(self, body_json: dict, headers: dict | None = None):
|
||||
self._body = body_json
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_http_mcp — unit tests (no I/O)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_initialize():
|
||||
"""initialize method returns protocol version, capabilities, and server info."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 42, "method": "initialize", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 42
|
||||
assert "protocolVersion" in resp["result"]
|
||||
assert "capabilities" in resp["result"]
|
||||
assert resp["result"]["serverInfo"]["name"] == "molecule"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_notifications_initialized_returns_none():
|
||||
"""notifications/initialized is a notification (no response needed)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_list():
|
||||
"""tools/list returns the TOOLS schema."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 7, "method": "tools/list"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 7
|
||||
assert "tools" in resp["result"]
|
||||
assert isinstance(resp["result"]["tools"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_unknown_method_returns_error():
|
||||
"""Unknown method returns -32601 Method not found."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 3, "method": "foobar", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 3
|
||||
assert resp["error"]["code"] == -32601
|
||||
assert "Method not found" in resp["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_malformed_json_returns_parse_error():
|
||||
"""Request with bad JSON returns -32700 parse error."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest.__new__(_DummyRequest)
|
||||
req.headers = {}
|
||||
req.json = AsyncMock(side_effect=ValueError("bad json"))
|
||||
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["error"]["code"] == -32700
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_with_get_workspace_info():
|
||||
"""tools/call for get_workspace_info returns workspace info (mocked platform call)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="mocked info")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 9,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 9
|
||||
assert resp["result"]["content"][0]["text"] == "mocked info"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_unknown_tool():
|
||||
"""tools/call for an unknown tool returns the handle_tool_call error text."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 11,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "not_a_real_tool", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 11
|
||||
assert "Unknown tool" in resp["result"]["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Starlette app — integration tests with TestClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clear_http_globals():
|
||||
"""Reset module-level HTTP state before and after each test."""
|
||||
import a2a_mcp_server
|
||||
|
||||
# Save and restore globals
|
||||
saved_queues = a2a_mcp_server._http_connection_queues.copy()
|
||||
saved_lock = a2a_mcp_server._http_connection_lock
|
||||
a2a_mcp_server._http_connection_queues.clear()
|
||||
yield
|
||||
# Restore
|
||||
a2a_mcp_server._http_connection_queues = saved_queues
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _register_sse_queue():
|
||||
"""Register a queue for SSE push delivery (synchronous — callable from tests)."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue = asyncio.Queue(maxsize=100)
|
||||
import a2a_mcp_server
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
return conn_id, queue
|
||||
|
||||
|
||||
def _build_test_app(port: int = 9100):
|
||||
"""Build the Starlette app for testing without starting a real server.
|
||||
|
||||
Mirrors the app construction inside _run_http_server, but returns
|
||||
the app directly so TestClient can drive it without binding a port.
|
||||
"""
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
import a2a_mcp_server
|
||||
|
||||
async def mcp_handler(request):
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await a2a_mcp_server._handle_http_mcp(request)
|
||||
if response is None:
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
queue = a2a_mcp_server._http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
import asyncio as _asyncio
|
||||
|
||||
async def event_stream():
|
||||
import json as _json
|
||||
yield f"event: connected\ndata: {_json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await _asyncio.wait_for(queue.get(), timeout=300)
|
||||
import json as _json
|
||||
yield f"event: message\ndata: {_json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except _asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues.pop(conn_id, None)
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
return Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestHTTPAppRoutes:
|
||||
"""Integration tests using Starlette TestClient against the HTTP app.
|
||||
|
||||
Starlette TestClient uses the ASGI interface directly (no real HTTP server
|
||||
or uvicorn needed), so no uvicorn mock is required.
|
||||
"""
|
||||
|
||||
def test_health_returns_ok_and_transport(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9100)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["transport"] == "http+sse"
|
||||
assert data["port"] == 9100
|
||||
|
||||
def test_health_accepts_different_port(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9999)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.json()["port"] == 9999
|
||||
|
||||
def test_mcp_post_initialize(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 1
|
||||
assert "protocolVersion" in data["result"]
|
||||
|
||||
def test_mcp_post_tools_list(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "tools" in data["result"]
|
||||
assert len(data["result"]["tools"]) > 0
|
||||
|
||||
def test_mcp_post_notifications_initialized_returns_202(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
})
|
||||
|
||||
# Notifications return 202 with no body
|
||||
assert resp.status_code == 202
|
||||
|
||||
def test_mcp_post_unknown_method_returns_200_with_error(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "no_such_method",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
def test_mcp_post_malformed_json_returns_error(self, _clear_http_globals):
|
||||
"""Malformed JSON body returns a JSON-RPC parse-error response (HTTP 200)."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
content=b"not json at all",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
# _handle_http_mcp catches ValueError from request.json() and returns
|
||||
# a JSON-RPC parse-error response with HTTP 200.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["error"]["code"] == -32700
|
||||
assert "Parse error" in resp.json()["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_stream_populates_queue(self, _clear_http_globals):
|
||||
"""_register_sse_queue adds a queue to _http_connection_queues before any async work."""
|
||||
import a2a_mcp_server
|
||||
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
# The queue is registered synchronously — no await needed, no cleanup ran yet.
|
||||
assert conn_id in a2a_mcp_server._http_connection_queues
|
||||
assert len(conn_id) == 36 # valid UUID format
|
||||
assert not queue.full()
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_queue_delivers_response(self, _clear_http_globals):
|
||||
"""POST /mcp with x-mcp-conn-id routes response into the SSE queue."""
|
||||
import uuid
|
||||
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register an SSE queue to simulate an active SSE subscriber
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
|
||||
# POST a tools/call with the conn_id header
|
||||
with TestClient(_build_test_app()) as client:
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="test-ws-info")):
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 99,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
},
|
||||
)
|
||||
|
||||
# The handler returns 202 because the response was queued for SSE delivery
|
||||
assert resp.status_code == 202
|
||||
|
||||
# Verify the response was placed in the SSE queue
|
||||
result = await asyncio.wait_for(queue.get(), timeout=2.0)
|
||||
assert result["id"] == 99
|
||||
assert result["result"]["content"][0]["text"] == "test-ws-info"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_tool_call — remaining tool branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_send_message_to_user_with_mixed_attachments():
|
||||
"""attachments with non-string elements are filtered; the list branch is exercised."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_send_message_to_user", AsyncMock(return_value="sent ok")) as mock_fn:
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 21,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "send_message_to_user",
|
||||
"arguments": {
|
||||
"message": "hello",
|
||||
# Mixed types: list contains a dict (non-string) and an empty string
|
||||
"attachments": [{"url": "http://x"}, "", "valid.zip", None],
|
||||
},
|
||||
},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "sent ok"
|
||||
# Only string, non-empty values passed through
|
||||
mock_fn.assert_called_once()
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs["attachments"] == ["valid.zip"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_wait_for_message():
|
||||
"""wait_for_message is dispatched and returns the wrapped result."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_wait_for_message", AsyncMock(return_value="no messages")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 22,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "wait_for_message", "arguments": {"timeout_secs": 5.0}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "no messages"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_inbox_peek():
|
||||
"""inbox_peek is dispatched with the limit argument."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_inbox_peek", AsyncMock(return_value="2 items")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 23,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "inbox_peek", "arguments": {"limit": 5}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "2 items"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_inbox_pop():
|
||||
"""inbox_pop is dispatched with the activity_id argument."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_inbox_pop", AsyncMock(return_value="acked")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 24,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "inbox_pop", "arguments": {"activity_id": "abc-123"}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "acked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_chat_history():
|
||||
"""chat_history is dispatched with peer_id, limit, and before_ts arguments."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_chat_history", AsyncMock(return_value="history")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 25,
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "chat_history",
|
||||
"arguments": {"peer_id": "ws-peer-1", "limit": 10, "before_ts": ""},
|
||||
},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["result"]["content"][0]["text"] == "history"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cli_main argparse — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_mcp_post_falls_back_to_json_when_sse_queue_is_full(_clear_http_globals):
|
||||
"""When the SSE queue is full (>100 pending), the handler returns JSON directly."""
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register a queue and fill it to capacity
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=2) # small queue for testing
|
||||
|
||||
async def _setup():
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
queue.put_nowait({"id": 1})
|
||||
queue.put_nowait({"id": 2})
|
||||
|
||||
_sync_run(_setup())
|
||||
assert queue.full()
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={"jsonrpc": "2.0", "id": 99, "method": "initialize", "params": {}},
|
||||
)
|
||||
|
||||
# With a full queue, the handler returns the response as JSON (not 202)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == 99
|
||||
assert "result" in resp.json()
|
||||
|
||||
|
||||
def _sync_run(coro):
|
||||
"""Run a coroutine synchronously for test isolation (no real event loop needed)."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
def test_cli_main_transport_stdio_calls_main(monkeypatch):
|
||||
"""cli_main(transport='stdio') calls asyncio.run(main) without HTTP."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_calls: list = []
|
||||
|
||||
async def fake_main():
|
||||
run_calls.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="stdio", port=9100)
|
||||
|
||||
assert "called" in run_calls
|
||||
|
||||
|
||||
def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
|
||||
"""cli_main(transport='http') calls _run_http_server without stdio."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_http_calls = []
|
||||
|
||||
async def fake_run_http(port):
|
||||
run_http_calls.append(port)
|
||||
|
||||
# asyncio.run must execute the coroutine for _run_http_server to be called
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_run_http_server", fake_run_http)
|
||||
# stdio path must not be entered
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9102)
|
||||
|
||||
assert run_http_calls == [9102]
|
||||
|
||||
|
||||
def test_cli_main_http_skips_stdio_check(monkeypatch):
|
||||
"""When transport=http, _assert_stdio_is_pipe_compatible must NOT be called."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called = []
|
||||
|
||||
def fake_assert():
|
||||
called.append("assert_called")
|
||||
|
||||
# Patch on the module object directly
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", fake_assert)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", lambda fn: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9100)
|
||||
|
||||
assert "assert_called" not in called
|
||||
|
||||
|
||||
def test_cli_main_default_transport_is_stdio(monkeypatch):
|
||||
"""cli_main() with no args defaults to stdio transport."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called_as: list = []
|
||||
|
||||
async def fake_main():
|
||||
called_as.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main() # No args — defaults to stdio
|
||||
|
||||
assert "called" in called_as
|
||||
|
||||
|
||||
def test_cli_main_main_raises_propagates(monkeypatch):
|
||||
"""If main() raises, cli_main() re-raises (doesn't swallow)."""
|
||||
import a2a_mcp_server
|
||||
|
||||
async def fake_main():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
a2a_mcp_server.cli_main(transport="stdio")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# uvicorn/starlette lazy-import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_http_server_is_coroutine_function():
|
||||
"""_run_http_server is a coroutine function accepting a port argument."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
assert inspect.iscoroutinefunction(_run_http_server)
|
||||
|
||||
|
||||
def test_run_http_server_signature_port_int():
|
||||
"""_run_http_server accepts port as int."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
sig = inspect.signature(_run_http_server)
|
||||
assert "port" in sig.parameters
|
||||
assert sig.parameters["port"].annotation == int
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Test coverage for builtin_tools.security._redact_secrets().
|
||||
|
||||
Issue #834 (C2): commit_memory must not persist API keys verbatim.
|
||||
|
||||
Pre-commit hook blocks bare secret-like strings (ghp_, sk-ant-, etc.) to prevent
|
||||
accidental commits of real credentials. These tests focus on the functional
|
||||
behaviour of the redaction logic: idempotency, contextual keyword=value patterns,
|
||||
boundary cases, and mixed content — without triggering the hook's length thresholds.
|
||||
The pre-commit hook itself is the primary guard for bare-pattern detection.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from builtin_tools.security import REDACTED, _redact_secrets
|
||||
|
||||
|
||||
class TestRedactContextual:
|
||||
"""Keyword=value patterns with high-entropy values (under pre-commit threshold)."""
|
||||
|
||||
def test_api_key_contextual(self):
|
||||
"""api_key=X where X ≥ 40 base64 chars → value replaced, keyword preserved."""
|
||||
value = "A" * 40
|
||||
assert _redact_secrets(f"api_key={value}") == f"api_key={REDACTED}"
|
||||
|
||||
def test_keyword_contextual(self):
|
||||
"""Generic 'key=' also matches."""
|
||||
value = "B" * 45
|
||||
assert _redact_secrets(f"key={value}") == f"key={REDACTED}"
|
||||
|
||||
def test_secret_contextual(self):
|
||||
value = "C" * 50
|
||||
assert _redact_secrets(f"secret= {value}") == f"secret= {REDACTED}"
|
||||
|
||||
def test_token_contextual(self):
|
||||
value = "D" * 40
|
||||
assert _redact_secrets(f"token={value}") == f"token={REDACTED}"
|
||||
|
||||
def test_password_contextual(self):
|
||||
value = "E" * 50
|
||||
assert _redact_secrets(f"password={value}") == f"password={REDACTED}"
|
||||
|
||||
def test_keyword_spacing_tolerated(self):
|
||||
"""Spaces around = are tolerated by the pattern."""
|
||||
value = "F" * 40
|
||||
assert _redact_secrets(f"key = {value}") == f"key = {REDACTED}"
|
||||
|
||||
def test_contextual_too_short_not_redacted(self):
|
||||
"""Value shorter than 40 chars is not redacted."""
|
||||
short = "A" * 39
|
||||
assert _redact_secrets(f"api_key={short}") == f"api_key={short}"
|
||||
|
||||
def test_case_insensitive_keyword(self):
|
||||
"""Keyword matching is case-insensitive."""
|
||||
value = "G" * 40
|
||||
assert _redact_secrets(f"API_KEY={value}") == f"API_KEY={REDACTED}"
|
||||
assert _redact_secrets(f"Token={value}") == f"Token={REDACTED}"
|
||||
assert _redact_secrets(f"SECRET={value}") == f"SECRET={REDACTED}"
|
||||
|
||||
def test_boundary_preserved(self):
|
||||
"""Contextual pattern preserves the keyword; only value is replaced."""
|
||||
value = "H" * 40
|
||||
result = _redact_secrets(f"api_key={value}")
|
||||
assert result.startswith("api_key=")
|
||||
assert result.endswith(REDACTED)
|
||||
assert result == f"api_key={REDACTED}"
|
||||
|
||||
def test_base64_chars_in_value(self):
|
||||
"""Base64 alphabet chars (/ +) in value are covered by the charset."""
|
||||
# 40-char string with base64 chars
|
||||
value = "A" * 20 + "/+" + "A" * 18
|
||||
result = _redact_secrets(f"api_key={value}")
|
||||
assert result == f"api_key={REDACTED}"
|
||||
|
||||
|
||||
class TestRedactEdgeCases:
|
||||
"""Non-secret strings, idempotency, and boundary conditions."""
|
||||
|
||||
def test_idempotent(self):
|
||||
"""Calling redaction twice produces the same result."""
|
||||
text = f"token={'A' * 40}"
|
||||
first = _redact_secrets(text)
|
||||
second = _redact_secrets(first)
|
||||
assert second == first
|
||||
assert REDACTED in first
|
||||
|
||||
def test_already_redacted_string(self):
|
||||
"""The [REDACTED] sentinel itself is not matched by any pattern."""
|
||||
assert _redact_secrets(f"see {REDACTED} here") == f"see {REDACTED} here"
|
||||
|
||||
def test_no_match_passthrough(self):
|
||||
"""Normal prose passes through unchanged."""
|
||||
assert _redact_secrets("The answer is 42.") == "The answer is 42."
|
||||
assert _redact_secrets("Hello, world!") == "Hello, world!"
|
||||
assert _redact_secrets("api_key short") == "api_key short"
|
||||
assert _redact_secrets("") == ""
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _redact_secrets("") == ""
|
||||
|
||||
def test_short_value_not_secret(self):
|
||||
"""A short string after a keyword= prefix is not a secret."""
|
||||
assert _redact_secrets("token=short") == "token=short"
|
||||
|
||||
def test_mixed_content(self):
|
||||
"""Real text with a secret-like prefix → only the secret is redacted."""
|
||||
value = "A" * 40
|
||||
result = _redact_secrets(f"found secret: api_key={value} in config")
|
||||
assert result == f"found secret: api_key={REDACTED} in config"
|
||||
Reference in New Issue
Block a user