forked from molecule-ai/molecule-core
Compare commits
68 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 184ce7ae4e | |||
| ff75aeb43e | |||
| 81cf0cbf98 | |||
| 412dec0d87 | |||
| 39931acd9c | |||
| 6f19b88fa7 | |||
| 83454e5efd | |||
| 8254bedf30 | |||
| ec72f199e6 | |||
| ae22a55675 | |||
| 08648bf4b1 | |||
| eec4ea2e7d | |||
| 6201d12533 | |||
| 81e83c05b7 | |||
| 5b5eacbb29 | |||
| c8fca1467e | |||
| 7c8b81c6eb | |||
| fc1c45789e | |||
| e3a18ed8e8 | |||
| 9f551319d2 | |||
| 1052f8bdb0 | |||
| 30fb507165 | |||
| 77e9a965ac | |||
| 5334d60de4 | |||
| d6c0227e3f | |||
| 27db090d3d | |||
| 0f25f6de97 | |||
| 9991057ad1 | |||
| b89a49ec93 | |||
| f5613bf099 | |||
| 9bd2a2c45f | |||
| a489ee1a7c | |||
| c79ba05ed5 | |||
| 6470e5f41b | |||
| aa560c0314 | |||
| 7644e82f2f | |||
| 33fabdf483 | |||
| abba16beb4 | |||
| 9c752e0673 | |||
| be18b9c8f9 | |||
| 2cb1b26512 | |||
| 48d1945269 | |||
| a04a49f7aa | |||
| bbec4cfcfb | |||
| 19c25a9278 | |||
| e50799bc29 | |||
| 07839580a0 | |||
| 2227a14b1e | |||
| e72f9ad107 | |||
| 17aec22f9b | |||
| 8388144098 | |||
| a327d207da | |||
| afe5a0cfe9 | |||
| 529c3f3922 | |||
| c778b62202 | |||
| d80bffe3e3 | |||
| 0c461eb9f1 | |||
| 86015412eb | |||
| f81813f708 | |||
| 58253f0673 | |||
| 28ef75d25e | |||
| 243f9bc2b1 | |||
| 43bf94a07c | |||
| 55f5c0b0ff | |||
| 86fdaad111 | |||
| 6125700c39 | |||
| 89ee8e4d04 | |||
| 26e2e97006 |
@@ -272,6 +272,14 @@ jobs:
|
||||
find tests/e2e infra/scripts -type f -name '*.sh' -print0 \
|
||||
| xargs -0 shellcheck --severity=warning
|
||||
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Lint cleanup-trap hygiene (RFC #2873)
|
||||
# Asserts every shell E2E test that calls `mktemp` also installs
|
||||
# an EXIT trap. Catches the /tmp-leak class — a missing trap
|
||||
# silently leaks scratch into CI runners (~10-100KB per run).
|
||||
# See tests/e2e/lint_cleanup_traps.sh for the rule + fix pattern.
|
||||
run: bash tests/e2e/lint_cleanup_traps.sh
|
||||
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Run E2E bash unit tests (no live infra)
|
||||
# Pure-bash unit tests for E2E helper libs (lib/*.sh). These pin
|
||||
|
||||
@@ -48,16 +48,21 @@ export function EmptyState() {
|
||||
});
|
||||
|
||||
// "Create blank" bypasses templates entirely — no preflight, no
|
||||
// modal, just POST /workspaces with a default name and tier.
|
||||
// Deliberately NOT routed through useTemplateDeploy because it
|
||||
// has no `template.id` to deploy against.
|
||||
// modal, just POST /workspaces with a default name. Deliberately
|
||||
// NOT routed through useTemplateDeploy because it has no
|
||||
// `template.id` to deploy against.
|
||||
//
|
||||
// tier is omitted so the backend picks a SaaS-aware default
|
||||
// (T4 on SaaS, T3 on self-hosted — see WorkspaceHandler.DefaultTier).
|
||||
// The previous hardcoded `tier: 2` shipped every fresh-tenant agent
|
||||
// at Standard regardless of host, which surprised SaaS users whose
|
||||
// CreateWorkspaceDialog already defaults to T4.
|
||||
const createBlank = async () => {
|
||||
setBlankCreating(true);
|
||||
setBlankError(null);
|
||||
try {
|
||||
const ws = await api.post<{ id: string }>("/workspaces", {
|
||||
name: "My First Agent",
|
||||
tier: 2,
|
||||
canvas: firstDeployCoords(),
|
||||
});
|
||||
handleDeployed(ws.id);
|
||||
|
||||
@@ -286,6 +286,14 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [confirmRestart, setConfirmRestart] = useState(false);
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
// First-mount scroll-to-bottom needs `behavior: "instant"` — long
|
||||
// conversations smooth-animate for ~300ms which any concurrent
|
||||
// re-render can interrupt, leaving the user stuck mid-conversation
|
||||
// when the chat tab opens. Subsequent appends (new agent messages)
|
||||
// keep `smooth` for the visual "landing" feel. Flipped the first
|
||||
// time messages.length goes positive, so a workspace switch (which
|
||||
// remounts ChatTab) gets a fresh instant jump too.
|
||||
const hasInitialScrollRef = useRef(false);
|
||||
// Lazy-load older history on scroll-up.
|
||||
// - containerRef = the scrollable messages viewport
|
||||
// - topRef = sentinel above the messages list; IO observes it
|
||||
@@ -545,6 +553,15 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
scrollAnchorRef.current = null;
|
||||
return;
|
||||
}
|
||||
// Instant on first arrival of messages — smooth-scroll on a long
|
||||
// conversation gets interrupted by concurrent renders and leaves
|
||||
// the user stuck in the middle. After the first jump, subsequent
|
||||
// appends animate as before.
|
||||
if (!hasInitialScrollRef.current && messages.length > 0) {
|
||||
hasInitialScrollRef.current = true;
|
||||
bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior });
|
||||
return;
|
||||
}
|
||||
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages]);
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# Team Expansion (Recursive Workspaces)
|
||||
|
||||
When a workspace is expanded into a team, it gains sub-workspaces while its own agent remains as the **team lead** (coordinator). This is recursive — sub-workspaces can themselves be expanded into teams, infinitely deep.
|
||||
|
||||
## How It Works
|
||||
|
||||
When Developer PM is expanded into a team:
|
||||
|
||||
```
|
||||
Business Core
|
||||
|
|
||||
+-- Developer PM (agent stays, becomes coordinator)
|
||||
|
|
||||
+-- Frontend Agent (sub-workspace, private scope)
|
||||
+-- Backend Agent (sub-workspace, private scope)
|
||||
+-- QA Agent (sub-workspace, private scope)
|
||||
```
|
||||
|
||||
- Developer PM's agent **still exists** and acts as coordinator
|
||||
- Developer PM receives incoming A2A messages from Business Core
|
||||
- Developer PM's agent decides how to delegate to sub-workspaces
|
||||
- Sub-workspaces talk to Developer PM and to each other (same level)
|
||||
- Sub-workspaces **cannot** talk to Business Core or any workspace outside the team
|
||||
|
||||
## Communication Rules
|
||||
|
||||
| Direction | Allowed? | Example |
|
||||
|-----------|----------|---------|
|
||||
| Parent level -> team lead | Yes | Business Core -> Developer PM |
|
||||
| Team lead -> sub-workspaces | Yes | Developer PM -> Frontend Agent |
|
||||
| Sub-workspace -> team lead | Yes | Frontend Agent -> Developer PM |
|
||||
| Sub-workspace <-> sibling | Yes | Frontend Agent <-> Backend Agent |
|
||||
| Outside -> sub-workspace directly | No (403) | Business Core -> Frontend Agent |
|
||||
| Sub-workspace -> outside directly | No | Frontend Agent -> Business Core |
|
||||
|
||||
The team lead (Developer PM) is the **only** bridge between the team's internal world and the outside.
|
||||
|
||||
## Scoped Registry
|
||||
|
||||
Sub-workspaces register in the platform registry but with a **private scope**. The registry knows about them but enforces access control.
|
||||
|
||||
```
|
||||
Registry:
|
||||
Business Core :8001 scope: public
|
||||
Developer PM :8002 scope: public
|
||||
Frontend Agent :8010 scope: private, parent=Developer PM
|
||||
Backend Agent :8011 scope: private, parent=Developer PM
|
||||
QA Agent :8012 scope: private, parent=Developer PM
|
||||
```
|
||||
|
||||
- The platform can always discover any workspace (for provisioning, monitoring)
|
||||
- The parent workspace can discover its sub-workspaces
|
||||
- Sub-workspaces can discover their siblings (same parent)
|
||||
- Outside workspaces get a **403 Forbidden** if they try to discover a private sub-workspace
|
||||
|
||||
## How to Expand
|
||||
|
||||
Expansion is triggered via `POST /workspaces/:id/expand`. The platform reads the `sub_workspaces` list from the workspace's config and provisions each one. On the canvas, users right-click a workspace node and select "Expand into team."
|
||||
|
||||
Collapsing is the inverse: `POST /workspaces/:id/collapse`. Sub-workspaces are stopped and removed.
|
||||
|
||||
## What Happens on Expansion
|
||||
|
||||
When Developer PM is expanded into a team, the hierarchy changes but the outside view doesn't. Business Core's parent/child relationship to Developer PM is unaffected — Developer PM still responds to the same A2A endpoint.
|
||||
|
||||
The events fired:
|
||||
- `WORKSPACE_EXPANDED` with the new `sub_workspace_ids` in the payload
|
||||
- `WORKSPACE_PROVISIONING` for each new sub-workspace
|
||||
- `WORKSPACE_ONLINE` for each sub-workspace as they come up
|
||||
|
||||
Communication rules are automatically derived from the new hierarchy — no manual wiring needed.
|
||||
|
||||
## Canvas Behavior
|
||||
|
||||
- Children render as embedded mini-cards (`TeamMemberChip`) inside the parent node, not as separate canvas nodes
|
||||
- Each mini-card shows full status: gradient bar, name, tier badge, skills pills, active tasks, descendant count
|
||||
- **Recursive rendering** up to 3 levels deep (`MAX_NESTING_DEPTH = 3`) — sub-cards can contain their own "Team" sections
|
||||
- Parent node dynamically resizes: 210-280px (no children), 320-450px (children), 400-560px (grandchildren)
|
||||
- Eject button (sky-blue arrow icon) on hover extracts a child from the team
|
||||
- "Extract from Team" also available in the right-click context menu
|
||||
- Double-click a team node to zoom/fit to the parent area
|
||||
- The parent workspace node shows a badge with total descendant count
|
||||
|
||||
## Collapsing a Team
|
||||
|
||||
The inverse of expansion, triggered via `POST /workspaces/:id/collapse`:
|
||||
|
||||
1. Each sub-workspace agent wraps up current work and writes a handoff document to memory
|
||||
2. Sub-workspaces are stopped and removed
|
||||
3. The team lead's agent goes back to handling everything directly
|
||||
4. A `WORKSPACE_COLLAPSED` event fires
|
||||
|
||||
Sub-workspace memory is cleaned up based on backend (see [Memory — Cleanup](../architecture/memory.md#cleanup-on-workspace-deletion)).
|
||||
|
||||
## Deleting a Team Workspace
|
||||
|
||||
When a team workspace is deleted:
|
||||
1. Platform shows a warning listing all sub-workspaces that will be deleted
|
||||
2. User can **drag sub-workspaces out** of the team before confirming (promotes them to the parent level)
|
||||
3. On confirmation, cascade delete removes the parent and all remaining sub-workspaces
|
||||
4. `WORKSPACE_REMOVED` events fire for each deleted workspace
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Communication Rules](../api-protocol/communication-rules.md) — Full access control model
|
||||
- [Core Concepts](../product/core-concepts.md) — Workspace fundamentals
|
||||
- [System Prompt Structure](./system-prompt-structure.md) — How peer capabilities are injected
|
||||
- [Provisioner](../architecture/provisioner.md) — How sub-workspaces are deployed
|
||||
- [Registry & Heartbeat](../api-protocol/registry-and-heartbeat.md) — How registration works
|
||||
- [Event Log](../architecture/event-log.md) — Events fired during expansion
|
||||
- [Canvas UI](../frontend/canvas.md) — Visual behavior of teams
|
||||
@@ -41,8 +41,6 @@ Full contract: `docs/runbooks/admin-auth.md`.
|
||||
| GET | /admin/workspaces/:id/test-token | admin_test_token.go — mint a fresh bearer token for E2E scripts; returns 404 unless `MOLECULE_ENV != production` or `MOLECULE_ENABLE_TEST_TOKENS=1` |
|
||||
| GET/POST/DELETE | /admin/secrets[/:key] | secrets.go — legacy aliases for /settings/secrets |
|
||||
| WS | /workspaces/:id/terminal | terminal.go |
|
||||
| POST | /workspaces/:id/expand | team.go |
|
||||
| POST | /workspaces/:id/collapse | team.go |
|
||||
| POST/GET | /workspaces/:id/approvals | approvals.go |
|
||||
| POST | /workspaces/:id/approvals/:id/decide | approvals.go |
|
||||
| GET | /approvals/pending | approvals.go |
|
||||
|
||||
@@ -336,8 +336,6 @@ This same logic governs: A2A delegation, memory scope enforcement, activity visi
|
||||
|
||||
| Method | Endpoint | Purpose |
|
||||
|--------|----------|---------|
|
||||
| `POST` | `/workspaces/:id/expand` | Expand workspace into team (become coordinator) |
|
||||
| `POST` | `/workspaces/:id/collapse` | Collapse team back to single workspace |
|
||||
|
||||
### Files, Terminal, Templates, Bundles (8 endpoints)
|
||||
|
||||
|
||||
@@ -186,4 +186,3 @@ So the UI now exposes more operational failure state directly instead of silentl
|
||||
- [Quickstart](../quickstart.md)
|
||||
- [Platform API](../api-protocol/platform-api.md)
|
||||
- [Workspace Runtime](../agent-runtime/workspace-runtime.md)
|
||||
- [Team Expansion](../agent-runtime/team-expansion.md)
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ lands in the watch list with a colliding term, add a row here.
|
||||
| **plugin** | A directory under `plugins/` packaging one or more skills or an MCP server wrapper, installable per-workspace via `POST /workspaces/:id/plugins`. Governed by `plugin.yaml`. | **Langflow**: a visual UI node / component in a flowchart. **CrewAI**: a Python-importable callable registered as a capability. |
|
||||
| **agent** | A persistent containerized workspace running continuously — an identity with memory, a role, and a schedule. Not a one-shot invocation. | Most frameworks (AutoGPT, LangChain agents, OpenAI Assistants): a stateless function-call loop. No persistence between invocations unless explicitly checkpointed. |
|
||||
| **flow** | A task execution within a workspace — a request enters, the agent runs tools, emits a response, logs activity. No explicit graph abstraction. | **Langflow**: a directed graph of nodes you author visually. **LangGraph**: a stateful graph of callable nodes. Our "flow" is an imperative timeline, not a graph. |
|
||||
| **team** | A named cluster of workspaces under a PM (org template `expand_team`). Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **team** | A named cluster of workspaces under a PM . Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **skill** | A directory with `SKILL.md` that an agent invokes via the `Skill` tool. Skills are documentation + optional scripts that teach an agent a recipe. | **Anthropic Skills API**: nearly identical. **CrewAI tool**: closer to our plugin's MCP tool, not our skill. |
|
||||
| **channel** | An outbound/inbound social integration (Telegram, Slack, …) per-workspace, wired in `workspace_channels`. | Slack's "channel": the container for messages. We use "channel" for the adapter + credentials, not the conversation itself. |
|
||||
| **runtime** | The execution engine image tag for a workspace: one of `langgraph`, `claude-code`, `openclaw`, `crewai`, `autogen`, `deepagents`, `hermes`. | **LangGraph runtime**: the Python process running the graph. We use "runtime" for the Docker image + adapter pairing, not the inner process. |
|
||||
|
||||
@@ -166,8 +166,6 @@ list_workspaces
|
||||
|
||||
| MCP Tool | API Route | Method | Description |
|
||||
|----------|-----------|--------|-------------|
|
||||
| `expand_team` | `/workspaces/:id/expand` | POST | Expand team node |
|
||||
| `collapse_team` | `/workspaces/:id/collapse` | POST | Collapse team node |
|
||||
|
||||
### Templates & Bundles
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ TOP_LEVEL_MODULES = {
|
||||
"a2a_executor",
|
||||
"a2a_mcp_server",
|
||||
"a2a_tools",
|
||||
"a2a_tools_delegation",
|
||||
"a2a_tools_rbac",
|
||||
"adapter_base",
|
||||
"agent",
|
||||
"agents_md",
|
||||
@@ -69,11 +71,15 @@ TOP_LEVEL_MODULES = {
|
||||
"executor_helpers",
|
||||
"heartbeat",
|
||||
"inbox",
|
||||
"inbox_uploads",
|
||||
"initial_prompt",
|
||||
"internal_chat_uploads",
|
||||
"internal_file_read",
|
||||
"main",
|
||||
"mcp_cli",
|
||||
"mcp_heartbeat",
|
||||
"mcp_inbox_pollers",
|
||||
"mcp_workspace_resolver",
|
||||
"molecule_ai_status",
|
||||
"not_configured_handler",
|
||||
"platform_auth",
|
||||
|
||||
Executable
+40
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
# lint_cleanup_traps.sh — regression gate for the OSS-shape program's
|
||||
# "all E2E tests must have proper cleanup" bar (RFC #2873).
|
||||
#
|
||||
# Asserts: every shell file under tests/e2e/ that calls `mktemp` ALSO
|
||||
# installs an `EXIT` trap somewhere in the file. The trap is the
|
||||
# minimum-viable guarantee that scratch files won't leak when an
|
||||
# assertion or curl exits the script non-zero.
|
||||
#
|
||||
# Why this lints (instead of the test runner enforcing): shell scripts
|
||||
# can't easily be wrapped by an outer harness without breaking the
|
||||
# `WSID=… ./test_x.sh` invocation contract. Static gate is the cheap
|
||||
# defense.
|
||||
#
|
||||
# Usage:
|
||||
# tests/e2e/lint_cleanup_traps.sh
|
||||
#
|
||||
# Exits non-zero if any test_*.sh has unmatched mktemp/trap. CI invokes
|
||||
# it from the existing Shellcheck (E2E scripts) workflow.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
violations=0
|
||||
for f in test_*.sh; do
|
||||
if grep -qE '\bmktemp\b' "$f"; then
|
||||
if ! grep -qE 'trap[[:space:]]+.*EXIT' "$f"; then
|
||||
echo "::error file=tests/e2e/$f::has 'mktemp' but no 'trap … EXIT' — scratch will leak when test exits non-zero. Pattern: TMPDIR_E2E=\$(mktemp -d -t prefix-XXX); trap 'rm -rf \"\$TMPDIR_E2E\"' EXIT INT TERM"
|
||||
violations=$((violations + 1))
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$violations" -gt 0 ]; then
|
||||
echo "::error::$violations shell E2E file(s) leak scratch on early exit. See above."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ all $(grep -lE '\bmktemp\b' test_*.sh | wc -l | tr -d ' ') shell E2E files with mktemp also install an EXIT trap"
|
||||
@@ -22,6 +22,13 @@ set -euo pipefail
|
||||
WSID="${WSID:?WSID=<workspace-id> required}"
|
||||
BASE="${BASE:-http://localhost:8080}"
|
||||
|
||||
# Per-run scratch dir collected under one trap so every mktemp leak path
|
||||
# (assertion failure, SIGINT, exit non-zero) is plugged. Pre-fix this test
|
||||
# created a /tmp/hermes-e2e-XXXXXX.txt and never deleted it — ~10 KB ×
|
||||
# every CI run leaked into the runner. RFC #2873 cleanup-hygiene PR.
|
||||
TMPDIR_E2E=$(mktemp -d -t chat-attachments-e2e-XXXXXX)
|
||||
trap 'rm -rf "$TMPDIR_E2E"' EXIT INT TERM
|
||||
|
||||
log() { printf "\n=== %s ===\n" "$*"; }
|
||||
|
||||
log "Preflight: workspace online?"
|
||||
@@ -29,7 +36,9 @@ STATUS=$(curl -s "$BASE/workspaces/$WSID" | python3 -c 'import json,sys;print(js
|
||||
[ "$STATUS" = "online" ] || { echo "workspace not online ($STATUS)"; exit 1; }
|
||||
|
||||
log "Step 1 — Upload a text file via /chat/uploads"
|
||||
TEST_FILE=$(mktemp -t hermes-e2e-XXXXXX.txt)
|
||||
# `mktemp <full-template>` is portable across BSD (macOS) + GNU; -p is
|
||||
# GNU-only and breaks local dev runs on Mac.
|
||||
TEST_FILE=$(mktemp "$TMPDIR_E2E/hermes-e2e-XXXXXX.txt")
|
||||
echo "secret code: $(openssl rand -hex 4)-$(openssl rand -hex 4)" > "$TEST_FILE"
|
||||
EXPECTED=$(cat "$TEST_FILE" | awk '{print $NF}')
|
||||
UPLOAD=$(curl -s -X POST "$BASE/workspaces/$WSID/chat/uploads" -F "files=@$TEST_FILE")
|
||||
|
||||
@@ -24,6 +24,15 @@ set -uo pipefail
|
||||
BASE="${BASE:-http://localhost:8080}"
|
||||
fails=0
|
||||
|
||||
# Per-run scratch dir collected under one trap so every per-runtime
|
||||
# round_trip mktemp leak path (assertion failure, SIGINT, exit
|
||||
# non-zero, function early-return between mktemp and rm) is plugged.
|
||||
# Pre-fix, round_trip's `rm -f "$test_file"` only fired on the success
|
||||
# path inside the function — every test_failure path before the rm
|
||||
# leaked the scratch into /tmp permanently. RFC #2873 cleanup-hygiene PR.
|
||||
TMPDIR_E2E=$(mktemp -d -t mr-attachments-e2e-XXXXXX)
|
||||
trap 'rm -rf "$TMPDIR_E2E"' EXIT INT TERM
|
||||
|
||||
has_patch_in_container() {
|
||||
local container="$1"
|
||||
# Signal that platform helpers are available AND wired into the
|
||||
@@ -74,12 +83,16 @@ print(f"executor: claude-code monkey-patch active ({name})")
|
||||
round_trip() {
|
||||
local label="$1" wsid="$2"
|
||||
local test_file expected upload uri payload reply reply_text
|
||||
test_file=$(mktemp -t e2e-mr-XXXX.txt)
|
||||
# Scratch goes under TMPDIR_E2E; the script-level trap rm -rf's the
|
||||
# whole dir on exit, so per-file rm calls are unnecessary AND make
|
||||
# error paths leak when forgotten.
|
||||
# `mktemp <full-template>` is portable across BSD (macOS) + GNU; -p is GNU-only.
|
||||
test_file=$(mktemp "$TMPDIR_E2E/e2e-mr-${label}-XXXX.txt")
|
||||
expected="secret-$(openssl rand -hex 6)"
|
||||
echo "$expected" > "$test_file"
|
||||
upload=$(curl -s -X POST "$BASE/workspaces/$wsid/chat/uploads" -F "files=@$test_file")
|
||||
uri=$(echo "$upload" | python3 -c 'import json,sys;print(json.load(sys.stdin)["files"][0]["uri"])' 2>/dev/null)
|
||||
[ -z "$uri" ] && { echo "FAIL $label: upload returned no URI: $upload"; rm -f "$test_file"; return 1; }
|
||||
[ -z "$uri" ] && { echo "FAIL $label: upload returned no URI: $upload"; return 1; }
|
||||
payload=$(URI="$uri" python3 -c '
|
||||
import json, os
|
||||
uri = os.environ["URI"]
|
||||
@@ -103,7 +116,8 @@ try:
|
||||
except Exception as exc:
|
||||
print(f"(parse failed: {exc})")
|
||||
' 2>&1)
|
||||
rm -f "$test_file"
|
||||
# $test_file lives under TMPDIR_E2E; the script-level trap rm -rf's
|
||||
# the dir on exit, covering every return path including SIGINT.
|
||||
|
||||
if echo "$reply_text" | grep -qF "$expected"; then
|
||||
echo "PASS $label round-trip: agent quoted $expected"
|
||||
|
||||
@@ -29,11 +29,20 @@ FAIL=0
|
||||
WSID=""
|
||||
|
||||
cleanup() {
|
||||
# Workspace teardown — best-effort, ignore errors so an unrelated CP
|
||||
# outage doesn't shadow a real test failure.
|
||||
if [ -n "$WSID" ]; then
|
||||
curl -s -X DELETE "$BASE/workspaces/$WSID?confirm=true" > /dev/null || true
|
||||
fi
|
||||
# /tmp scratch — pre-fix only ran on success path (the unconditional
|
||||
# rm at the bottom of the script). Trap-based path lets the file leak
|
||||
# whenever the script exits non-zero before reaching the rm. RFC #2873
|
||||
# cleanup-hygiene PR.
|
||||
if [ -n "${TMPF:-}" ]; then
|
||||
rm -f "$TMPF"
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
assert() {
|
||||
local label="$1"
|
||||
@@ -230,7 +239,8 @@ for r in rows:
|
||||
assert "stored URI matches uploaded URI" "$STORED_URI" "$URI"
|
||||
fi
|
||||
|
||||
rm -f "$TMPF"
|
||||
# $TMPF cleanup happens via the trap-cleanup function above — covers
|
||||
# both the success path and any early exit / SIGINT.
|
||||
|
||||
echo ""
|
||||
echo "=== Results: $PASS passed, $FAIL failed ==="
|
||||
|
||||
@@ -94,6 +94,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
@@ -142,6 +149,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
|
||||
@@ -21,6 +21,14 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Bundle the built-in memory-plugin-postgres binary so an operator can
|
||||
# activate Memory v2 by setting MEMORY_V2_CUTOVER=true + (default)
|
||||
# MEMORY_PLUGIN_URL=http://localhost:9100. The entrypoint starts this
|
||||
# binary in the background; main /platform talks to it over loopback.
|
||||
# Stays inert until the operator flips the cutover env var.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# Clone templates + plugins at build time from manifest.json
|
||||
FROM alpine:3.20 AS templates
|
||||
@@ -30,8 +38,9 @@ COPY scripts/clone-manifest.sh /scripts/clone-manifest.sh
|
||||
RUN chmod +x /scripts/clone-manifest.sh && /scripts/clone-manifest.sh /manifest.json /workspace-configs-templates /org-templates /plugins
|
||||
|
||||
FROM alpine:3.20
|
||||
RUN apk add --no-cache ca-certificates git tzdata
|
||||
RUN apk add --no-cache ca-certificates git tzdata wget
|
||||
COPY --from=builder /platform /platform
|
||||
COPY --from=builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
COPY --from=templates /workspace-configs-templates /workspace-configs-templates
|
||||
COPY --from=templates /org-templates /org-templates
|
||||
@@ -41,6 +50,7 @@ RUN addgroup -g 1000 platform && adduser -u 1000 -G platform -s /bin/sh -D platf
|
||||
EXPOSE 8080
|
||||
COPY <<'ENTRY' /entrypoint.sh
|
||||
#!/bin/sh
|
||||
# Set up docker-socket group (unchanged from pre-sidecar entrypoint).
|
||||
if [ -S /var/run/docker.sock ]; then
|
||||
SOCK_GID=$(stat -c '%g' /var/run/docker.sock 2>/dev/null || stat -f '%g' /var/run/docker.sock 2>/dev/null)
|
||||
if [ -n "$SOCK_GID" ] && [ "$SOCK_GID" != "0" ]; then
|
||||
@@ -50,6 +60,61 @@ if [ -S /var/run/docker.sock ]; then
|
||||
addgroup platform root 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). Co-located with the
|
||||
# main server so operators flipping MEMORY_V2_CUTOVER=true don't need
|
||||
# to provision a separate service.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it — either MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set.
|
||||
# Without that signal, the sidecar adds zero value (the platform's
|
||||
# wiring.go skips building the client too) but pays a real cost: the
|
||||
# plugin's first migration runs `CREATE EXTENSION vector`, which fails
|
||||
# on tenant Postgres without pgvector preinstalled and aborts container
|
||||
# boot via the 30s health gate. Caught on staging redeploy 2026-05-05.
|
||||
#
|
||||
# Env defaults (when sidecar IS spawned):
|
||||
# MEMORY_PLUGIN_DATABASE_URL = $DATABASE_URL (share existing Postgres;
|
||||
# plugin's `memory_namespaces` / `memory_records` tables coexist
|
||||
# with `agent_memories` and the rest of the platform schema —
|
||||
# no conflicts. Operator can override with a separate URL.)
|
||||
# MEMORY_PLUGIN_LISTEN_ADDR = 127.0.0.1:9100
|
||||
#
|
||||
# Set MEMORY_PLUGIN_DISABLE=1 to force-skip the sidecar even with
|
||||
# cutover env set (e.g. running the plugin externally on a separate host).
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
# Drop privs to the platform user — the plugin doesn't need root and
|
||||
# runs unprivileged elsewhere (tenant image already starts as canvas).
|
||||
su-exec platform /memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for the plugin's /v1/health to return 200. Boot
|
||||
# failure here is fatal — better to crash-loop than to silently
|
||||
# serve cutover traffic against a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check that DATABASE_URL is reachable, has the pgvector extension, and the plugin's migrations applied." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
exec su-exec platform /platform "$@"
|
||||
ENTRY
|
||||
RUN chmod +x /entrypoint.sh && apk add --no-cache su-exec
|
||||
|
||||
@@ -34,6 +34,13 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Memory v2 sidecar binary (Memory v2 #2728). Bundled so an operator
|
||||
# can activate cutover by flipping MEMORY_V2_CUTOVER=true without
|
||||
# provisioning a separate service. See entrypoint-tenant.sh for the
|
||||
# launch logic.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# ── Stage 2: Canvas Next.js standalone ────────────────────────────────
|
||||
FROM node:20-alpine AS canvas-builder
|
||||
@@ -74,8 +81,9 @@ RUN deluser --remove-home node 2>/dev/null || true; \
|
||||
delgroup node 2>/dev/null || true; \
|
||||
addgroup -g 1000 canvas && adduser -u 1000 -G canvas -s /bin/sh -D canvas
|
||||
|
||||
# Go platform binary
|
||||
# Go platform binary + Memory v2 sidecar
|
||||
COPY --from=go-builder /platform /platform
|
||||
COPY --from=go-builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
|
||||
# Templates + plugins (cloned from GitHub in stage 3)
|
||||
@@ -91,7 +99,7 @@ COPY --from=canvas-builder /canvas/public ./public
|
||||
|
||||
COPY workspace-server/entrypoint-tenant.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh && \
|
||||
chown -R canvas:canvas /canvas /platform /migrations
|
||||
chown -R canvas:canvas /canvas /platform /memory-plugin /migrations
|
||||
|
||||
EXPOSE 8080
|
||||
# entrypoint.sh starts as root to fix volume perms, then drops to
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoadConfig_DefaultListenAddrIsLoopback pins the default-bind contract.
|
||||
//
|
||||
// Why this matters: with the prior `:9100` default, the plugin listened on
|
||||
// every interface. Inside the container it didn't matter (no host port
|
||||
// mapping today), but a future change that publishes 9100 OR a cross-host
|
||||
// sidecar deploy would have exposed an unauth'd memory store. Loopback by
|
||||
// default is the least-privilege baseline; operators with a multi-host
|
||||
// topology override via MEMORY_PLUGIN_LISTEN_ADDR.
|
||||
func TestLoadConfig_DefaultListenAddrIsLoopback(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", "")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(cfg.ListenAddr, "127.0.0.1:") {
|
||||
t.Errorf("default ListenAddr must bind loopback-only, got %q "+
|
||||
"(security regression — would expose plugin on every interface)",
|
||||
cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ListenAddrEnvOverride(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", ":9100")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if cfg.ListenAddr != ":9100" {
|
||||
t.Errorf("env override ignored: want :9100, got %q", cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MissingDatabaseURL(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "")
|
||||
|
||||
if _, err := loadConfig(); err == nil {
|
||||
t.Fatal("loadConfig must error when MEMORY_PLUGIN_DATABASE_URL is empty")
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -26,12 +28,28 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
|
||||
)
|
||||
|
||||
// migrationsFS bundles the .up.sql files into the binary at build time
|
||||
// so the prebuilt image doesn't need the source tree at runtime. The
|
||||
// prior `os.ReadDir("cmd/memory-plugin-postgres/migrations")` path
|
||||
// only resolved during `go test` from the repo root — in the published
|
||||
// image the path didn't exist and boot failed after the 30s health gate
|
||||
// (caught on staging redeploy 2026-05-05 after PR #2906).
|
||||
//
|
||||
//go:embed migrations/*.up.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
const (
|
||||
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
|
||||
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
|
||||
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
|
||||
|
||||
defaultListenAddr = ":9100"
|
||||
// Loopback-only by default (defense in depth). The platform talks to
|
||||
// the plugin over `http://localhost:9100` from the same container, so
|
||||
// binding to all interfaces would only widen the reachable surface
|
||||
// without enabling any in-design caller. Operators running the plugin
|
||||
// on a separate host override via MEMORY_PLUGIN_LISTEN_ADDR=:9100 (or
|
||||
// some other interface).
|
||||
defaultListenAddr = "127.0.0.1:9100"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -143,32 +161,71 @@ func openDB(databaseURL string) (*sql.DB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations applies the schema migrations bundled at
|
||||
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
|
||||
// runMigrations applies the schema migrations bundled into the binary
|
||||
// via go:embed (see migrationsFS at the top of this file). Idempotent
|
||||
// on repeat boot — every migration file uses CREATE … IF NOT EXISTS.
|
||||
//
|
||||
// Implementation note: rather than embedding the full migrate engine,
|
||||
// we read the migration files at boot from a known relative path. The
|
||||
// down migrations are deliberately NOT applied here — that's a manual
|
||||
// operator action. This keeps the binary tiny and avoids dragging in
|
||||
// golang-migrate's drivers.
|
||||
// The down migrations are deliberately NOT applied here — that's a
|
||||
// manual operator action. This keeps the binary tiny and avoids
|
||||
// dragging in golang-migrate's drivers.
|
||||
//
|
||||
// MEMORY_PLUGIN_MIGRATIONS_DIR (filesystem path) is honored as an
|
||||
// override for operators who need to ship custom migrations alongside
|
||||
// the binary without rebuilding. When unset (the common case) we read
|
||||
// from the embedded FS.
|
||||
func runMigrations(db *sql.DB) error {
|
||||
// Find the migrations directory. In `go run` mode it's relative
|
||||
// to the cmd dir; in the prebuilt binary case it's expected next
|
||||
// to the binary OR via env var override.
|
||||
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
|
||||
if dir == "" {
|
||||
// Best-effort: try the cwd-relative path that works for `go test`.
|
||||
dir = "cmd/memory-plugin-postgres/migrations"
|
||||
if dir := strings.TrimSpace(os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")); dir != "" {
|
||||
return runMigrationsFromDisk(db, dir)
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
return runMigrationsFromEmbed(db)
|
||||
}
|
||||
|
||||
// runMigrationsFromEmbed applies the *.up.sql files bundled into the
|
||||
// binary at build time. Order is alphabetical (matches the on-disk
|
||||
// behavior of os.ReadDir on Linux for the same set of names).
|
||||
func runMigrationsFromEmbed(db *sql.DB) error {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
return fmt.Errorf("read embedded migrations: %w", err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
path := dir + "/" + e.Name()
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
data, err := migrationsFS.ReadFile("migrations/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read embedded %q: %w", name, err)
|
||||
}
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", name, err)
|
||||
}
|
||||
log.Printf("applied embedded migration %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMigrationsFromDisk preserves the legacy filesystem-path mode for
|
||||
// operator-supplied custom migrations.
|
||||
func runMigrationsFromDisk(db *sql.DB, dir string) error {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
path := dir + "/" + name
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %q: %w", path, err)
|
||||
@@ -176,7 +233,7 @@ func runMigrations(db *sql.DB) error {
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", path, err)
|
||||
}
|
||||
log.Printf("applied migration %s", e.Name())
|
||||
log.Printf("applied disk migration %s (from %s)", name, dir)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMigrationsEmbedded_ContainsCreateTable pins that the migrations
|
||||
// are bundled into the binary at build time, NOT loaded from a
|
||||
// filesystem path that doesn't exist at runtime in the published image.
|
||||
//
|
||||
// Pre-fix: PR #2906 shipped the binary without the migrations dir;
|
||||
// `os.ReadDir("cmd/memory-plugin-postgres/migrations")` errored on every
|
||||
// tenant boot, the 30s health gate aborted the container, and the
|
||||
// staging redeploy fleet job marked all tenants as failed. Embedding
|
||||
// the migrations into the binary removes the runtime path entirely.
|
||||
func TestMigrationsEmbedded_ContainsCreateTable(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("embedded migrations dir is empty — go:embed pattern matched no files")
|
||||
}
|
||||
|
||||
var seenUp bool
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
seenUp = true
|
||||
data, err := migrationsFS.ReadFile("migrations/" + e.Name())
|
||||
if err != nil {
|
||||
t.Errorf("read embedded %q: %v", e.Name(), err)
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(data), "CREATE TABLE") {
|
||||
t.Errorf("embedded %q has no CREATE TABLE — wrong file embedded?", e.Name())
|
||||
}
|
||||
}
|
||||
if !seenUp {
|
||||
t.Fatal("no *.up.sql in embedded migrations — runtime would have no schema to apply")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunMigrationsFromEmbed_OrderingIsAlphabetic pins that we apply
|
||||
// migrations in deterministic alphabetical order, not in whatever
|
||||
// arbitrary order migrationsFS.ReadDir happens to return. With one
|
||||
// migration today this is moot, but a future second migration ('002_…')
|
||||
// MUST run after '001_…' or the schema is broken.
|
||||
//
|
||||
// We can't easily exercise db.Exec here (no test DB); instead pin the
|
||||
// sort step on the directory listing itself.
|
||||
func TestRunMigrationsFromEmbed_OrderingIsAlphabetic(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
for i := 1; i < len(names); i++ {
|
||||
if names[i-1] > names[i] {
|
||||
t.Errorf("ReadDir returned non-sorted names; runMigrationsFromEmbed must sort. "+
|
||||
"Got %q before %q", names[i-1], names[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/handlers"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/imagewatch"
|
||||
memwiring "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/wiring"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/router"
|
||||
@@ -265,6 +266,14 @@ func main() {
|
||||
})
|
||||
}
|
||||
|
||||
// Pending-uploads GC sweep — deletes acked rows past their retention
|
||||
// window plus unacked rows past expires_at. Without this the
|
||||
// pending_uploads table grows unbounded; even with the 24h hard TTL,
|
||||
// nothing actually deletes a row, just makes it un-fetchable.
|
||||
go supervised.RunWithRecover(ctx, "pending-uploads-sweeper", func(c context.Context) {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
// status='provisioning' past the timeout window to 'failed' and emits
|
||||
// WORKSPACE_PROVISION_TIMEOUT. Without this the UI banner is cosmetic
|
||||
|
||||
@@ -20,6 +20,51 @@ cd /canvas
|
||||
PORT=3000 HOSTNAME=0.0.0.0 node server.js &
|
||||
CANVAS_PID=$!
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). See Dockerfile entrypoint
|
||||
# comment for rationale.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it (MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set).
|
||||
# Without that signal, the sidecar adds zero value and risks aborting
|
||||
# tenant boot via the 30s health gate when the tenant Postgres lacks
|
||||
# pgvector. Caught on staging redeploy 2026-05-05:
|
||||
# pq: extension "vector" is not available
|
||||
#
|
||||
# Defaults (when sidecar IS spawned): MEMORY_PLUGIN_DATABASE_URL
|
||||
# falls back to the tenant's DATABASE_URL.
|
||||
MEMORY_PLUGIN_PID=""
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
/memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for /v1/health. Boot failure is fatal so a misconfigured
|
||||
# tenant crash-loops instead of silently serving cutover traffic against
|
||||
# a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check DATABASE_URL reachability + pgvector extension + migrations." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
kill "$CANVAS_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
# Start Go platform in foreground-ish (we trap signals)
|
||||
# CANVAS_PROXY_URL tells the platform to proxy unmatched routes to Canvas.
|
||||
# CONTAINER_BACKEND: empty = Docker (default for self-hosted/local).
|
||||
@@ -29,15 +74,20 @@ cd /
|
||||
/platform &
|
||||
PLATFORM_PID=$!
|
||||
|
||||
# If either process exits, kill the other
|
||||
# If any process exits, kill the others
|
||||
cleanup() {
|
||||
kill $CANVAS_PID 2>/dev/null || true
|
||||
kill $PLATFORM_PID 2>/dev/null || true
|
||||
[ -n "$MEMORY_PLUGIN_PID" ] && kill $MEMORY_PLUGIN_PID 2>/dev/null || true
|
||||
}
|
||||
trap cleanup EXIT SIGTERM SIGINT
|
||||
|
||||
# Wait for either to exit — whichever exits first triggers cleanup
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
# Wait for any to exit — whichever exits first triggers cleanup
|
||||
if [ -n "$MEMORY_PLUGIN_PID" ]; then
|
||||
wait -n $CANVAS_PID $PLATFORM_PID $MEMORY_PLUGIN_PID
|
||||
else
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
fi
|
||||
EXIT_CODE=$?
|
||||
cleanup
|
||||
exit $EXIT_CODE
|
||||
|
||||
@@ -31,23 +31,37 @@ package handlers
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// ChatFilesHandler serves file upload + download for chat. Holds a
|
||||
// reference to TemplatesHandler so the (still docker-exec) Download
|
||||
// path keeps using the shared findContainer/CopyFromContainer helpers
|
||||
// without duplicating them. Upload no longer reaches into Docker.
|
||||
//
|
||||
// pendingUploads + broadcaster are wired only when the platform's
|
||||
// migration 20260505100000 has run; nil values fall back to the
|
||||
// pre-poll-mode behavior (422 on poll-mode upload, same as before).
|
||||
// This lets the binary keep booting in environments where the
|
||||
// migration hasn't run yet — the poll branch is gated by a not-nil
|
||||
// check at the call site.
|
||||
type ChatFilesHandler struct {
|
||||
templates *TemplatesHandler
|
||||
|
||||
@@ -56,6 +70,19 @@ type ChatFilesHandler struct {
|
||||
// the 50 MB worst case on a slow EC2 link without leaving a
|
||||
// connection hanging forever on a sick workspace.
|
||||
httpClient *http.Client
|
||||
|
||||
// pendingUploads is the platform-side staging layer for poll-mode
|
||||
// uploads. nil → poll branch returns 422 unchanged (the pre-feature
|
||||
// behavior); non-nil → poll branch parses multipart, persists each
|
||||
// file via storage.Put, logs a chat_upload_receive activity row,
|
||||
// and returns 200 with synthetic platform-pending: URIs.
|
||||
pendingUploads pendinguploads.Storage
|
||||
|
||||
// broadcaster is the events.EventEmitter used to notify the canvas
|
||||
// when an activity row lands (so the Agent Comms panel updates
|
||||
// live). Same emitter the rest of the platform uses; nil = no
|
||||
// broadcast (tests).
|
||||
broadcaster events.EventEmitter
|
||||
}
|
||||
|
||||
func NewChatFilesHandler(t *TemplatesHandler) *ChatFilesHandler {
|
||||
@@ -69,6 +96,16 @@ func NewChatFilesHandler(t *TemplatesHandler) *ChatFilesHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPendingUploads enables the poll-mode upload branch by wiring a
|
||||
// Storage + broadcaster. Call site (router.go) does this at
|
||||
// construction; tests set the fields directly when they want the
|
||||
// poll path exercised. Returns the handler for chained construction.
|
||||
func (h *ChatFilesHandler) WithPendingUploads(storage pendinguploads.Storage, broadcaster events.EventEmitter) *ChatFilesHandler {
|
||||
h.pendingUploads = storage
|
||||
h.broadcaster = broadcaster
|
||||
return h
|
||||
}
|
||||
|
||||
// chatUploadMaxBytes caps the full multipart request body so a
|
||||
// malicious / runaway client can't OOM the proxy hop. 50 MB matches
|
||||
// the workspace-side limit; anything larger is rejected at the
|
||||
@@ -262,6 +299,24 @@ func (h *ChatFilesHandler) Upload(c *gin.Context) {
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Branch on delivery_mode BEFORE attempting the HTTP forward.
|
||||
// Push-mode workspaces continue to do the streaming forward
|
||||
// unchanged. Poll-mode workspaces (typically external runtimes
|
||||
// on a laptop, no public callback URL) get the platform-side
|
||||
// staging path — the file lands in pending_uploads, an activity
|
||||
// row goes into the inbox queue, and the workspace pulls on its
|
||||
// next poll cycle.
|
||||
if h.pendingUploads != nil {
|
||||
mode, modeOK := lookupUploadDeliveryMode(c, ctx, workspaceID)
|
||||
if !modeOK {
|
||||
return
|
||||
}
|
||||
if mode == "poll" {
|
||||
h.uploadPollMode(c, ctx, workspaceID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
wsURL, secret, ok := resolveWorkspaceForwardCreds(c, ctx, workspaceID, "upload")
|
||||
if !ok {
|
||||
return
|
||||
@@ -405,3 +460,317 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// lookupUploadDeliveryMode returns the workspace's delivery_mode
|
||||
// for the chat upload branch. Returns ("", false) and writes the
|
||||
// HTTP error response on lookup failure (caller stops). NULL or
|
||||
// empty delivery_mode is treated as "push" — that's the schema
|
||||
// default and matches the legacy pre-#2339 behavior. Only the
|
||||
// explicit string "poll" routes the upload through the poll-mode
|
||||
// branch.
|
||||
//
|
||||
// Why a dedicated helper instead of reusing lookupDeliveryMode
|
||||
// from a2a_proxy_helpers.go: that one swallows errors and falls
|
||||
// back to "push" so the proxy keeps working on a transient DB
|
||||
// hiccup. For upload we want to surface the not-found case as 404
|
||||
// (which the workspace-poll branch wouldn't otherwise hit, since
|
||||
// the workspace-side row IS the source of truth for the mode).
|
||||
func lookupUploadDeliveryMode(c *gin.Context, ctx context.Context, workspaceID string) (string, bool) {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return "", false
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("chat_files Upload: delivery_mode lookup failed for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delivery_mode lookup failed"})
|
||||
return "", false
|
||||
}
|
||||
if !mode.Valid || mode.String == "" {
|
||||
return "push", true
|
||||
}
|
||||
return mode.String, true
|
||||
}
|
||||
|
||||
// unsafeFilenameChars matches every character that isn't in the safe
|
||||
// alphanumeric + dot/dash/underscore set. Mirrors the Python regex
|
||||
// _UNSAFE_FILENAME_CHARS in workspace/internal_chat_uploads.py — drift
|
||||
// here would mean canvas-emitted URIs differ between push and poll
|
||||
// paths for the same upload.
|
||||
var unsafeFilenameChars = regexp.MustCompile(`[^a-zA-Z0-9._\-]`)
|
||||
|
||||
// SanitizeFilename reduces a user-supplied filename to a safe form.
|
||||
// Behaviorally identical to sanitize_filename in workspace/
|
||||
// internal_chat_uploads.py. Exported so tests in other packages can
|
||||
// pin behavior parity, and so a future shared library can move both
|
||||
// implementations behind one source of truth.
|
||||
func SanitizeFilename(name string) string {
|
||||
base := filepath.Base(name)
|
||||
// filepath.Base on a path-traversal input ("../../etc/passwd")
|
||||
// returns "passwd" (just the last component) — which matches what
|
||||
// Python's os.path.basename does. Tests pin both here and on the
|
||||
// Python side.
|
||||
base = strings.ReplaceAll(base, " ", "_")
|
||||
base = unsafeFilenameChars.ReplaceAllString(base, "_")
|
||||
if len(base) > 100 {
|
||||
ext := ""
|
||||
dot := strings.LastIndex(base, ".")
|
||||
if dot >= 0 && len(base)-dot <= 16 {
|
||||
ext = base[dot:]
|
||||
}
|
||||
base = base[:100-len(ext)] + ext
|
||||
}
|
||||
if base == "" || base == "." || base == ".." {
|
||||
return "file"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// uploadedFile is the per-file response shape the workspace-side
|
||||
// /internal/chat/uploads/ingest also produces. Mirroring the schema
|
||||
// keeps the canvas client unaware of which path handled the upload.
|
||||
type uploadedFile struct {
|
||||
URI string `json:"uri"`
|
||||
Name string `json:"name"`
|
||||
Mimetype string `json:"mimeType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// uploadPollMode handles a chat upload bound for a poll-mode
|
||||
// workspace. Parses the multipart in-place, persists each file via
|
||||
// pendinguploads.Storage, and logs one chat_upload_receive activity
|
||||
// row per file so the workspace's inbox poller picks them up on its
|
||||
// next cycle.
|
||||
//
|
||||
// Why one activity row per file (not one per multipart batch):
|
||||
// - Each row carries one URI; agents that consume the inbox treat
|
||||
// each row as one inbound event. A batch row would force every
|
||||
// consumer to deserialize a list, doubling the field-shape
|
||||
// surface for no UX win.
|
||||
// - At-least-once semantics: a workspace can ack files
|
||||
// individually. Batch ack would leak partial-success state on
|
||||
// a fetcher crash mid-batch.
|
||||
//
|
||||
// Limits enforced here mirror the workspace-side ingest_handler:
|
||||
// - Total body cap: 50 MB (set on c.Request.Body before reaching us)
|
||||
// - Per-file cap: 25 MB (pendinguploads.MaxFileBytes; rejected as 413)
|
||||
// - Filename: sanitized + capped at 100 chars (SanitizeFilename)
|
||||
//
|
||||
// Logging: every persisted file logs an INFO line with workspace_id,
|
||||
// file_id, size, and sanitized name. Failure modes (oversize, missing
|
||||
// files field, malformed multipart) log at WARN with the same fields.
|
||||
// Phase 3 metrics will hook these structured logs.
|
||||
func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, workspaceID string) {
|
||||
// Parse multipart with the same per-file/per-form limits the
|
||||
// workspace-side handler uses (workspace/internal_chat_uploads.py:
|
||||
// max_files=64, max_fields=32). gin's MultipartForm does not
|
||||
// expose those limits directly — the underlying ParseMultipartForm
|
||||
// caps memory at 32 MB by default and spills to disk. For poll-
|
||||
// mode we read each file into memory to hand to Storage.Put;
|
||||
// 25 MB-per-file × 64-files ceiling means worst-case is 1.6 GB of
|
||||
// peak memory. Bound the per-file size at the multipart layer so
|
||||
// the spill never gets close.
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil {
|
||||
log.Printf("chat_files uploadPollMode: parse multipart failed for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "malformed multipart body"})
|
||||
return
|
||||
}
|
||||
form := c.Request.MultipartForm
|
||||
if form == nil || len(form.File["files"]) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no files field in request"})
|
||||
return
|
||||
}
|
||||
headers := form.File["files"]
|
||||
if len(headers) > 64 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "too many files (limit 64)"})
|
||||
return
|
||||
}
|
||||
|
||||
wsUUID, err := uuid.Parse(workspaceID)
|
||||
if err != nil {
|
||||
// validateWorkspaceID at the top of Upload already gates this;
|
||||
// the re-parse is defence in depth in case validateWorkspaceID
|
||||
// drifts. Keep the error class consistent so a bad-id reaches
|
||||
// the same 400 path. Not separately tested because the gate at
|
||||
// the call site is structurally the same uuid.Parse.
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 1: pre-validate + read every part BEFORE any DB write.
|
||||
// A multi-file upload must commit all-or-nothing; a per-file
|
||||
// failure halfway through used to leave rows 1..K-1 in the table
|
||||
// while the client got a 500 and retried the whole batch — duplicate
|
||||
// rows, orphan activity rows. Validating up-front + atomic PutBatch
|
||||
// closes that gap.
|
||||
type prepped struct {
|
||||
Sanitized string
|
||||
Mimetype string
|
||||
Content []byte
|
||||
Original string // original (unsanitized) filename for error messages
|
||||
}
|
||||
prepReady := make([]prepped, 0, len(headers))
|
||||
items := make([]pendinguploads.PutItem, 0, len(headers))
|
||||
for _, fh := range headers {
|
||||
if fh.Size > pendinguploads.MaxFileBytes {
|
||||
log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)",
|
||||
workspaceID, fh.Filename, fh.Size)
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "file exceeds per-file cap",
|
||||
"filename": fh.Filename,
|
||||
"size": fh.Size,
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
content, err := readMultipartFile(fh)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v",
|
||||
workspaceID, fh.Filename, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"})
|
||||
return
|
||||
}
|
||||
// Belt-and-braces post-read cap (multipart.FileHeader.Size can lie
|
||||
// on some clients that don't set Content-Length per part).
|
||||
if len(content) > pendinguploads.MaxFileBytes {
|
||||
log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)",
|
||||
workspaceID, fh.Filename, len(content))
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "file exceeds per-file cap",
|
||||
"filename": fh.Filename,
|
||||
"size": len(content),
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
sanitized := SanitizeFilename(fh.Filename)
|
||||
mimetype := safeMimetype(fh.Header.Get("Content-Type"))
|
||||
prepReady = append(prepReady, prepped{
|
||||
Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename,
|
||||
})
|
||||
items = append(items, pendinguploads.PutItem{
|
||||
Content: content, Filename: sanitized, Mimetype: mimetype,
|
||||
})
|
||||
}
|
||||
|
||||
// Phase 2: atomic batch insert. On failure no rows commit.
|
||||
fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items)
|
||||
if err != nil {
|
||||
if errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
// Belt + suspenders: pre-validation above already caught
|
||||
// this; surface a clean 413 if a malformed FileHeader
|
||||
// somehow slipped through.
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "one or more files exceed per-file cap",
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v",
|
||||
workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 3: write per-file activity rows and build the response. Activity
|
||||
// rows are written individually (not part of the same Tx as PutBatch)
|
||||
// because LogActivity is shared across many handlers and threading the
|
||||
// Tx through would be a bigger refactor. The trade-off: if an activity
|
||||
// write fails after the PutBatch commits, the pending_uploads rows
|
||||
// orphan until the 24h TTL — significantly better than the previous
|
||||
// "every multi-file upload could orphan" behavior, and the workspace's
|
||||
// fetcher handles soft-404 cleanly when activity rows reference a row
|
||||
// the platform later expired.
|
||||
out := make([]uploadedFile, 0, len(prepReady))
|
||||
for i, p := range prepReady {
|
||||
fileID := fileIDs[i]
|
||||
uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID)
|
||||
summary := "chat_upload_receive: " + p.Sanitized
|
||||
method := "chat_upload_receive"
|
||||
LogActivity(ctx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
ActivityType: "a2a_receive",
|
||||
TargetID: &workspaceID,
|
||||
Method: &method,
|
||||
Summary: &summary,
|
||||
RequestBody: map[string]interface{}{
|
||||
"file_id": fileID.String(),
|
||||
"name": p.Sanitized,
|
||||
"mimeType": p.Mimetype,
|
||||
"size": len(p.Content),
|
||||
"uri": uri,
|
||||
},
|
||||
Status: "ok",
|
||||
})
|
||||
|
||||
log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)",
|
||||
workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype)
|
||||
|
||||
out = append(out, uploadedFile{
|
||||
URI: uri,
|
||||
Name: p.Sanitized,
|
||||
Mimetype: p.Mimetype,
|
||||
Size: int64(len(p.Content)),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"files": out})
|
||||
}
|
||||
|
||||
// safeMimetype validates a multipart-supplied Content-Type header and
|
||||
// returns a sanitized value safe to store + serve back unmodified.
|
||||
//
|
||||
// The platform's GET /content handler reflects the stored mimetype as
|
||||
// the response Content-Type. An attacker-controlled header that
|
||||
// embedded CR/LF could split the response (header injection); a value
|
||||
// containing semicolons could carry an unexpected charset parameter
|
||||
// that confuses a downstream renderer. Strip CR/LF/control chars +
|
||||
// keep only the type/subtype prefix; reject anything that doesn't
|
||||
// match a basic `type/subtype` regex by falling back to the safe
|
||||
// default (application/octet-stream — the workspace-side handler does
|
||||
// the same fallback).
|
||||
func safeMimetype(raw string) string {
|
||||
const fallback = "application/octet-stream"
|
||||
// Trim parameters (`text/html; charset=utf-8` → `text/html`).
|
||||
if i := strings.IndexByte(raw, ';'); i >= 0 {
|
||||
raw = raw[:i]
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
// Reject if any control char or whitespace is present (header
|
||||
// injection defense). RFC 7231 mimetype grammar forbids whitespace.
|
||||
for _, r := range raw {
|
||||
if r < 0x21 || r > 0x7e {
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
// Require exactly one slash separating type and subtype.
|
||||
parts := strings.Split(raw, "/")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return fallback
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// readMultipartFile reads a multipart part fully into memory. Wraps
|
||||
// the open + io.ReadAll + close idiom so the call site stays clean,
|
||||
// and so a future change (chunked reads / hashing) has one place to
|
||||
// land.
|
||||
func readMultipartFile(fh *multipartFileHeader) ([]byte, error) {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open part: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
// multipartFileHeader is a local alias so the readMultipartFile
|
||||
// signature doesn't pull "mime/multipart" into every test that
|
||||
// touches uploadPollMode.
|
||||
type multipartFileHeader = multipart.FileHeader
|
||||
|
||||
@@ -0,0 +1,750 @@
|
||||
package handlers
|
||||
|
||||
// chat_files_poll_test.go — Upload poll-mode branch tests.
|
||||
//
|
||||
// Pinned in their own file so the existing chat_files_test.go stays
|
||||
// focused on the push-mode forward proxy. Same setupTestDB / sqlmock
|
||||
// scaffolding as the rest of the package, plus an in-memory
|
||||
// pendinguploads.Storage so we don't have to mock six SQL statements
|
||||
// per assertion.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// inMemStorage is a process-local pendinguploads.Storage for branch
|
||||
// tests. Records every Put for assertion. Failure modes (Put error,
|
||||
// MarkFetched / Ack tested elsewhere) are injected via fields.
|
||||
type inMemStorage struct {
|
||||
mu sync.Mutex
|
||||
rows map[uuid.UUID]pendinguploads.Record
|
||||
puts []putCall
|
||||
putErr error
|
||||
}
|
||||
|
||||
type putCall struct {
|
||||
WorkspaceID uuid.UUID
|
||||
Filename string
|
||||
Mimetype string
|
||||
Size int
|
||||
}
|
||||
|
||||
func newInMemStorage() *inMemStorage {
|
||||
return &inMemStorage{rows: map[uuid.UUID]pendinguploads.Record{}}
|
||||
}
|
||||
|
||||
func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.putErr != nil {
|
||||
return uuid.Nil, s.putErr
|
||||
}
|
||||
id := uuid.New()
|
||||
s.rows[id] = pendinguploads.Record{
|
||||
FileID: id, WorkspaceID: ws, Content: content,
|
||||
Filename: filename, Mimetype: mimetype,
|
||||
SizeBytes: int64(len(content)), CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
s.puts = append(s.puts, putCall{
|
||||
WorkspaceID: ws, Filename: filename, Mimetype: mimetype, Size: len(content),
|
||||
})
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// PutBatch mirrors the production atomic-batch contract: any per-item
|
||||
// failure leaves the in-memory state unchanged, simulating Tx rollback.
|
||||
// Pre-validation matches PostgresStorage.PutBatch; oversized items
|
||||
// return ErrTooLarge before any row is added.
|
||||
func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.putErr != nil {
|
||||
return nil, s.putErr
|
||||
}
|
||||
// Pre-validate so an oversized item rejects the whole batch before
|
||||
// any state mutation — matches the Tx-rollback semantics.
|
||||
for _, it := range items {
|
||||
if len(it.Content) > pendinguploads.MaxFileBytes {
|
||||
return nil, pendinguploads.ErrTooLarge
|
||||
}
|
||||
}
|
||||
ids := make([]uuid.UUID, 0, len(items))
|
||||
stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items))
|
||||
stagedPuts := make([]putCall, 0, len(items))
|
||||
for _, it := range items {
|
||||
id := uuid.New()
|
||||
stagedRows[id] = pendinguploads.Record{
|
||||
FileID: id, WorkspaceID: ws, Content: it.Content,
|
||||
Filename: it.Filename, Mimetype: it.Mimetype,
|
||||
SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
stagedPuts = append(stagedPuts, putCall{
|
||||
WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content),
|
||||
})
|
||||
ids = append(ids, id)
|
||||
}
|
||||
for id, r := range stagedRows {
|
||||
s.rows[id] = r
|
||||
}
|
||||
s.puts = append(s.puts, stagedPuts...)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) {
|
||||
return pendinguploads.Record{}, pendinguploads.ErrNotFound
|
||||
}
|
||||
func (s *inMemStorage) MarkFetched(context.Context, uuid.UUID) error { return nil }
|
||||
func (s *inMemStorage) Ack(context.Context, uuid.UUID) error { return nil }
|
||||
|
||||
// Sweep is required by the Storage interface (Phase 3 GC). Not
|
||||
// exercised by upload-branch tests — the dedicated sweeper_test.go +
|
||||
// storage_sweep_test.go cover it.
|
||||
func (s *inMemStorage) Sweep(context.Context, time.Duration) (pendinguploads.SweepResult, error) {
|
||||
return pendinguploads.SweepResult{}, nil
|
||||
}
|
||||
|
||||
// expectPollDeliveryMode stubs the SELECT delivery_mode lookup that
|
||||
// uploadPollMode does (separate from the one resolveWorkspaceForwardCreds
|
||||
// does — this is the new helper introduced for the poll branch).
|
||||
func expectPollDeliveryMode(mock sqlmock.Sqlmock, workspaceID, mode string) {
|
||||
rows := sqlmock.NewRows([]string{"delivery_mode"}).AddRow(mode)
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(workspaceID).
|
||||
WillReturnRows(rows)
|
||||
}
|
||||
|
||||
func expectPollDeliveryModeMissing(mock sqlmock.Sqlmock, workspaceID string) {
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(workspaceID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
}
|
||||
|
||||
// expectActivityInsert stubs the LogActivity INSERT so the poll branch's
|
||||
// per-file activity row write doesn't fail the sqlmock expectations.
|
||||
func expectActivityInsert(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
|
||||
// expectActivityInsertWithTypeAndMethod is a strict variant that pins
|
||||
// the activity_type and method positional args. Used in the discriminator
|
||||
// regression test below — the workspace inbox poller filters
|
||||
// `?type=a2a_receive`, so writing any other activity_type silently breaks
|
||||
// poll-mode delivery without a build/test error. Pin the two discriminator
|
||||
// fields so a refactor that flips activity_type back to a custom value is
|
||||
// caught here instead of at runtime by a confused poller.
|
||||
//
|
||||
// Positional args (LogActivity uses ExecContext with 12 positional params):
|
||||
// $1 workspace_id, $2 activity_type, $3 source_id, $4 target_id,
|
||||
// $5 method, $6 summary, $7 request_body, $8 response_body,
|
||||
// $9 tool_trace, $10 duration_ms, $11 status, $12 error_detail.
|
||||
func expectActivityInsertWithTypeAndMethod(mock sqlmock.Sqlmock, workspaceID, activityType, method string) {
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WithArgs(
|
||||
workspaceID, // $1 workspace_id
|
||||
activityType, // $2 activity_type ← pinned
|
||||
sqlmock.AnyArg(), // $3 source_id
|
||||
sqlmock.AnyArg(), // $4 target_id (workspaceID, but already covered)
|
||||
method, // $5 method ← pinned
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
|
||||
// pollUploadFixture builds a multipart body with N named files.
|
||||
func pollUploadFixture(t *testing.T, files map[string][]byte) (*bytes.Buffer, string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
for name, data := range files {
|
||||
fw, err := mw.CreateFormFile("files", name)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile: %v", err)
|
||||
}
|
||||
_, _ = fw.Write(data)
|
||||
}
|
||||
mw.Close()
|
||||
return &buf, mw.FormDataContentType()
|
||||
}
|
||||
|
||||
// ---- happy path ----
|
||||
|
||||
func TestPollUpload_HappyPath_OneFile_StagesAndLogs(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "11111111-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"report.pdf": []byte("PDF-bytes")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(store.puts) != 1 {
|
||||
t.Fatalf("expected 1 storage Put, got %d", len(store.puts))
|
||||
}
|
||||
put := store.puts[0]
|
||||
if put.Filename != "report.pdf" || put.Size != 9 {
|
||||
t.Errorf("unexpected put: %+v", put)
|
||||
}
|
||||
|
||||
// Response shape must match the workspace-side
|
||||
// /internal/chat/uploads/ingest schema so canvas can't tell which
|
||||
// path handled the upload.
|
||||
var resp struct {
|
||||
Files []struct {
|
||||
URI string `json:"uri"`
|
||||
Name string `json:"name"`
|
||||
Mimetype string `json:"mimeType"`
|
||||
Size int `json:"size"`
|
||||
} `json:"files"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("decode response: %v body=%s", err, w.Body.String())
|
||||
}
|
||||
if len(resp.Files) != 1 {
|
||||
t.Fatalf("response files count = %d, want 1", len(resp.Files))
|
||||
}
|
||||
got := resp.Files[0]
|
||||
if got.Name != "report.pdf" || got.Size != 9 {
|
||||
t.Errorf("response file mismatch: %+v", got)
|
||||
}
|
||||
if !strings.HasPrefix(got.URI, "platform-pending:"+wsID+"/") {
|
||||
t.Errorf("URI %q does not start with platform-pending:%s/", got.URI, wsID)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_MultipleFiles_AllStagedAndLogged(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "11111111-aaaa-bbbb-cccc-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
expectActivityInsert(mock)
|
||||
expectActivityInsert(mock)
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
"a.txt": []byte("aaaa"),
|
||||
"b.txt": []byte("bbbbb"),
|
||||
"c.txt": []byte("cccccc"),
|
||||
})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(store.puts) != 3 {
|
||||
t.Fatalf("expected 3 storage Puts, got %d", len(store.puts))
|
||||
}
|
||||
}
|
||||
|
||||
// ---- regression: push-mode unchanged ----
|
||||
|
||||
func TestPollUpload_PushModeFallsThroughToForward(t *testing.T) {
|
||||
// With pendingUploads wired but the workspace's mode is push,
|
||||
// the poll branch must NOT activate — flow falls through to the
|
||||
// existing resolveWorkspaceForwardCreds path. Pinned via the
|
||||
// "delivery_mode lookup happened, then the URL+mode SELECT
|
||||
// happened, then we 503 because no inbound secret" sequence.
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "22222222-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "push")
|
||||
// After the poll branch is bypassed, we hit
|
||||
// resolveWorkspaceForwardCreds which selects url+delivery_mode.
|
||||
expectURL(mock, wsID, "")
|
||||
// URL empty + mode=push → 503 (no inbound secret check needed).
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d body=%s — expected push-mode 503 fall-through", w.Code, w.Body.String())
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("push-mode should NOT have hit storage, got %d puts", len(store.puts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_NotConfigured_FallsThrough(t *testing.T) {
|
||||
// Backwards compat: a binary running without WithPendingUploads
|
||||
// behaves exactly as before — the poll branch is dead code.
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "33333333-2222-3333-4444-555555555555"
|
||||
expectURLAndMode(mock, wsID, "", "poll") // resolveWorkspaceForwardCreds emits 422
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
// No WithPendingUploads — pendingUploads is nil.
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusUnprocessableEntity {
|
||||
t.Errorf("status=%d, want 422 (legacy poll-mode rejection)", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- error paths ----
|
||||
|
||||
func TestPollUpload_WorkspaceMissing_404(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "44444444-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryModeMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_DeliveryModeLookupDBError_500(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "55555555-2222-3333-4444-555555555555"
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).WillReturnError(errors.New("connection lost"))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_NoFilesField_400(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "66666666-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Multipart with a non-files field — no actual files.
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
mw.WriteField("not_files", "hi")
|
||||
mw.Close()
|
||||
|
||||
c, w := makeUploadRequest(t, wsID, &buf, mw.FormDataContentType())
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400 on no files field", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_MalformedMultipart_400(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "77777777-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Body that doesn't match the boundary in Content-Type.
|
||||
c, w := makeUploadRequest(t, wsID, bytes.NewBufferString("garbage"), "multipart/form-data; boundary=fake")
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400 on malformed multipart", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_StorageError_500(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "88888888-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("disk full")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_StorageTooLarge_413(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "99999999-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = pendinguploads.ErrTooLarge
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("status=%d, want 413", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_TooManyFiles_400(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "aaaaaaaa-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 65 files — over the per-batch cap.
|
||||
files := map[string][]byte{}
|
||||
for i := 0; i < 65; i++ {
|
||||
files[uuid.New().String()] = []byte("x")
|
||||
}
|
||||
body, ct := pollUploadFixture(t, files)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400 on too many files", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_NullDeliveryMode_TreatedAsPush(t *testing.T) {
|
||||
// Production-observed 2026-05-04: external runtime workspaces
|
||||
// (molecule-sdk-python on user infra) sometimes register with
|
||||
// delivery_mode = NULL — the schema default for legacy rows from
|
||||
// before #2339. The poll branch must NOT activate on NULL — only
|
||||
// the explicit "poll" string. This is the same defensive posture
|
||||
// resolveWorkspaceForwardCreds takes for legacy rows.
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "cccccccc-2222-3333-4444-555555555555"
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode"}).AddRow(nil))
|
||||
// Falls through to resolveWorkspaceForwardCreds:
|
||||
expectURLAndMode(mock, wsID, "", "")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
// resolveWorkspaceForwardCreds with empty url + NULL mode = 422
|
||||
// (the legacy "no callback URL" rejection — exactly what we're
|
||||
// fixing for ACTUAL poll-mode rows but want to preserve for
|
||||
// NULL ones until the row gets a real mode value via the next
|
||||
// /registry/register).
|
||||
if w.Code != http.StatusUnprocessableEntity {
|
||||
t.Errorf("status=%d, want 422 for NULL delivery_mode (legacy fallthrough)", w.Code)
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("NULL mode should NOT have hit storage, got %d puts", len(store.puts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPollUpload_PerFileCapPreStorage_413(t *testing.T) {
|
||||
// Pin the early-reject branch (fh.Size > MaxFileBytes) BEFORE we
|
||||
// read the part into memory. Without this, an oversize file
|
||||
// would hit the storage layer's belt-and-suspenders check, which
|
||||
// works but burns ~25 MB of memory + DB round-trip first. Send
|
||||
// 25 MB + 1 byte → 413 with the file size in the response.
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "dddddddd-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 25 MB + 1 byte. Single file, large enough to trip the early
|
||||
// size check.
|
||||
oversize := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"big.bin": oversize})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("status=%d, want 413 on per-file size cap", w.Code)
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("per-file cap reject should NOT have called storage.Put, got %d puts", len(store.puts))
|
||||
}
|
||||
// Sanity: response carries the size we tried to upload + the cap.
|
||||
var body_ map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &body_)
|
||||
if got := body_["max"]; got == nil {
|
||||
t.Errorf("expected max field in response, got %v", body_)
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeFilename is exercised in the upload chain — pin one
|
||||
// end-to-end case that exercises the URI path through the response.
|
||||
func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "bbbbbbbb-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"hello world!.pdf": []byte("data")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp struct {
|
||||
Files []struct {
|
||||
Name string `json:"name"`
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if len(resp.Files) == 0 || resp.Files[0].Name != "hello_world_.pdf" {
|
||||
t.Errorf("expected sanitized name 'hello_world_.pdf', got: %+v", resp.Files)
|
||||
}
|
||||
if len(store.puts) == 0 || store.puts[0].Filename != "hello_world_.pdf" {
|
||||
t.Errorf("storage Put didn't receive sanitized filename: %+v", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the
|
||||
// transactional contract introduced in phase 5: when one file in a
|
||||
// multi-file batch fails pre-validation (oversize), NONE of the files
|
||||
// in the batch land in storage. Previously a per-file Put loop would
|
||||
// stage rows 1..K-1 before failing on row K, leaving orphan
|
||||
// pending_uploads + activity rows the client would re-create on retry.
|
||||
//
|
||||
// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's
|
||||
// Tx-rollback behavior on a per-item validation failure) — but the
|
||||
// real atomicity guarantee is the integration test in
|
||||
// pending_uploads_integration_test.go.
|
||||
func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "aaaaaaaa-3333-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Two files: first OK, second over the per-file cap. Pre-validation
|
||||
// in uploadPollMode catches it BEFORE any Put — store.puts must
|
||||
// stay empty. (If the test ever sees len=1, the regression is
|
||||
// "first file slipped through into storage on a partial-failure
|
||||
// batch.")
|
||||
tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1)
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
"ok.txt": []byte("small"),
|
||||
"huge.bin": tooBig,
|
||||
})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String())
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in-
|
||||
// flight PutBatch failure (e.g. simulated DB error) leaves zero rows
|
||||
// — same guarantee as the pre-validation path, but exercises the
|
||||
// "Tx-Rollback after BEGIN" branch via the fake.
|
||||
func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "bbbbbbbb-3333-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("db down mid-batch")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
"a.txt": []byte("aaa"),
|
||||
"b.txt": []byte("bbb"),
|
||||
"c.txt": []byte("ccc"),
|
||||
})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype
|
||||
// hardening: a multipart-supplied Content-Type header with CR/LF is
|
||||
// rewritten to application/octet-stream so the eventual /content
|
||||
// response can't be header-split on the wire.
|
||||
func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) {
|
||||
got := safeMimetype("text/html\r\nX-Injected: pwn")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("CRLF mimetype not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("image/png\x00")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("NUL byte mimetype not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("text/plain; charset=utf-8")
|
||||
if got != "text/plain" {
|
||||
t.Errorf("parameter not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("application/pdf")
|
||||
if got != "application/pdf" {
|
||||
t.Errorf("clean mime modified, got %q", got)
|
||||
}
|
||||
got = safeMimetype("")
|
||||
if got != "" {
|
||||
t.Errorf("empty input should pass through, got %q", got)
|
||||
}
|
||||
got = safeMimetype("notamime")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("non-type/subtype not coerced, got %q", got)
|
||||
}
|
||||
got = safeMimetype("/empty-type")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("missing type half not coerced, got %q", got)
|
||||
}
|
||||
got = safeMimetype("type/")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("missing subtype half not coerced, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_ActivityRowDiscriminator pins the
|
||||
// activity_type / method shape that the workspace inbox poller depends
|
||||
// on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive`
|
||||
// so the handler MUST write activity_type=a2a_receive (NOT a custom
|
||||
// type), and use method=chat_upload_receive as the
|
||||
// upload-vs-message-vs-task discriminator.
|
||||
//
|
||||
// Why pinned: a previous iteration of this handler used
|
||||
// activity_type="chat_upload_receive" — silently invisible to the
|
||||
// existing poller. The branch passed every push-mode test, every
|
||||
// storage test, and every per-file content test; the bug only
|
||||
// surfaced at runtime when the workspace polled and got nothing.
|
||||
// Encode the contract in a unit test so the next refactor can't
|
||||
// re-break it without a red CI.
|
||||
func TestPollUpload_ActivityRowDiscriminator(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "abc12345-6789-4abc-8def-000000000999"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
expectActivityInsertWithTypeAndMethod(mock, wsID, "a2a_receive", "chat_upload_receive")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.pdf": []byte("xx")})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func TestChatUpload_InvalidWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
c, w := makeUploadRequest(t, "not-a-uuid", &bytes.Buffer{}, "")
|
||||
h.Upload(c)
|
||||
@@ -122,7 +122,7 @@ func TestChatUpload_WorkspaceNotInDB(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -166,7 +166,7 @@ func TestChatUpload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -203,7 +203,7 @@ func TestChatUpload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone) // mint fails
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -231,7 +231,7 @@ func TestChatUpload_NoURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000042"
|
||||
expectURLAndMode(mock, wsID, "", "push")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -256,7 +256,7 @@ func TestChatUpload_PollModeEmptyURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLAndMode(mock, wsID, "", "poll")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -286,7 +286,7 @@ func TestChatUpload_NullModeEmptyURL(t *testing.T) {
|
||||
wsID := "30ba7f0b-b303-4a20-aefe-3a4a675b8aa4" // user's "mac laptop"
|
||||
expectURLNullMode(mock, wsID, "")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -338,7 +338,7 @@ func TestChatUpload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "super-secret-123")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -380,7 +380,7 @@ func TestChatUpload_ForwardsErrorStatusUnchanged(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -402,7 +402,7 @@ func TestChatUpload_WorkspaceUnreachable(t *testing.T) {
|
||||
expectURL(mock, wsID, "http://127.0.0.1:1")
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -418,7 +418,7 @@ func TestChatDownload_InvalidPath(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
cases := []struct {
|
||||
name, path, wantSubstr string
|
||||
@@ -507,7 +507,7 @@ func TestChatDownload_WorkspaceNotInDB(t *testing.T) {
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -533,7 +533,7 @@ func TestChatDownload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -559,7 +559,7 @@ func TestChatDownload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -592,7 +592,7 @@ func TestChatDownload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "the-secret")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/report.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestChatDownload_404FromWorkspacePropagated(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/missing.txt")
|
||||
h.Download(c)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/scheduler"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -51,18 +52,31 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
model = defaults.Model
|
||||
}
|
||||
if model == "" {
|
||||
if runtime == "claude-code" {
|
||||
model = "sonnet"
|
||||
} else {
|
||||
model = "anthropic:claude-opus-4-7"
|
||||
}
|
||||
// SSOT: per-runtime defaults live in models/runtime_defaults.go
|
||||
// (see RFC #2873). Consolidated from a duplicate of the same
|
||||
// branch in workspace_provision.go.
|
||||
model = models.DefaultModel(runtime)
|
||||
}
|
||||
tier := ws.Tier
|
||||
if tier == 0 {
|
||||
tier = defaults.Tier
|
||||
}
|
||||
if tier == 0 {
|
||||
tier = 2
|
||||
// Resolved via the same DefaultTier helper Create + Templates
|
||||
// use (#2910 PR-E). SaaS → T4 (one container per sibling EC2,
|
||||
// no neighbour to protect from), self-hosted → T3. Pre-#2910
|
||||
// this path returned T2 on self-hosted, asymmetric with
|
||||
// workspace.go's T3 — undocumented drift. Lifting to
|
||||
// DefaultTier collapses both call sites onto one source of
|
||||
// truth so a future tier-default change sweeps every entry
|
||||
// point at once. Templates that want a different floor still
|
||||
// declare `tier:` in config.yaml or `defaults.tier` in
|
||||
// org.yaml.
|
||||
if h.workspace != nil {
|
||||
tier = h.workspace.DefaultTier()
|
||||
} else {
|
||||
tier = 3
|
||||
}
|
||||
}
|
||||
|
||||
ctxLookup := context.Background()
|
||||
@@ -83,6 +97,16 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
}
|
||||
if existing {
|
||||
log.Printf("Org import: %q already exists (id=%s) — skipping create+provision, recursing into children for partial-match", ws.Name, existingID)
|
||||
parentRef := ""
|
||||
if parentID != nil {
|
||||
parentRef = *parentID
|
||||
}
|
||||
provlog.Event("provision.skip_existing", map[string]any{
|
||||
"name": ws.Name,
|
||||
"existing_id": existingID,
|
||||
"parent_id": parentRef,
|
||||
"tier": tier,
|
||||
})
|
||||
*results = append(*results, map[string]interface{}{
|
||||
"id": existingID,
|
||||
"name": ws.Name,
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -119,6 +123,60 @@ func TestLookupExistingChild_DBError_Propagates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// workspacesInsertRE matches a SQL literal that begins (after optional
|
||||
// leading whitespace) with `INSERT INTO workspaces` followed by `(` —
|
||||
// requiring the open-paren rules out lookalikes like
|
||||
// `INSERT INTO workspaces_audit`, `INSERT INTO workspace_secrets`,
|
||||
// `INSERT INTO workspace_channels`, `INSERT INTO canvas_layouts`. The
|
||||
// previous bytes.Index gate accepted `workspaces_audit` as a prefix
|
||||
// match — see RFC #2872 Important-1 for the silent-false-pass shape.
|
||||
var workspacesInsertRE = regexp.MustCompile(`(?s)^\s*INSERT\s+INTO\s+workspaces\s*\(`)
|
||||
|
||||
// findLookupAndWorkspacesInsertPos walks the AST of `src` and returns
|
||||
// the source positions of (a) the first call to `lookupExistingChild`
|
||||
// and (b) the first CallExpr whose argument list contains a STRING
|
||||
// BasicLit matching workspacesInsertRE. Either may be token.NoPos if
|
||||
// not found.
|
||||
//
|
||||
// Extracted as a helper so the gate logic can be exercised against
|
||||
// synthetic source — TestGate_FailsWhenLookupAfterInsert below proves
|
||||
// the gate actually catches the bug shape, not just the happy path.
|
||||
func findLookupAndWorkspacesInsertPos(t *testing.T, fname string, src []byte) (lookupPos, insertPos token.Pos, fset *token.FileSet) {
|
||||
t.Helper()
|
||||
fset = token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, fname, src, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s: %v", fname, err)
|
||||
}
|
||||
lookupPos, insertPos = token.NoPos, token.NoPos
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
call, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
|
||||
if sel.Sel.Name == "lookupExistingChild" && lookupPos == token.NoPos {
|
||||
lookupPos = call.Pos()
|
||||
}
|
||||
}
|
||||
for _, arg := range call.Args {
|
||||
lit, ok := arg.(*ast.BasicLit)
|
||||
if !ok || lit.Kind != token.STRING {
|
||||
continue
|
||||
}
|
||||
raw := lit.Value
|
||||
if unq, err := strconv.Unquote(raw); err == nil {
|
||||
raw = unq
|
||||
}
|
||||
if workspacesInsertRE.MatchString(raw) && insertPos == token.NoPos {
|
||||
insertPos = call.Pos()
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Source-level guard — pins that org_import.go calls
|
||||
// h.lookupExistingChild BEFORE its INSERT INTO workspaces.
|
||||
//
|
||||
@@ -126,6 +184,11 @@ func TestLookupExistingChild_DBError_Propagates(t *testing.T) {
|
||||
// (idempotency check before INSERT), not just function names. If a
|
||||
// future refactor reintroduces the un-checked INSERT (the original
|
||||
// bug shape that leaked 72 workspaces in 4 days), this test fails.
|
||||
//
|
||||
// AST-walk implementation closes the silent-false-pass mode that the
|
||||
// previous bytes.Index gate had — see workspacesInsertRE comment for
|
||||
// the failure mode (workspaces_audit / workspace_secrets / etc.
|
||||
// shadowing the real target via prefix match).
|
||||
func TestCreateWorkspaceTree_CallsLookupBeforeInsert(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -135,17 +198,189 @@ func TestCreateWorkspaceTree_CallsLookupBeforeInsert(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("read org_import.go: %v", err)
|
||||
}
|
||||
lookupPos, insertPos, fset := findLookupAndWorkspacesInsertPos(t, "org_import.go", src)
|
||||
|
||||
lookupAt := bytes.Index(src, []byte("h.lookupExistingChild("))
|
||||
insertAt := bytes.Index(src, []byte("INSERT INTO workspaces"))
|
||||
|
||||
if lookupAt < 0 {
|
||||
t.Fatalf("org_import.go missing call to h.lookupExistingChild — idempotency check removed?")
|
||||
if lookupPos == token.NoPos {
|
||||
t.Fatalf("AST: no call to lookupExistingChild in org_import.go — idempotency check removed?")
|
||||
}
|
||||
if insertAt < 0 {
|
||||
t.Fatalf("org_import.go missing INSERT INTO workspaces — schema change?")
|
||||
if insertPos == token.NoPos {
|
||||
t.Fatalf("AST: no SQL literal matching `^\\s*INSERT INTO workspaces\\s*\\(` in any CallExpr in org_import.go — schema change or rename?")
|
||||
}
|
||||
if lookupAt > insertAt {
|
||||
t.Errorf("h.lookupExistingChild must come BEFORE INSERT INTO workspaces in org_import.go (lookup@%d, insert@%d) — non-idempotent ordering would re-leak under repeat /org/import calls", lookupAt, insertAt)
|
||||
if lookupPos > insertPos {
|
||||
t.Errorf("lookupExistingChild call at %s must come BEFORE INSERT INTO workspaces at %s — non-idempotent ordering would re-leak under repeat /org/import calls",
|
||||
fset.Position(lookupPos), fset.Position(insertPos))
|
||||
}
|
||||
}
|
||||
|
||||
// TestGate_FailsWhenLookupAfterInsert proves the gate actually catches
|
||||
// the bug it's named after — running it against synthetic Go source
|
||||
// where the lookup call is positioned AFTER the workspaces INSERT must
|
||||
// produce lookupPos > insertPos, which the production gate flags as
|
||||
// an ERROR. Without this test the gate could regress to "always pass"
|
||||
// and we wouldn't notice until the bug shipped again.
|
||||
//
|
||||
// Per memory feedback_assert_exact_not_substring.md: verify a
|
||||
// tightened test FAILS on old code before merging.
|
||||
func TestGate_FailsWhenLookupAfterInsert(t *testing.T) {
|
||||
const buggySrc = `package handlers
|
||||
|
||||
import "context"
|
||||
|
||||
type fakeDB struct{}
|
||||
|
||||
func (fakeDB) ExecContext(ctx context.Context, sql string, args ...interface{}) {}
|
||||
|
||||
type fakeOrgHandler struct{}
|
||||
|
||||
func (h *fakeOrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func buggyCreate(h *fakeOrgHandler, db fakeDB, ctx context.Context, name string, parentID *string) {
|
||||
// Bug shape: INSERT runs FIRST, lookup runs AFTER. This is the
|
||||
// non-idempotent ordering the gate exists to forbid.
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces (id, name) VALUES ($1, $2)`" + `, "x", name)
|
||||
h.lookupExistingChild(ctx, name, parentID)
|
||||
}
|
||||
`
|
||||
lookupPos, insertPos, _ := findLookupAndWorkspacesInsertPos(t, "buggy.go", []byte(buggySrc))
|
||||
if lookupPos == token.NoPos || insertPos == token.NoPos {
|
||||
t.Fatalf("synthetic buggy source missing expected nodes (lookupPos=%v insertPos=%v) — helper logic regression", lookupPos, insertPos)
|
||||
}
|
||||
if lookupPos < insertPos {
|
||||
t.Fatalf("synthetic bug shape (lookup AFTER insert) returned lookupPos=%d < insertPos=%d — gate would NOT fire on actual bug, regression!", lookupPos, insertPos)
|
||||
}
|
||||
// Implicit: lookupPos > insertPos here, which the production gate
|
||||
// flags via t.Errorf. This proves the gate is live, not vestigial.
|
||||
}
|
||||
|
||||
// TestGate_IgnoresAuditTableShadow proves the regex tightening
|
||||
// actually ignores `INSERT INTO workspaces_audit` literals — the
|
||||
// specific shape #2872 cited as the silent-false-pass failure mode
|
||||
// for the previous bytes.Index gate.
|
||||
func TestGate_IgnoresAuditTableShadow(t *testing.T) {
|
||||
// Synthetic source with audit-table INSERT at line 1 (would be
|
||||
// position 0 under prefix-match) and lookup + real INSERT at later
|
||||
// positions. With the tightened regex, the audit literal is
|
||||
// ignored: insertPos points at the REAL INSERT, lookup precedes it,
|
||||
// gate passes correctly.
|
||||
const src = `package handlers
|
||||
|
||||
import "context"
|
||||
|
||||
type fakeDB struct{}
|
||||
|
||||
func (fakeDB) ExecContext(ctx context.Context, sql string, args ...interface{}) {}
|
||||
|
||||
type fakeOrgHandler struct{}
|
||||
|
||||
func (h *fakeOrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func okCreateWithAudit(h *fakeOrgHandler, db fakeDB, ctx context.Context, name string, parentID *string) {
|
||||
// Audit-table INSERT — should be IGNORED by the tightened regex.
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces_audit (id, action) VALUES ($1, $2)`" + `, "x", "create_attempt")
|
||||
// Lookup BEFORE real INSERT — correct order.
|
||||
h.lookupExistingChild(ctx, name, parentID)
|
||||
// Real INSERT.
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces (id, name) VALUES ($1, $2)`" + `, "x", name)
|
||||
}
|
||||
`
|
||||
lookupPos, insertPos, fset := findLookupAndWorkspacesInsertPos(t, "shadow.go", []byte(src))
|
||||
if lookupPos == token.NoPos || insertPos == token.NoPos {
|
||||
t.Fatalf("expected to find lookup + real INSERT, got lookupPos=%v insertPos=%v", lookupPos, insertPos)
|
||||
}
|
||||
// The audit-table INSERT is at line ~16 (column ~20-ish), the
|
||||
// lookup is at line 19, the real INSERT is at line 21. If the
|
||||
// regex regressed to prefix-match, insertPos would point at the
|
||||
// audit literal at line 16, and the gate would falsely fail
|
||||
// (lookup at 19 > "insert" at 16). With the tightened regex,
|
||||
// insertPos correctly points at line 21, and the gate passes.
|
||||
insertLine := fset.Position(insertPos).Line
|
||||
lookupLine := fset.Position(lookupPos).Line
|
||||
if insertLine < lookupLine {
|
||||
t.Errorf("regex regressed: audit shadow at line %d swallowed real INSERT (lookup at line %d). insertPos should point at the real INSERT (line ~21), not the audit literal.",
|
||||
insertLine, lookupLine)
|
||||
}
|
||||
if lookupPos > insertPos {
|
||||
t.Errorf("synthetic source has lookup at line %d before real INSERT at line %d, gate should pass (lookupPos < insertPos), got lookupPos=%d > insertPos=%d",
|
||||
lookupLine, insertLine, lookupPos, insertPos)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWorkspacesInsertRE_RejectsLookalikes pins the regex that
|
||||
// discriminates the real workspaces INSERT from prefix-matching
|
||||
// lookalikes. If this regex regresses to a substring match, the
|
||||
// AST gate above silently false-passes when a future refactor
|
||||
// shadows the real INSERT with a workspaces_audit / workspace_secrets
|
||||
// / canvas_layouts literal placed earlier in source.
|
||||
func TestWorkspacesInsertRE_RejectsLookalikes(t *testing.T) {
|
||||
cases := []struct {
|
||||
sql string
|
||||
want bool
|
||||
comment string
|
||||
}{
|
||||
{"INSERT INTO workspaces (id, name) VALUES ($1, $2)", true, "real target"},
|
||||
{"\n\t\tINSERT INTO workspaces (id, name)\n\t\tVALUES ($1, $2)", true, "real target with leading whitespace + newlines (raw string literal shape)"},
|
||||
{"INSERT INTO workspaces_audit (id) VALUES ($1)", false, "underscore-suffix lookalike (the #2872 specific failure mode)"},
|
||||
{"INSERT INTO workspace_secrets (key, value) VALUES ($1, $2)", false, "prefix without trailing 's' (workspace_*)"},
|
||||
{"INSERT INTO workspace_channels (id) VALUES ($1)", false, "another workspace_* prefix"},
|
||||
{"INSERT INTO canvas_layouts (workspace_id, x, y) VALUES ($1, $2, $3)", false, "unrelated table that contains 'workspace' in a column ref"},
|
||||
{"UPDATE workspaces SET status='running' WHERE id=$1", false, "UPDATE shouldn't match"},
|
||||
{"SELECT * FROM workspaces WHERE id=$1", false, "SELECT shouldn't match"},
|
||||
{"-- comment about INSERT INTO workspaces (\nSELECT 1", false, "comment shouldn't match"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := workspacesInsertRE.MatchString(c.sql)
|
||||
if got != c.want {
|
||||
t.Errorf("workspacesInsertRE.MatchString(%q) = %v, want %v (%s)", c.sql, got, c.want, c.comment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Confirm the regex actually matches the literal currently in
|
||||
// org_import.go. Pins the shape so `gofmt` reflows or trivial edits
|
||||
// to the SQL string don't silently disable the gate above.
|
||||
func TestWorkspacesInsertRE_MatchesActualSourceLiteral(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
src, err := os.ReadFile(filepath.Join(wd, "org_import.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("read org_import.go: %v", err)
|
||||
}
|
||||
// Strip backtick strings, find any whose content matches.
|
||||
// Walk the source via parser.ParseFile to avoid string-search
|
||||
// drift if the literal is reflowed.
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, filepath.Join(wd, "org_import.go"), src, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse org_import.go: %v", err)
|
||||
}
|
||||
var matched bool
|
||||
ast.Inspect(file, func(n ast.Node) bool {
|
||||
lit, ok := n.(*ast.BasicLit)
|
||||
if !ok || lit.Kind != token.STRING {
|
||||
return true
|
||||
}
|
||||
raw := lit.Value
|
||||
if unq, err := strconv.Unquote(raw); err == nil {
|
||||
raw = unq
|
||||
}
|
||||
if workspacesInsertRE.MatchString(raw) {
|
||||
matched = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !matched {
|
||||
t.Fatalf("no SQL literal in org_import.go matches workspacesInsertRE — gate is dead. Either the INSERT was renamed (update the regex) or the file was restructured (review the gate logic).")
|
||||
}
|
||||
// strings.Contains keeps the test informative: if the regex
|
||||
// stopped matching but the literal source still contains the
|
||||
// magic phrase, that's a regex-side failure (test the fix above).
|
||||
if !strings.Contains(string(src), "INSERT INTO workspaces") {
|
||||
t.Fatalf("org_import.go has no `INSERT INTO workspaces` substring at all — schema change?")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
// pending_uploads.go — endpoints the workspace polls to fetch and ack
|
||||
// chat-upload files staged on the platform side for poll-mode delivery.
|
||||
//
|
||||
// Companion to chat_files.go Upload's poll-mode branch:
|
||||
//
|
||||
// Canvas POST /workspaces/:id/chat/uploads
|
||||
// ↓ (poll-mode workspace)
|
||||
// Platform: chat_files.uploadPollMode
|
||||
// ↓ writes pending_uploads row + activity_logs(type=chat_upload_receive)
|
||||
// Workspace inbox poller picks up activity row
|
||||
// ↓
|
||||
// Workspace GETs /workspaces/:id/pending-uploads/:fid/content ← this file
|
||||
// ↓ writes file to /workspace/.molecule/chat-uploads
|
||||
// Workspace POSTs /workspaces/:id/pending-uploads/:fid/ack ← this file
|
||||
// ↓ row marked acked; Phase 3 sweep deletes
|
||||
//
|
||||
// Auth: same wsAuth middleware that gates the activity poll endpoint —
|
||||
// the workspace's per-workspace platform_token. Only the target workspace
|
||||
// can read OR ack its own pending uploads. The handler enforces that
|
||||
// :id == file.workspace_id even though the URL param matches; defence in
|
||||
// depth against a token leak letting one workspace pull another's bytes.
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// PendingUploadsHandler serves the workspace-side fetch + ack endpoints.
|
||||
// Holds a Storage so tests can inject an in-memory implementation
|
||||
// without going through Postgres (sqlmock-based unit tests cover the
|
||||
// Postgres impl in internal/pendinguploads/storage_test.go).
|
||||
type PendingUploadsHandler struct {
|
||||
storage pendinguploads.Storage
|
||||
}
|
||||
|
||||
// NewPendingUploadsHandler constructs the handler with a concrete
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.DB).
|
||||
func NewPendingUploadsHandler(storage pendinguploads.Storage) *PendingUploadsHandler {
|
||||
return &PendingUploadsHandler{storage: storage}
|
||||
}
|
||||
|
||||
// GetContent handles GET /workspaces/:id/pending-uploads/:file_id/content.
|
||||
//
|
||||
// Returns the file bytes with the original mimetype and a
|
||||
// Content-Disposition that names the original (sanitized) filename so
|
||||
// the workspace's fetcher writes it under the expected name. Stamps
|
||||
// fetched_at on the row best-effort — the read response is already
|
||||
// flushed to the network before the MarkFetched call so a sweep race
|
||||
// can't break the workspace's fetch.
|
||||
//
|
||||
// 404 on:
|
||||
// - file_id not found
|
||||
// - file_id belongs to a different workspace (cross-workspace bleed
|
||||
// protection)
|
||||
// - row already acked (workspace's bug — should not re-fetch after ack)
|
||||
// - row past expires_at (Phase 3 sweep would delete shortly anyway)
|
||||
func (h *PendingUploadsHandler) GetContent(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if err := validateWorkspaceID(workspaceID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
|
||||
return
|
||||
}
|
||||
fileIDStr := c.Param("file_id")
|
||||
fileID, err := uuid.Parse(fileIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid file_id"})
|
||||
return
|
||||
}
|
||||
|
||||
rec, err := h.storage.Get(c.Request.Context(), fileID)
|
||||
if errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "pending upload not found, expired, or already acked"})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("pending_uploads GetContent: storage.Get(%s) failed: %v", fileID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Cross-workspace bleed protection: a token leak from workspace A
|
||||
// must not let it read workspace B's pending uploads even with the
|
||||
// correct file_id. wsAuth already pinned the caller to :id; reject
|
||||
// if the row's workspace_id doesn't match.
|
||||
if rec.WorkspaceID.String() != workspaceID {
|
||||
log.Printf("pending_uploads GetContent: workspace mismatch — caller=%s row=%s file_id=%s",
|
||||
workspaceID, rec.WorkspaceID, fileID)
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "pending upload not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Stream the bytes. Set the original mimetype if known; fall back
|
||||
// to application/octet-stream so curl / browser clients still get
|
||||
// a valid response. Content-Disposition uses the workspace-side
|
||||
// filename so the fetcher writes it under the expected name.
|
||||
mimetype := rec.Mimetype
|
||||
if mimetype == "" {
|
||||
mimetype = "application/octet-stream"
|
||||
}
|
||||
c.Header("Content-Type", mimetype)
|
||||
c.Header("Content-Disposition", contentDispositionAttachment(rec.Filename))
|
||||
c.Header("Content-Length", strconv.FormatInt(rec.SizeBytes, 10))
|
||||
c.Status(http.StatusOK)
|
||||
if _, err := c.Writer.Write(rec.Content); err != nil {
|
||||
// Connection closed mid-stream — log and bail; we cannot
|
||||
// re-emit headers at this point. The workspace's HTTP client
|
||||
// will see the truncated body and retry on next poll.
|
||||
log.Printf("pending_uploads GetContent: write failed for %s: %v", fileID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort fetched_at stamp. After-the-fact so the GET response
|
||||
// completes regardless of the UPDATE outcome — a Phase 3 sweep
|
||||
// race that nukes the row between Get and MarkFetched must not
|
||||
// break the workspace's fetch.
|
||||
if err := h.storage.MarkFetched(c.Request.Context(), fileID); err != nil {
|
||||
log.Printf("pending_uploads GetContent: mark_fetched(%s) failed: %v", fileID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ack handles POST /workspaces/:id/pending-uploads/:file_id/ack.
|
||||
//
|
||||
// Marks the row as handed-off; Phase 3 sweep deletes acked rows after
|
||||
// a retention window. Idempotent — workspace at-least-once retries on
|
||||
// a flaky network return success without moving the timestamp.
|
||||
func (h *PendingUploadsHandler) Ack(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if err := validateWorkspaceID(workspaceID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
|
||||
return
|
||||
}
|
||||
fileIDStr := c.Param("file_id")
|
||||
fileID, err := uuid.Parse(fileIDStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid file_id"})
|
||||
return
|
||||
}
|
||||
|
||||
// Cross-workspace bleed protection: do a lookup BEFORE Ack so
|
||||
// a token leak can't ack a row owned by a different workspace.
|
||||
// We don't expose this distinction in the response (404 either
|
||||
// way) — the workspace can't tell whether it ack'd a non-existent
|
||||
// row vs. one it didn't own, and that's fine for the contract.
|
||||
rec, err := h.storage.Get(c.Request.Context(), fileID)
|
||||
if errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "pending upload not found, expired, or already acked"})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("pending_uploads Ack: storage.Get(%s) failed: %v", fileID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage error"})
|
||||
return
|
||||
}
|
||||
if rec.WorkspaceID.String() != workspaceID {
|
||||
log.Printf("pending_uploads Ack: workspace mismatch — caller=%s row=%s file_id=%s",
|
||||
workspaceID, rec.WorkspaceID, fileID)
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "pending upload not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.storage.Ack(c.Request.Context(), fileID); err != nil {
|
||||
if errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
// Race window: the row passed Get but failed Ack — sweep
|
||||
// raced with us between the two queries. Treat as success
|
||||
// (the workspace's intent was honored, the row is gone).
|
||||
c.JSON(http.StatusOK, gin.H{"acked": true, "raced": true})
|
||||
return
|
||||
}
|
||||
log.Printf("pending_uploads Ack: storage.Ack(%s) failed: %v", fileID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "storage error"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"acked": true})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,476 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
// pending_uploads_integration_test.go — REAL Postgres integration
|
||||
// tests for the poll-mode chat upload flow (RFC: phases 1–3).
|
||||
//
|
||||
// Run with:
|
||||
//
|
||||
// docker run --rm -d --name pg-integration \
|
||||
// -e POSTGRES_PASSWORD=test -e POSTGRES_DB=molecule \
|
||||
// -p 55432:5432 postgres:15-alpine
|
||||
// sleep 4
|
||||
// psql ... < workspace-server/migrations/20260505100000_pending_uploads.up.sql
|
||||
// cd workspace-server
|
||||
// INTEGRATION_DB_URL="postgres://postgres:test@localhost:55432/molecule?sslmode=disable" \
|
||||
// go test -tags=integration ./internal/handlers/ -run Integration_PendingUploads
|
||||
//
|
||||
// CI (.github/workflows/handlers-postgres-integration.yml) runs this on
|
||||
// every PR that touches workspace-server/internal/handlers/** OR
|
||||
// workspace-server/migrations/**.
|
||||
//
|
||||
// Why these are NOT plain unit tests
|
||||
// ----------------------------------
|
||||
// The strict-sqlmock unit tests in storage_test.go pin which SQL
|
||||
// statements fire — they are fast and let us iterate without a DB. But
|
||||
// sqlmock CANNOT detect bugs that depend on the actual row state after
|
||||
// the SQL runs. In particular:
|
||||
//
|
||||
// - the WITH … DELETE … RETURNING CTE used by Sweep depends on
|
||||
// Postgres' `make_interval` function and the table's CHECK
|
||||
// constraints. sqlmock would happily accept a hand-written SQL
|
||||
// literal that Postgres rejects at runtime.
|
||||
// - the partial index `idx_pending_uploads_unacked` (created by the
|
||||
// Phase 1 migration) only catches a wrong WHERE predicate at real-
|
||||
// query-plan time.
|
||||
//
|
||||
// These tests close those gaps by booting a real Postgres, running the
|
||||
// production helpers, and SELECTing the row to verify the observable
|
||||
// state matches the expected outcome.
|
||||
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// integrationDB_PendingUploads opens a connection from $INTEGRATION_DB_URL
|
||||
// (skipping the test if unset), wipes the pending_uploads table for
|
||||
// isolation, and registers a Cleanup that closes the connection.
|
||||
//
|
||||
// NOT SAFE FOR `t.Parallel()` — each test gets the table to itself.
|
||||
// Mirrors the integrationDB helper in delegation_ledger_integration_test.go
|
||||
// but kept separate so each table's wipe step is local to its tests.
|
||||
func integrationDB_PendingUploads(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
url := os.Getenv("INTEGRATION_DB_URL")
|
||||
if url == "" {
|
||||
t.Skip("INTEGRATION_DB_URL not set; skipping (local devs: see file header)")
|
||||
}
|
||||
conn, err := sql.Open("postgres", url)
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(context.Background(), `DELETE FROM pending_uploads`); err != nil {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
return conn
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_PutGetAckRoundTrip(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
fileID, err := store.Put(ctx, wsID, []byte("hello PDF"), "report.pdf", "application/pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
// Get reads back the row.
|
||||
rec, err := store.Get(ctx, fileID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if rec.WorkspaceID != wsID {
|
||||
t.Errorf("workspace_id = %s, want %s", rec.WorkspaceID, wsID)
|
||||
}
|
||||
if string(rec.Content) != "hello PDF" {
|
||||
t.Errorf("content = %q, want %q", rec.Content, "hello PDF")
|
||||
}
|
||||
if rec.Filename != "report.pdf" {
|
||||
t.Errorf("filename = %q, want %q", rec.Filename, "report.pdf")
|
||||
}
|
||||
if rec.AckedAt != nil {
|
||||
t.Errorf("AckedAt should be nil before Ack, got %v", rec.AckedAt)
|
||||
}
|
||||
|
||||
// MarkFetched stamps fetched_at.
|
||||
if err := store.MarkFetched(ctx, fileID); err != nil {
|
||||
t.Fatalf("MarkFetched: %v", err)
|
||||
}
|
||||
|
||||
// Re-read to confirm.
|
||||
rec2, err := store.Get(ctx, fileID)
|
||||
if err != nil {
|
||||
t.Fatalf("Get after MarkFetched: %v", err)
|
||||
}
|
||||
if rec2.FetchedAt == nil {
|
||||
t.Errorf("FetchedAt should be set after MarkFetched")
|
||||
}
|
||||
|
||||
// Ack flips acked_at; subsequent Gets return ErrNotFound (acked rows
|
||||
// are filtered out at the SELECT predicate).
|
||||
if err := store.Ack(ctx, fileID); err != nil {
|
||||
t.Fatalf("Ack: %v", err)
|
||||
}
|
||||
if _, err := store.Get(ctx, fileID); err != pendinguploads.ErrNotFound {
|
||||
t.Errorf("Get after Ack: got %v, want ErrNotFound", err)
|
||||
}
|
||||
|
||||
// Idempotent re-ack succeeds.
|
||||
if err := store.Ack(ctx, fileID); err != nil {
|
||||
t.Errorf("re-Ack should be idempotent, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_Sweep_DeletesAckedAfterRetention(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
fid, err := store.Put(ctx, wsID, []byte("data"), "x.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
if err := store.Ack(ctx, fid); err != nil {
|
||||
t.Fatalf("Ack: %v", err)
|
||||
}
|
||||
|
||||
// retention=1h, row was acked just now → not yet eligible.
|
||||
res, err := store.Sweep(ctx, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep(1h): %v", err)
|
||||
}
|
||||
if res.Total() != 0 {
|
||||
t.Errorf("expected 0 deletions yet, got %+v", res)
|
||||
}
|
||||
|
||||
// retention=0 → row IS eligible immediately.
|
||||
res, err = store.Sweep(ctx, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep(0): %v", err)
|
||||
}
|
||||
if res.Acked != 1 || res.Expired != 0 {
|
||||
t.Errorf("expected acked=1 expired=0, got %+v", res)
|
||||
}
|
||||
|
||||
// Verify row is actually gone — not just un-fetchable.
|
||||
var n int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE file_id = $1`, fid).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("row should be DELETEd, found %d rows", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_Sweep_DeletesExpiredUnacked(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
fid, err := store.Put(ctx, wsID, []byte("data"), "x.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
// Manually backdate expires_at so the row IS expired. We don't ack,
|
||||
// so this exercises the unacked-and-expired branch of the WHERE
|
||||
// clause specifically.
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
`UPDATE pending_uploads SET expires_at = now() - interval '1 minute' WHERE file_id = $1`,
|
||||
fid,
|
||||
); err != nil {
|
||||
t.Fatalf("backdate: %v", err)
|
||||
}
|
||||
|
||||
res, err := store.Sweep(ctx, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Acked != 0 || res.Expired != 1 {
|
||||
t.Errorf("expected acked=0 expired=1, got %+v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_Sweep_DeletesBothCategoriesInOneCycle(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
|
||||
// Three rows: one acked (eligible at retention=0), one expired
|
||||
// unacked, one fresh unacked (must NOT be deleted).
|
||||
ackedFID, err := store.Put(ctx, wsID, []byte("acked"), "a.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put acked: %v", err)
|
||||
}
|
||||
if err := store.Ack(ctx, ackedFID); err != nil {
|
||||
t.Fatalf("Ack: %v", err)
|
||||
}
|
||||
|
||||
expiredFID, err := store.Put(ctx, wsID, []byte("expired"), "e.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put expired: %v", err)
|
||||
}
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
`UPDATE pending_uploads SET expires_at = now() - interval '1 minute' WHERE file_id = $1`,
|
||||
expiredFID,
|
||||
); err != nil {
|
||||
t.Fatalf("backdate: %v", err)
|
||||
}
|
||||
|
||||
freshFID, err := store.Put(ctx, wsID, []byte("fresh"), "f.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put fresh: %v", err)
|
||||
}
|
||||
|
||||
res, err := store.Sweep(ctx, 0) // retention=0 makes the acked row eligible
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Acked != 1 || res.Expired != 1 {
|
||||
t.Errorf("expected acked=1 expired=1, got %+v", res)
|
||||
}
|
||||
|
||||
// Fresh row survives.
|
||||
rec, err := store.Get(ctx, freshFID)
|
||||
if err != nil {
|
||||
t.Errorf("fresh row should still be Get-able, got err=%v", err)
|
||||
}
|
||||
if rec.FileID != freshFID {
|
||||
t.Errorf("fresh row file_id = %s, want %s", rec.FileID, freshFID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
tooBig := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
if _, err := store.Put(ctx, wsID, tooBig, "big.bin", "application/octet-stream"); err != pendinguploads.ErrTooLarge {
|
||||
t.Errorf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the
|
||||
// "all rows commit" leg of the PutBatch atomicity contract against a real
|
||||
// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery
|
||||
// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only
|
||||
// COUNT(*) on the real table can.
|
||||
func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
|
||||
// Pre-existing row so the COUNT(*) baseline is non-zero — proves
|
||||
// PutBatch adds rows incrementally rather than overwriting.
|
||||
if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil {
|
||||
t.Fatalf("seed Put: %v", err)
|
||||
}
|
||||
|
||||
items := []pendinguploads.PutItem{
|
||||
{Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"},
|
||||
{Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"},
|
||||
}
|
||||
ids, err := store.PutBatch(ctx, wsID, items)
|
||||
if err != nil {
|
||||
t.Fatalf("PutBatch: %v", err)
|
||||
}
|
||||
if len(ids) != len(items) {
|
||||
t.Fatalf("ids length %d, want %d", len(ids), len(items))
|
||||
}
|
||||
|
||||
// Each returned id round-trips through Get with the right content.
|
||||
for i, id := range ids {
|
||||
rec, err := store.Get(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Get item %d (%s): %v", i, id, err)
|
||||
}
|
||||
if string(rec.Content) != string(items[i].Content) {
|
||||
t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content)
|
||||
}
|
||||
if rec.Filename != items[i].Filename {
|
||||
t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename)
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure
|
||||
// proves the all-or-nothing contract end-to-end against real Postgres MVCC.
|
||||
//
|
||||
// Strategy: build a 3-item batch where item index 1 carries a filename with
|
||||
// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol
|
||||
// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which
|
||||
// triggers the per-row INSERT error path in PutBatch. The first item's
|
||||
// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy
|
||||
// rollback would leave that row visible after PutBatch returns.
|
||||
//
|
||||
// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the
|
||||
// "no leak" contract; a unit test with sqlmock can prove the Go function
|
||||
// CALLED Rollback, but only this integration test proves Postgres actually
|
||||
// HONORED it.
|
||||
func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
|
||||
// Baseline COUNT(*) for this workspace — must remain 0 after a failed batch.
|
||||
var before int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil {
|
||||
t.Fatalf("baseline count: %v", err)
|
||||
}
|
||||
if before != 0 {
|
||||
t.Fatalf("workspace not isolated: baseline = %d, want 0", before)
|
||||
}
|
||||
|
||||
// Item 1 has a NUL byte in the filename — Go-side pre-validation
|
||||
// (which only checks empty/length) lets it through, so the INSERT
|
||||
// reaches lib/pq, which rejects it at the protocol level. That's the
|
||||
// canonical "DB-side error mid-batch" we want to exercise.
|
||||
items := []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"},
|
||||
}
|
||||
_, err := store.PutBatch(ctx, wsID, items)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error from NUL-byte filename, got nil")
|
||||
}
|
||||
|
||||
// THE assertion this whole test exists for: even though item 0's
|
||||
// INSERT…RETURNING succeeded inside the Tx, the rollback unwound
|
||||
// it — zero rows for this workspace, not one (let alone three).
|
||||
var after int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil {
|
||||
t.Fatalf("post-failure count: %v", err)
|
||||
}
|
||||
if after != 0 {
|
||||
t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the
|
||||
// pre-validation short-circuit: an oversized item rejects with ErrTooLarge
|
||||
// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock
|
||||
// with zero expectations) catches the Go-side path; this test sanity-checks
|
||||
// no real DB I/O happens by confirming COUNT(*) doesn't move.
|
||||
func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
tooBig := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
_, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "ok.txt"},
|
||||
{Content: tooBig, Filename: "too-big.bin"},
|
||||
})
|
||||
if err != pendinguploads.ErrTooLarge {
|
||||
t.Fatalf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
var n int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a
|
||||
// migration (20260505200000_pending_uploads_acked_index.up.sql) actually
|
||||
// created idx_pending_uploads_acked with the right partial-index predicate.
|
||||
//
|
||||
// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny
|
||||
// tables regardless of available indexes — a plan-shape check would be
|
||||
// flaky under real test loads. The contract we care about is "the index
|
||||
// exists with the predicate we wrote in the migration"; pg_indexes is
|
||||
// the canonical source for that, robust to row count and planner version.
|
||||
func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
ctx := context.Background()
|
||||
|
||||
var indexdef string
|
||||
err := conn.QueryRowContext(ctx, `
|
||||
SELECT indexdef FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = 'pending_uploads'
|
||||
AND indexname = 'idx_pending_uploads_acked'
|
||||
`).Scan(&indexdef)
|
||||
if err == sql.ErrNoRows {
|
||||
t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("pg_indexes query: %v", err)
|
||||
}
|
||||
|
||||
// Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL"
|
||||
// we'd be indexing the entire table (defeats the point — most rows are
|
||||
// unacked), and the existing idx_pending_uploads_unacked already covers
|
||||
// the inverse predicate.
|
||||
if !strings.Contains(indexdef, "(acked_at)") {
|
||||
t.Errorf("index missing acked_at column: %s", indexdef)
|
||||
}
|
||||
if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") {
|
||||
t.Errorf("index missing partial predicate: %s", indexdef)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
fid, err := store.Put(ctx, wsID, []byte("data"), "x.txt", "text/plain")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
// Backdate expires_at — Get must return ErrNotFound, even though the
|
||||
// row physically exists in the table (Sweep hasn't run).
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
`UPDATE pending_uploads SET expires_at = now() - interval '1 minute' WHERE file_id = $1`,
|
||||
fid,
|
||||
); err != nil {
|
||||
t.Fatalf("backdate: %v", err)
|
||||
}
|
||||
if _, err := store.Get(ctx, fid); err != pendinguploads.ErrNotFound {
|
||||
t.Errorf("Get after expiry: got %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/handlers"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// fakeStorage is an in-memory pendinguploads.Storage. Lets handler
|
||||
// tests pin behaviour without going through Postgres + sqlmock — the
|
||||
// storage layer's own tests (internal/pendinguploads/storage_test.go)
|
||||
// cover the SQL drift surface; here we only care about the handler's
|
||||
// 4xx/5xx mapping and side-effect ordering.
|
||||
type fakeStorage struct {
|
||||
rows map[uuid.UUID]pendinguploads.Record
|
||||
getErr error // forced error from Get (overrides rows lookup)
|
||||
ackErr error // forced error from Ack
|
||||
markErr error // forced error from MarkFetched
|
||||
markFetched []uuid.UUID
|
||||
ackCalls []uuid.UUID
|
||||
}
|
||||
|
||||
func newFakeStorage() *fakeStorage {
|
||||
return &fakeStorage{rows: map[uuid.UUID]pendinguploads.Record{}}
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Put(ctx context.Context, ws uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) {
|
||||
id := uuid.New()
|
||||
f.rows[id] = pendinguploads.Record{
|
||||
FileID: id, WorkspaceID: ws, Content: content,
|
||||
Filename: filename, Mimetype: mimetype,
|
||||
SizeBytes: int64(len(content)), CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Get(_ context.Context, fileID uuid.UUID) (pendinguploads.Record, error) {
|
||||
if f.getErr != nil {
|
||||
return pendinguploads.Record{}, f.getErr
|
||||
}
|
||||
rec, ok := f.rows[fileID]
|
||||
if !ok {
|
||||
return pendinguploads.Record{}, pendinguploads.ErrNotFound
|
||||
}
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
func (f *fakeStorage) MarkFetched(_ context.Context, fileID uuid.UUID) error {
|
||||
f.markFetched = append(f.markFetched, fileID)
|
||||
return f.markErr
|
||||
}
|
||||
|
||||
func (f *fakeStorage) Ack(_ context.Context, fileID uuid.UUID) error {
|
||||
f.ackCalls = append(f.ackCalls, fileID)
|
||||
if f.ackErr != nil {
|
||||
return f.ackErr
|
||||
}
|
||||
delete(f.rows, fileID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sweep is required by the Storage interface (Phase 3 GC). Not exercised
|
||||
// by these handler tests — the dedicated sweeper_test.go covers it.
|
||||
func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads.SweepResult, error) {
|
||||
return pendinguploads.SweepResult{}, nil
|
||||
}
|
||||
|
||||
// PutBatch is required by the Storage interface; the upload handler
|
||||
// tests live in chat_files_poll_test.go and use a separate fake
|
||||
// (inMemStorage). Stubbed here because the Get/Ack tests don't drive
|
||||
// PutBatch, but the interface must be satisfied.
|
||||
func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/pending-uploads/:file_id/content", handler.GetContent)
|
||||
r.POST("/workspaces/:id/pending-uploads/:file_id/ack", handler.Ack)
|
||||
return r
|
||||
}
|
||||
|
||||
// ---- GetContent ----
|
||||
|
||||
func TestGetContent_HappyPath_StreamsBytesAndStampsFetched(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, err := fs.Put(context.Background(), wsID, []byte("hello world"), "report.pdf", "application/pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
h := handlers.NewPendingUploadsHandler(fs)
|
||||
r := newRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if got := w.Body.String(); got != "hello world" {
|
||||
t.Errorf("body = %q, want %q", got, "hello world")
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "application/pdf" {
|
||||
t.Errorf("Content-Type = %q, want application/pdf", got)
|
||||
}
|
||||
if got := w.Header().Get("Content-Disposition"); !strings.Contains(got, "report.pdf") {
|
||||
t.Errorf("Content-Disposition = %q, expected to mention report.pdf", got)
|
||||
}
|
||||
if got := w.Header().Get("Content-Length"); got != "11" {
|
||||
t.Errorf("Content-Length = %q, want 11", got)
|
||||
}
|
||||
if len(fs.markFetched) != 1 || fs.markFetched[0] != fileID {
|
||||
t.Errorf("expected MarkFetched(%s), got %v", fileID, fs.markFetched)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_DefaultsMimetypeWhenEmpty(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsID, []byte("data"), "x.bin", "")
|
||||
h := handlers.NewPendingUploadsHandler(fs)
|
||||
r := newRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Content-Type"); got != "application/octet-stream" {
|
||||
t.Errorf("Content-Type fallback = %q, want application/octet-stream", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_InvalidWorkspaceID_400(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
req := httptest.NewRequest(http.MethodGet, "/workspaces/not-a-uuid/pending-uploads/00000000-0000-0000-0000-000000000000/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_InvalidFileID_400(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
wsID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/not-a-uuid/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_NotFound_404(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
wsID := uuid.New()
|
||||
missing := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+missing.String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_StorageError_500(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
fs.getErr = errors.New("connection refused")
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
wsID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+uuid.New().String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_CrossWorkspaceBleed_404(t *testing.T) {
|
||||
// Token leak: workspace A's wsAuth-validated request tries to
|
||||
// pull workspace B's file_id. Handler must 404 even though the
|
||||
// row exists.
|
||||
fs := newFakeStorage()
|
||||
wsB := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsB, []byte("secret"), "leak.txt", "text/plain")
|
||||
|
||||
wsA := uuid.New()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsA.String()+"/pending-uploads/"+fileID.String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404 for cross-workspace bleed", w.Code)
|
||||
}
|
||||
// Critical: must not have leaked the bytes.
|
||||
if strings.Contains(w.Body.String(), "secret") {
|
||||
t.Errorf("response body leaked content from another workspace: %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetContent_MarkFetchedFailureLoggedNotPropagated(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsID, []byte("ok"), "x.txt", "text/plain")
|
||||
fs.markErr = errors.New("update failed (sweep raced)")
|
||||
h := handlers.NewPendingUploadsHandler(fs)
|
||||
r := newRouter(h)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Body already returned 200 OK + bytes BEFORE the MarkFetched
|
||||
// failure — workspace fetch must NOT fail because of an
|
||||
// observability hook.
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status=%d, want 200 even on MarkFetched failure", w.Code)
|
||||
}
|
||||
if w.Body.String() != "ok" {
|
||||
t.Errorf("body = %q, want %q", w.Body.String(), "ok")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Ack ----
|
||||
|
||||
func TestAck_HappyPath_RemovesRow(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsID, []byte("data"), "x.bin", "")
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if body["acked"] != true {
|
||||
t.Errorf("body.acked = %v, want true", body["acked"])
|
||||
}
|
||||
if _, exists := fs.rows[fileID]; exists {
|
||||
t.Errorf("row should have been removed after ack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_NonExistent_404(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
wsID := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+uuid.New().String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_CrossWorkspaceBleed_404(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsB := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsB, []byte("data"), "x.bin", "")
|
||||
wsA := uuid.New()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+wsA.String()+"/pending-uploads/"+fileID.String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status=%d, want 404 for cross-workspace ack", w.Code)
|
||||
}
|
||||
// Row must remain — workspace A's bogus ack must NOT delete
|
||||
// workspace B's file.
|
||||
if _, exists := fs.rows[fileID]; !exists {
|
||||
t.Errorf("row should NOT have been removed by cross-workspace ack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_InvalidWorkspaceID_400(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/not-a-uuid/pending-uploads/"+uuid.New().String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_InvalidFileID_400(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+uuid.New().String()+"/pending-uploads/not-a-uuid/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_GetStorageError_500(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
fs.getErr = errors.New("conn lost")
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+uuid.New().String()+"/pending-uploads/"+uuid.New().String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_RaceWithSweep_ReturnsRacedTrue(t *testing.T) {
|
||||
// Sweep deletes the row between the handler's Get and Ack calls.
|
||||
// Storage.Ack returns ErrNotFound; handler treats that as success
|
||||
// (intent honored, row gone) and reports raced:true.
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsID, []byte("data"), "x.bin", "")
|
||||
fs.ackErr = pendinguploads.ErrNotFound
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200 on race", w.Code)
|
||||
}
|
||||
var body map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &body)
|
||||
if body["acked"] != true || body["raced"] != true {
|
||||
t.Errorf("expected acked=true raced=true, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_StorageError_500(t *testing.T) {
|
||||
fs := newFakeStorage()
|
||||
wsID := uuid.New()
|
||||
fileID, _ := fs.Put(context.Background(), wsID, []byte("data"), "x.bin", "")
|
||||
fs.ackErr = errors.New("conn refused")
|
||||
r := newRouter(handlers.NewPendingUploadsHandler(fs))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/"+wsID.String()+"/pending-uploads/"+fileID.String()+"/ack", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package handlers
|
||||
|
||||
// provlog_emit_test.go — pins that the structured-logging emit sites
|
||||
// added for #2867 PR-D actually fire when their boundary is crossed.
|
||||
//
|
||||
// These are call-site contract tests, not provlog package tests (those
|
||||
// live next to the helper). The assertion is "this dispatcher path
|
||||
// emits this event name" — if a refactor moves the call out of the
|
||||
// boundary helper, the gate fails. Fields are NOT pinned here on
|
||||
// purpose; the field set is convenience for ops, not contract for the
|
||||
// emit point. Pinning fields would block additive evolution of the
|
||||
// payload (see also feedback_behavior_based_ast_gates.md).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// captureProvLog redirects the global logger to a buffer for the test
|
||||
// duration. provlog.Event uses log.Printf, so this is the only seam.
|
||||
// Returned mutex protects against concurrent reads from the goroutine
|
||||
// fired by provisionWorkspaceAuto (the goroutine never returns in
|
||||
// these tests because Start() is stubbed, but the buffer can still be
|
||||
// touched by it racing the assertion).
|
||||
func captureProvLog(t *testing.T) (read func() string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
var mu sync.Mutex
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetFlags(0)
|
||||
log.SetOutput(&safeWriter{buf: &buf, mu: &mu})
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
return func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionWorkspaceAutoSync_EmitsProvisionStart — sync variant is
|
||||
// chosen for the assertion path because it returns once the (stubbed)
|
||||
// Start() has been called, so we know the emit has flushed. The async
|
||||
// variant would race a goroutine.
|
||||
func TestProvisionWorkspaceAutoSync_EmitsProvisionStart(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
// Best-effort: the body will hit DB code under provisionWorkspaceCP
|
||||
// — we only need the emit at the entry, which fires unconditionally
|
||||
// before the dispatch. Recovering from any later panic keeps the
|
||||
// test focused.
|
||||
defer func() { _ = recover() }()
|
||||
h.provisionWorkspaceAutoSync("ws-test-1", "tmpl", nil, models.CreateWorkspacePayload{
|
||||
Name: "n", Tier: 4, Runtime: "claude-code",
|
||||
})
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: provision.start ") {
|
||||
t.Fatalf("expected provision.start emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-test-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"sync":true`) {
|
||||
t.Errorf("sync flag not pinned for sync dispatcher: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsRestartPreStop — emit fires before the actual
|
||||
// Stop call, so the trackingCPProv stub doesn't need to be wired for
|
||||
// real Stop semantics. Backend label "cp" pinned because that's the
|
||||
// SaaS path; we don't pin "docker" or "none" branches here (separate
|
||||
// tests would only re-test the trivial branch label switch).
|
||||
func TestStopForRestart_EmitsRestartPreStop(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
defer func() { _ = recover() }()
|
||||
h.stopForRestart(context.Background(), "ws-restart-1")
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: restart.pre_stop ") {
|
||||
t.Fatalf("expected restart.pre_stop emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-restart-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"backend":"cp"`) {
|
||||
t.Errorf("backend label missing or wrong: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsBackendNoneWhenUnwired — pin the no-backend
|
||||
// branch so a future refactor that drops the label switch is caught.
|
||||
// This is the silent-Stop case (workspace_dispatchers.go:StopWorkspaceAuto
|
||||
// returns nil for unwired backends); the emit ensures the operator can
|
||||
// still see the boundary in the log.
|
||||
func TestStopForRestart_EmitsBackendNoneWhenUnwired(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{} // both nil
|
||||
h.stopForRestart(context.Background(), "ws-restart-2")
|
||||
got := read()
|
||||
if !strings.Contains(got, `"backend":"none"`) {
|
||||
t.Fatalf("expected backend=none for unwired handler: %s", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
// Tests for the SaaS-aware default-tier resolution introduced in #2901
|
||||
// and hardened in #2910 (multi-model review of #2901 found the original
|
||||
// claim of "all green" was passing because no SaaS-mode test existed).
|
||||
//
|
||||
// These tests pin three invariants:
|
||||
//
|
||||
// 1. WorkspaceHandler.IsSaaS() returns true when cpProv is wired,
|
||||
// false otherwise.
|
||||
// 2. WorkspaceHandler.DefaultTier() returns 4 on SaaS, 3 self-hosted.
|
||||
// 3. generateDefaultConfig (TemplatesHandler.Import path) writes the
|
||||
// passed-in tier into the generated config.yaml — pre-#2910 it
|
||||
// was hardcoded to 3 and silently disagreed with the create-
|
||||
// handler default on SaaS.
|
||||
|
||||
// stubCPProv is a minimal stand-in for the CP provisioner — only
|
||||
// exercises the IsSaaS / HasProvisioner contract, never invoked in
|
||||
// these tests.
|
||||
type stubCPProv struct{}
|
||||
|
||||
func (stubCPProv) Start(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (stubCPProv) Stop(_ interface{}, _ string) error { return nil }
|
||||
func (stubCPProv) Restart(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestIsSaaS_TrueWhenCPProvWired(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if !h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=false with cpProv wired; expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSaaS_FalseWhenOnlyDocker(t *testing.T) {
|
||||
// provisioner field set, cpProv nil — the self-hosted path.
|
||||
// Use a non-nil sentinel so the check actually has something to
|
||||
// disagree with. trackingCPProv lives in workspace_provision_auto_test.go
|
||||
// and is the established stub for these handler-level tests.
|
||||
h := &WorkspaceHandler{provisioner: nil, cpProv: nil}
|
||||
if h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=true with both backends nil; expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SaaS_IsT4(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if got := h.DefaultTier(); got != 4 {
|
||||
t.Errorf("SaaS DefaultTier()=%d; expected 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SelfHosted_IsT3(t *testing.T) {
|
||||
h := &WorkspaceHandler{}
|
||||
if got := h.DefaultTier(); got != 3 {
|
||||
t.Errorf("self-hosted DefaultTier()=%d; expected 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
// generateDefaultConfig — pin that the tier param flows into the
|
||||
// emitted config.yaml verbatim. Pre-#2910 this was hardcoded "tier: 3"
|
||||
// regardless of caller intent.
|
||||
func TestGenerateDefaultConfig_RespectsTierParam(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 4)
|
||||
if !strings.Contains(cfg, "tier: 4\n") {
|
||||
t.Errorf("expected `tier: 4` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
// The pre-#2910 hardcoded `tier: 3` line must NOT appear.
|
||||
if strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("config should not contain `tier: 3` when caller passed 4, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDefaultConfig_SelfHostedTierT3(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 3)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("expected `tier: 3` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// Bounds check — caller passes 0 or out-of-range, helper falls back
|
||||
// to T3 (the safer-of-the-two when deployment mode can't be resolved).
|
||||
func TestGenerateDefaultConfig_OutOfRangeFallsBackToT3(t *testing.T) {
|
||||
for _, tier := range []int{0, -1, 99} {
|
||||
cfg := generateDefaultConfig("X", map[string]string{}, tier)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("invalid tier %d should fall back to T3, got:\n%s", tier, cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/handlers"
|
||||
)
|
||||
|
||||
// SanitizeFilename mirrors workspace/internal_chat_uploads.py's
|
||||
// sanitize_filename. Drift between the two means canvas-emitted URIs
|
||||
// differ between push and poll paths for the same upload — pin every
|
||||
// case the Python suite pins (workspace/tests/test_internal_chat_uploads.py
|
||||
// :: test_sanitize_filename).
|
||||
|
||||
func TestSanitizeFilename_StripsPathTraversal(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"../../etc/passwd": "passwd",
|
||||
"/etc/passwd": "passwd",
|
||||
"a/b/c.txt": "c.txt",
|
||||
"./relative": "relative",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := handlers.SanitizeFilename(in); got != want {
|
||||
t.Errorf("SanitizeFilename(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_ReplacesUnsafeChars(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"hello world.pdf": "hello_world.pdf",
|
||||
"weird;chars!?.txt": "weird_chars__.txt",
|
||||
"中文.docx": "__.docx", // non-ASCII → underscore (each rune)
|
||||
"file (1).pdf": "file__1_.pdf",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := handlers.SanitizeFilename(in); got != want {
|
||||
t.Errorf("SanitizeFilename(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_PreservesAllowedChars(t *testing.T) {
|
||||
in := "report-2026.05.04_v2.pdf"
|
||||
if got := handlers.SanitizeFilename(in); got != in {
|
||||
t.Errorf("SanitizeFilename(%q) = %q, want unchanged", in, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_CapsAt100Chars_PreservesShortExtension(t *testing.T) {
|
||||
// 95-char base + ".pdf" (4 chars + dot) = 100 chars total — fits.
|
||||
base := strings.Repeat("a", 95)
|
||||
in := base + ".pdf"
|
||||
got := handlers.SanitizeFilename(in)
|
||||
if got != in {
|
||||
t.Errorf("expected unchanged at 100 chars, got %q (len=%d)", got, len(got))
|
||||
}
|
||||
|
||||
// 200-char base + ".pdf" → truncated to 100 with .pdf preserved.
|
||||
long := strings.Repeat("b", 200) + ".pdf"
|
||||
got = handlers.SanitizeFilename(long)
|
||||
if len(got) != 100 {
|
||||
t.Errorf("expected length 100, got %d (%q)", len(got), got)
|
||||
}
|
||||
if !strings.HasSuffix(got, ".pdf") {
|
||||
t.Errorf("expected .pdf suffix preserved, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_DropsLongExtension(t *testing.T) {
|
||||
// Extension > 16 chars is treated as part of the name; truncation
|
||||
// drops it without preservation. Mirrors the Python rule
|
||||
// (dot >= 0 AND len(base) - dot <= 16).
|
||||
long := strings.Repeat("c", 90) + ".thisisaverylongextensionnotpreserved"
|
||||
got := handlers.SanitizeFilename(long)
|
||||
if len(got) != 100 {
|
||||
t.Errorf("expected 100, got %d (%q)", len(got), got)
|
||||
}
|
||||
// First 100 chars of the SANITIZED input — extension not preserved.
|
||||
if strings.Contains(got, ".thisisaverylongextensionnotpreserved") {
|
||||
t.Errorf("long extension should have been truncated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_FallbackForReservedNames(t *testing.T) {
|
||||
cases := []string{"", ".", ".."}
|
||||
for _, in := range cases {
|
||||
if got := handlers.SanitizeFilename(in); got != "file" {
|
||||
t.Errorf("SanitizeFilename(%q) = %q, want %q", in, got, "file")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename_AllUnsafeBecomesAllUnderscores_NotReserved(t *testing.T) {
|
||||
// All-non-ASCII input becomes all-underscores — not "." or ".." or
|
||||
// empty, so the fallback path doesn't trigger and we get a real
|
||||
// (if uninformative) sanitized name.
|
||||
got := handlers.SanitizeFilename("中文中文")
|
||||
if got != "____" {
|
||||
t.Errorf("SanitizeFilename(中文中文) = %q, want %q", got, "____")
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func TestSecurity_GetTemplates_NoAuth_Returns401(t *testing.T) {
|
||||
authDB, authMock := newEnrolledAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
@@ -98,7 +98,7 @@ func TestSecurity_GetTemplates_FreshInstall_FailsOpen(t *testing.T) {
|
||||
authDB, authMock := newFreshInstallAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TeamHandler now hosts only Collapse — the visual "expand" action is
|
||||
// canvas-side and creating children goes through the regular
|
||||
// WorkspaceHandler.Create path with parent_id set, like any other
|
||||
// workspace. Every workspace can have children; "team" is just the
|
||||
// state of having children. The old Expand handler bulk-created
|
||||
// children by reading sub_workspaces from a parent's config and was
|
||||
// non-idempotent — calling it N times leaked N×children EC2s, which
|
||||
// is how tenant-hongming accumulated 72 stale workspaces.
|
||||
type TeamHandler struct {
|
||||
wh *WorkspaceHandler
|
||||
b *events.Broadcaster
|
||||
}
|
||||
|
||||
// NewTeamHandler constructs a TeamHandler. wh is used by Collapse to
|
||||
// route StopWorkspaceAuto through the backend dispatcher.
|
||||
func NewTeamHandler(b *events.Broadcaster, wh *WorkspaceHandler, platformURL, configsDir string) *TeamHandler {
|
||||
return &TeamHandler{wh: wh, b: b}
|
||||
}
|
||||
|
||||
// Collapse handles POST /workspaces/:id/collapse
|
||||
// Stops and removes all child workspaces.
|
||||
func (h *TeamHandler) Collapse(c *gin.Context) {
|
||||
parentID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find children
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, parentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to query children"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
removed := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var childID, childName string
|
||||
if rows.Scan(&childID, &childName) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Stop the workload via the backend dispatcher (CP for SaaS,
|
||||
// Docker for self-hosted). Pre-2026-05-05 this was
|
||||
// `if h.provisioner != nil { h.provisioner.Stop(...) }`, which
|
||||
// silently skipped on every SaaS tenant — child EC2s kept running
|
||||
// after team-collapse until the orphan sweeper caught them
|
||||
// (issue #2813).
|
||||
if err := h.wh.StopWorkspaceAuto(ctx, childID); err != nil {
|
||||
log.Printf("Team collapse: stop %s failed: %v — orphan sweeper will reconcile", childID, err)
|
||||
}
|
||||
|
||||
// Mark as removed
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusRemoved, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to remove workspace %s: %v", childID, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`DELETE FROM canvas_layouts WHERE workspace_id = $1`, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to delete layout for %s: %v", childID, err)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", childID, map[string]interface{}{})
|
||||
|
||||
removed = append(removed, childName)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_COLLAPSED", parentID, map[string]interface{}{
|
||||
"removed_children": removed,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "collapsed",
|
||||
"removed": removed,
|
||||
})
|
||||
}
|
||||
|
||||
// findTemplateDirByName resolves a workspace name to its template
|
||||
// directory. Kept here because callers outside this package may use
|
||||
// it, even though the in-package consumer (Expand) is gone.
|
||||
//
|
||||
// TODO: relocate alongside the templates handler if no other callers
|
||||
// surface, or delete entirely after a deprecation cycle.
|
||||
func findTemplateDirByName(configsDir, name string) string {
|
||||
normalized := normalizeName(name)
|
||||
|
||||
candidate := filepath.Join(configsDir, normalized)
|
||||
if _, err := os.Stat(filepath.Join(candidate, "config.yaml")); err == nil {
|
||||
return candidate
|
||||
}
|
||||
|
||||
// Fall back to scanning all dirs
|
||||
entries, err := os.ReadDir(configsDir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cfgPath := filepath.Join(configsDir, e.Name(), "config.yaml")
|
||||
data, err := os.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var cfg struct {
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
if yaml.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ---------- TeamHandler: Collapse ----------
|
||||
|
||||
func TestTeamCollapse_NoChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// No children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "collapsed" {
|
||||
t.Errorf("expected status 'collapsed', got %v", resp["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamCollapse_WithChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// Two children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("child-1", "Worker A").
|
||||
AddRow("child-2", "Worker B"))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-1
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-2
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast for parent
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
removed, ok := resp["removed"].([]interface{})
|
||||
if !ok || len(removed) != 2 {
|
||||
t.Errorf("expected 2 removed children, got %v", resp["removed"])
|
||||
}
|
||||
}
|
||||
// ---------- findTemplateDirByName helper ----------
|
||||
|
||||
func TestFindTemplateDirByName_DirectMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
subDir := filepath.Join(dir, "mybot")
|
||||
os.MkdirAll(subDir, 0755)
|
||||
os.WriteFile(filepath.Join(subDir, "config.yaml"), []byte("name: MyBot"), 0644)
|
||||
|
||||
result := findTemplateDirByName(dir, "mybot")
|
||||
if result != subDir {
|
||||
t.Errorf("expected %s, got %s", subDir, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
result := findTemplateDirByName(dir, "nonexistent")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_InvalidConfigsDir(t *testing.T) {
|
||||
result := findTemplateDirByName("/nonexistent/path", "anything")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for invalid dir, got %s", result)
|
||||
}
|
||||
}
|
||||
@@ -36,8 +36,14 @@ func normalizeName(name string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files and skills.
|
||||
func generateDefaultConfig(name string, files map[string]string) string {
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files
|
||||
// and skills. tier is the deployment-aware default (caller passes
|
||||
// h.wh.DefaultTier() — T4 on SaaS, T3 on self-hosted) so the generated
|
||||
// file matches what POST /workspaces would default to. Pre-#2910 this
|
||||
// was hardcoded to 3, which split-brained with the create-handler
|
||||
// default on SaaS (T4) and pinned newly-imported templates at T3 even
|
||||
// when downstream Create paths picked T4.
|
||||
func generateDefaultConfig(name string, files map[string]string, tier int) string {
|
||||
promptFiles := []string{}
|
||||
skillSet := map[string]bool{}
|
||||
|
||||
@@ -74,9 +80,15 @@ func generateDefaultConfig(name string, files map[string]string) string {
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString(`name: "` + escaped + `"` + "\n")
|
||||
cfg.WriteString("description: Imported agent\n")
|
||||
// Default to tier 3 ("Privileged") — matches the workspace.go
|
||||
// create handler default. See its comment for rationale.
|
||||
cfg.WriteString("version: 1.0.0\ntier: 3\n")
|
||||
// Tier is SaaS-aware via the caller's DefaultTier (#2910 PR-B).
|
||||
// Bounds-checked: invalid input falls back to T3 (the historical
|
||||
// default + the safer-of-the-two when the deployment mode can't
|
||||
// be resolved).
|
||||
if tier < 1 || tier > 4 {
|
||||
tier = 3
|
||||
}
|
||||
cfg.WriteString("version: 1.0.0\n")
|
||||
cfg.WriteString(fmt.Sprintf("tier: %d\n", tier))
|
||||
cfg.WriteString("model: anthropic:claude-haiku-4-5-20251001\n")
|
||||
cfg.WriteString("\nprompt_files:\n")
|
||||
if len(promptFiles) > 0 {
|
||||
@@ -148,7 +160,11 @@ func (h *TemplatesHandler) Import(c *gin.Context) {
|
||||
|
||||
// Auto-generate config.yaml if not provided
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
cfg := generateDefaultConfig(body.Name, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(body.Name, body.Files, tier)
|
||||
if err := os.WriteFile(filepath.Join(destDir, "config.yaml"), []byte(cfg), 0600); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write config.yaml"})
|
||||
return
|
||||
@@ -227,7 +243,11 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
// Check if config.yaml exists in container
|
||||
if _, err := h.execInContainer(ctx, containerName, []string{"test", "-f", "/configs/config.yaml"}); err != nil {
|
||||
cfg := generateDefaultConfig(wsName, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(wsName, body.Files, tier)
|
||||
singleFile := map[string]string{"config.yaml": cfg}
|
||||
h.copyFilesToContainer(ctx, containerName, "/configs", singleFile)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGenerateDefaultConfig_WithFiles(t *testing.T) {
|
||||
"skills/review/templates.md": "Templates",
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Test Agent", files)
|
||||
cfg := generateDefaultConfig("Test Agent", files, 3)
|
||||
|
||||
// Name is emitted as a double-quoted scalar (#221 sanitizer).
|
||||
if !strings.Contains(cfg, `name: "Test Agent"`) {
|
||||
@@ -85,7 +85,7 @@ func TestGenerateDefaultConfig_Empty(t *testing.T) {
|
||||
"data/something.json": `{"key": "value"}`,
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Empty Agent", files)
|
||||
cfg := generateDefaultConfig("Empty Agent", files, 3)
|
||||
|
||||
if !strings.Contains(cfg, `name: "Empty Agent"`) {
|
||||
t.Errorf("config should contain quoted agent name, got:\n%s", cfg)
|
||||
@@ -134,7 +134,7 @@ func TestGenerateDefaultConfig_YAMLInjection(t *testing.T) {
|
||||
|
||||
for _, tc := range adversarialCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{})
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{}, 3)
|
||||
var parsed map[string]interface{}
|
||||
if err := yaml.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("sanitized config does not parse as YAML: %v\n--- config ---\n%s", err, cfg)
|
||||
@@ -205,7 +205,7 @@ func TestImport_Success(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "New Agent",
|
||||
@@ -245,7 +245,7 @@ func TestImport_MissingName(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
body := `{"files": {"test.md": "content"}}`
|
||||
|
||||
@@ -265,7 +265,7 @@ func TestImport_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -296,7 +296,7 @@ func TestImport_AlreadyExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.MkdirAll(filepath.Join(tmpDir, "existing-agent"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{"name": "Existing Agent", "files": {"test.md": "content"}}`
|
||||
|
||||
@@ -317,7 +317,7 @@ func TestImport_WithConfigYaml(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "Custom Agent",
|
||||
@@ -354,7 +354,7 @@ func TestReplaceFiles_MissingBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -373,7 +373,7 @@ func TestReplaceFiles_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -398,7 +398,7 @@ func TestReplaceFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
// ReplaceFiles now selects (name, instance_id, runtime) for the
|
||||
// restart-cascade. Match the full column list rather than just the
|
||||
@@ -429,7 +429,7 @@ func TestReplaceFiles_PathTraversal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-rf-pt").
|
||||
|
||||
@@ -31,10 +31,20 @@ const maxUploadFiles = 200
|
||||
type TemplatesHandler struct {
|
||||
configsDir string
|
||||
docker *client.Client
|
||||
// wh is used by Import and ReplaceFiles to call DefaultTier() so a
|
||||
// generated config.yaml's tier matches the SaaS-vs-self-hosted
|
||||
// boundary (#2910 PR-B). nil-tolerant — the field is unused when
|
||||
// the caller doesn't import templates that need a fresh config
|
||||
// generated.
|
||||
wh *WorkspaceHandler
|
||||
}
|
||||
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli}
|
||||
// NewTemplatesHandler constructs a TemplatesHandler. wh may be nil for
|
||||
// callers that only use the read-only template surfaces (List,
|
||||
// ReadFile, ListFiles). Import + ReplaceFiles need wh non-nil so the
|
||||
// generated config.yaml picks the SaaS-aware default tier.
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client, wh *WorkspaceHandler) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli, wh: wh}
|
||||
}
|
||||
|
||||
// modelSpec describes a single supported model on a template: its id (sent
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestTemplatesList_EmptyDir(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -99,7 +99,7 @@ skills:
|
||||
// Create a directory without config.yaml (should be skipped)
|
||||
os.MkdirAll(filepath.Join(tmpDir, "no-config"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -160,7 +160,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -237,7 +237,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -315,7 +315,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -434,7 +434,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -512,7 +512,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -555,7 +555,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -589,7 +589,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -661,7 +661,7 @@ skills: []
|
||||
log.SetOutput(&logBuf)
|
||||
defer log.SetOutput(prevOutput)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -698,7 +698,7 @@ func TestTemplatesList_NonexistentDir(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil)
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -723,7 +723,7 @@ func TestListFiles_InvalidRoot(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -748,7 +748,7 @@ func TestListFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-nonexist").
|
||||
@@ -775,7 +775,7 @@ func TestListFiles_FallbackToHost_NoTemplate(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil) // nil docker = no container
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil) // nil docker = no container
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-fallback").
|
||||
@@ -815,7 +815,7 @@ func TestListFiles_FallbackToHost_WithTemplate(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Test Agent\n"), 0644)
|
||||
os.WriteFile(filepath.Join(tmplDir, "system-prompt.md"), []byte("# prompt"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-tmpl").
|
||||
@@ -849,7 +849,7 @@ func TestReadFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -870,7 +870,7 @@ func TestReadFile_InvalidRoot(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -892,7 +892,7 @@ func TestReadFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nf").
|
||||
@@ -926,7 +926,7 @@ func TestReadFile_FallbackToHost_Success(t *testing.T) {
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Reader Agent\ntier: 1\n"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
// instance_id="" → SaaS branch skipped → falls through to local
|
||||
// Docker / template-dir host fallback (the only path the test
|
||||
@@ -967,7 +967,7 @@ func TestReadFile_FallbackToHost_NotFound(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nofile").
|
||||
@@ -999,7 +999,7 @@ func TestWriteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1023,7 +1023,7 @@ func TestWriteFile_InvalidBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1046,7 +1046,7 @@ func TestWriteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-wf-nf").
|
||||
@@ -1080,7 +1080,7 @@ func TestDeleteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1101,7 +1101,7 @@ func TestDeleteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-del-nf").
|
||||
@@ -1133,7 +1133,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
tmplDir := filepath.Join(tmpDir, "my-agent")
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("My Agent")
|
||||
|
||||
if result != tmplDir {
|
||||
@@ -1143,7 +1143,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
|
||||
func TestResolveTemplateDir_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("Nonexistent Agent")
|
||||
|
||||
if result != "" {
|
||||
@@ -1177,7 +1177,7 @@ func TestCWE78_DeleteFile_TraversalVariants(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -148,15 +148,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
id := uuid.New().String()
|
||||
awarenessNamespace := workspaceAwarenessNamespace(id)
|
||||
if payload.Tier == 0 {
|
||||
// Default to T3 ("Privileged"). T3 gives agents a read_write
|
||||
// workspace mount + Docker daemon access — the level most
|
||||
// templates need to do real work. Lower tiers (T1 sandboxed,
|
||||
// T2 standard) stay available as explicit opt-ins for
|
||||
// low-trust agents. Matches the Canvas CreateWorkspaceDialog
|
||||
// default for self-hosted hosts (SaaS defaults to T4 via
|
||||
// CreateWorkspaceDialog because each SaaS workspace runs on
|
||||
// its own sibling EC2).
|
||||
payload.Tier = 3
|
||||
// SaaS-aware default. SaaS → T4 (full host access; each
|
||||
// workspace runs on its own sibling EC2 so the tier boundary
|
||||
// is a Docker resource limit on the only container present —
|
||||
// no neighbour to protect from). Self-hosted → T3 (read-write
|
||||
// workspace mount + Docker daemon access, most templates'
|
||||
// baseline). Lower tiers (T1 sandboxed, T2 standard) remain
|
||||
// explicit opt-ins for low-trust agents. Matches the canvas
|
||||
// CreateWorkspaceDialog defaults so the API and the UI agree.
|
||||
payload.Tier = h.DefaultTier()
|
||||
}
|
||||
|
||||
// Detect runtime + default model from template config.yaml when the
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// HasProvisioner reports whether either backend (CP or local Docker) is
|
||||
@@ -49,6 +50,32 @@ func (h *WorkspaceHandler) HasProvisioner() bool {
|
||||
return h.cpProv != nil || h.provisioner != nil
|
||||
}
|
||||
|
||||
// IsSaaS reports whether the CP (EC2) provisioner is wired. Each SaaS
|
||||
// workspace runs on its own sibling EC2, so the per-workspace tier
|
||||
// boundary is a Docker resource limit applied to the only container
|
||||
// on that EC2 — there's no neighbour to protect from. Self-hosted
|
||||
// runs many workspaces in one Docker daemon on a single host, so
|
||||
// the tier-2-by-default safe-neighbour-share posture stays.
|
||||
//
|
||||
// Tier defaults across Create / OrgImport / canvas EmptyState branch
|
||||
// on IsSaaS so SaaS users get T4 (full host access) by default and
|
||||
// self-hosted users keep the lower-trust caps.
|
||||
func (h *WorkspaceHandler) IsSaaS() bool {
|
||||
return h.cpProv != nil
|
||||
}
|
||||
|
||||
// DefaultTier is the SaaS-aware default tier. T4 on SaaS (single
|
||||
// container per EC2 — full host access matches the boundary), T3 on
|
||||
// self-hosted (read-write workspace mount + Docker daemon access,
|
||||
// most templates' baseline). Callers default to this when the user
|
||||
// hasn't explicitly picked a tier.
|
||||
func (h *WorkspaceHandler) DefaultTier() int {
|
||||
if h.IsSaaS() {
|
||||
return 4
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker
|
||||
// for self-hosted) and starts provisioning in a goroutine. Returns true
|
||||
// when a backend was kicked off, false when neither is wired.
|
||||
@@ -75,6 +102,14 @@ func (h *WorkspaceHandler) HasProvisioner() bool {
|
||||
// lives in prepareProvisionContext (shared by both per-backend
|
||||
// goroutines).
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": false,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
@@ -110,6 +145,14 @@ func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath stri
|
||||
// Keep these two helpers in sync — when one grows a new arm (third
|
||||
// backend, retry semantics), the other should too.
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAutoSync(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": true,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
|
||||
@@ -534,11 +534,10 @@ func (h *WorkspaceHandler) ensureDefaultConfig(workspaceID string, payload model
|
||||
// Generate a minimal config.yaml
|
||||
model := payload.Model
|
||||
if model == "" {
|
||||
if runtime == "claude-code" {
|
||||
model = "sonnet"
|
||||
} else {
|
||||
model = "anthropic:claude-opus-4-7"
|
||||
}
|
||||
// SSOT: per-runtime defaults live in models/runtime_defaults.go
|
||||
// (see RFC #2873). Was previously duplicated here AND in
|
||||
// org_import.go; consolidating prevents silent drift.
|
||||
model = models.DefaultModel(runtime)
|
||||
}
|
||||
|
||||
// Sanitize name/role/model for YAML safety — always double-quote so
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -431,6 +432,16 @@ func coalesceRestart(workspaceID string, cycle func()) {
|
||||
// NPE'd before reaching the reprovision step — which is why every SaaS dead-
|
||||
// agent incident pre-this-fix required manual restart from canvas.
|
||||
func (h *WorkspaceHandler) stopForRestart(ctx context.Context, workspaceID string) {
|
||||
backend := "none"
|
||||
if h.provisioner != nil {
|
||||
backend = "docker"
|
||||
} else if h.cpProv != nil {
|
||||
backend = "cp"
|
||||
}
|
||||
provlog.Event("restart.pre_stop", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"backend": backend,
|
||||
})
|
||||
if h.provisioner != nil {
|
||||
h.provisioner.Stop(ctx, workspaceID)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestINSERTworkspacesAllowlist enumerates every function in this
|
||||
// package that emits an `INSERT INTO workspaces (` SQL literal, and
|
||||
// pins the result against an explicit allowlist. New entries fail the
|
||||
// build until a reviewer adds them — forcing the question "what
|
||||
// makes this INSERT idempotent?" at PR-review time, not after the
|
||||
// next bulk-create leak.
|
||||
//
|
||||
// Pairs with TestCreateWorkspaceTree_CallsLookupBeforeInsert (the
|
||||
// behavior pin for the one bulk path). Together they close the
|
||||
// regression class: this test catches "did a new function start
|
||||
// inserting workspaces?", that test catches "did the existing bulk
|
||||
// path drop its idempotency check?". Either fires immediately when
|
||||
// drift happens.
|
||||
//
|
||||
// Why allowlist rather than pure behavior gate (per memory
|
||||
// feedback_behavior_based_ast_gates.md): the bulk-create leak class
|
||||
// is small + stable (1 path today), and a behavior gate would have
|
||||
// to disambiguate "iterating a YAML array of workspaces" from the
|
||||
// many other `for ... range` patterns in a Create handler (config
|
||||
// lines, secrets map, channels). Type-info-aware AST analysis would
|
||||
// catch the YAML-iteration shape but is heavy. Allowlisting is the
|
||||
// minimum-viable pin: any PR that adds a new INSERT site is forced
|
||||
// to pause, add an entry here, and document the safety mechanism in
|
||||
// the comment alongside.
|
||||
//
|
||||
// RFC #2867 class 1.
|
||||
func TestINSERTworkspacesAllowlist(t *testing.T) {
|
||||
// expected[key] = safety mechanism. Keep the comment pinned to
|
||||
// what makes that function safe — if the safety changes, the
|
||||
// allowlist must be re-reviewed.
|
||||
expected := map[string]string{
|
||||
// org_import.createWorkspaceTree: lookupExistingChild
|
||||
// before INSERT (#2868 phase 3). Also pinned by
|
||||
// TestCreateWorkspaceTree_CallsLookupBeforeInsert.
|
||||
"org_import.go:createWorkspaceTree": "lookup-then-insert via lookupExistingChild",
|
||||
// registry.Register: external workspace registers itself with
|
||||
// its known UUID; INSERT is idempotent via ON CONFLICT (id)
|
||||
// DO UPDATE — re-registration upserts, never duplicates.
|
||||
"registry.go:Register": "ON CONFLICT (id) DO UPDATE",
|
||||
// workspace.Create: single-workspace POST /workspaces from a
|
||||
// human or automation. No iteration; payload describes one
|
||||
// workspace; UUID is server-generated. Caller intent IS to
|
||||
// create, so no idempotency check is needed.
|
||||
"workspace.go:Create": "single-workspace POST, server-generated UUID",
|
||||
}
|
||||
|
||||
actual := map[string]string{}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("readdir %s: %v", wd, err)
|
||||
}
|
||||
for _, ent := range entries {
|
||||
name := ent.Name()
|
||||
if ent.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".go") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(wd, name)
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s: %v", path, err)
|
||||
}
|
||||
// For each top-level FuncDecl, walk its body and check for an
|
||||
// `INSERT INTO workspaces (` SQL literal in any CallExpr arg.
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fn.Body == nil {
|
||||
continue
|
||||
}
|
||||
var foundInsert bool
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
lit, ok := n.(*ast.BasicLit)
|
||||
if !ok || lit.Kind != token.STRING {
|
||||
return true
|
||||
}
|
||||
raw := lit.Value
|
||||
if unq, err := strconv.Unquote(raw); err == nil {
|
||||
raw = unq
|
||||
}
|
||||
if workspacesInsertRE.MatchString(raw) {
|
||||
foundInsert = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if foundInsert {
|
||||
key := name + ":" + fn.Name.Name
|
||||
actual[key] = "(observed via AST walk)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute set diffs so failures point at the specific drift.
|
||||
missing := []string{}
|
||||
unexpected := []string{}
|
||||
for k := range expected {
|
||||
if _, ok := actual[k]; !ok {
|
||||
missing = append(missing, k)
|
||||
}
|
||||
}
|
||||
for k := range actual {
|
||||
if _, ok := expected[k]; !ok {
|
||||
unexpected = append(unexpected, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(missing)
|
||||
sort.Strings(unexpected)
|
||||
|
||||
if len(unexpected) > 0 {
|
||||
t.Errorf(`new function(s) emit `+"`INSERT INTO workspaces (`"+` and aren't in the allowlist:
|
||||
%s
|
||||
|
||||
If this is a legitimate addition, add an entry to expected[] in this test
|
||||
with the safety mechanism pinned in the comment alongside (lookup-then-
|
||||
insert / ON CONFLICT / single-workspace path / etc.). The bulk-create
|
||||
regression class needs explicit per-handler review, not silent drift.
|
||||
|
||||
Reference: RFC #2867 class 1, sibling test
|
||||
TestCreateWorkspaceTree_CallsLookupBeforeInsert.`,
|
||||
strings.Join(unexpected, "\n "))
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Errorf(`expected function(s) no longer emit `+"`INSERT INTO workspaces (`"+`:
|
||||
%s
|
||||
|
||||
Either the function was renamed/deleted (update the allowlist) or the
|
||||
INSERT was moved out (verify the new home is also covered). Don't just
|
||||
delete the entry — confirm the safety mechanism is still in place
|
||||
elsewhere or that the workspace-create path was intentionally
|
||||
restructured.`,
|
||||
strings.Join(missing, "\n "))
|
||||
}
|
||||
}
|
||||
@@ -5,14 +5,15 @@
|
||||
//
|
||||
// Exposed metrics:
|
||||
//
|
||||
// molecule_http_requests_total{method,path,status} - counter
|
||||
// molecule_http_request_duration_seconds{method,path} - counter (sum, for avg rate)
|
||||
// molecule_websocket_connections_active - gauge
|
||||
// go_goroutines - gauge
|
||||
// go_memstats_alloc_bytes - gauge
|
||||
// go_memstats_sys_bytes - gauge
|
||||
// go_memstats_heap_inuse_bytes - gauge
|
||||
// go_gc_duration_seconds_total - counter
|
||||
// molecule_http_requests_total{method,path,status} - counter
|
||||
// molecule_http_request_duration_seconds{method,path} - counter (sum, for avg rate)
|
||||
// molecule_websocket_connections_active - gauge
|
||||
// molecule_pending_uploads_swept_total{outcome} - counter (acked|expired|error)
|
||||
// go_goroutines - gauge
|
||||
// go_memstats_alloc_bytes - gauge
|
||||
// go_memstats_sys_bytes - gauge
|
||||
// go_memstats_heap_inuse_bytes - gauge
|
||||
// go_gc_duration_seconds_total - counter
|
||||
package metrics
|
||||
|
||||
import (
|
||||
@@ -38,6 +39,12 @@ var (
|
||||
reqCounts = map[reqKey]int64{} // molecule_http_requests_total
|
||||
reqDurSums = map[reqKey]float64{} // sum of durations (seconds)
|
||||
activeWSConns int64 // molecule_websocket_connections_active
|
||||
|
||||
// pendinguploads sweeper counters — atomic so the sweeper goroutine
|
||||
// doesn't contend with the /metrics handler.
|
||||
pendingUploadsSweptAcked int64 // molecule_pending_uploads_swept_total{outcome="acked"}
|
||||
pendingUploadsSweptExpired int64 // molecule_pending_uploads_swept_total{outcome="expired"}
|
||||
pendingUploadsSweepErrors int64 // molecule_pending_uploads_swept_total{outcome="error"}
|
||||
)
|
||||
|
||||
// Middleware records per-request counts and latency.
|
||||
@@ -76,6 +83,50 @@ func TrackWSConnect() { atomic.AddInt64(&activeWSConns, 1) }
|
||||
// Call from the WebSocket disconnect / cleanup path.
|
||||
func TrackWSDisconnect() { atomic.AddInt64(&activeWSConns, -1) }
|
||||
|
||||
// phantomBusyResets is the cumulative count of workspace rows the
|
||||
// phantom-busy sweep reset (active_tasks=0 → active_tasks=0+counter
|
||||
// cleared). Surfaced as molecule_phantom_busy_resets_total — a high
|
||||
// reset rate signals a regression in task-lifecycle accounting (most
|
||||
// often: missing env vars cause claude --print to time out, the
|
||||
// agent loop never decrements active_tasks, and the sweep cleans up
|
||||
// the counter ~10 min later). Issue #2865.
|
||||
var phantomBusyResets int64
|
||||
|
||||
// TrackPhantomBusyReset increments the phantom-busy reset counter.
|
||||
// Called from sweepPhantomBusy in workspace-server/internal/scheduler/
|
||||
// after each row whose active_tasks was reset to 0. Idempotent +
|
||||
// goroutine-safe; called once per row per sweep tick.
|
||||
func TrackPhantomBusyReset() { atomic.AddInt64(&phantomBusyResets, 1) }
|
||||
|
||||
// PendingUploadsSwept records a successful sweep cycle. acked/expired
|
||||
// are added to the per-outcome counters so dashboards can spot the
|
||||
// stuck-fetch pattern (high expired, low acked) vs healthy churn.
|
||||
func PendingUploadsSwept(acked, expired int) {
|
||||
if acked > 0 {
|
||||
atomic.AddInt64(&pendingUploadsSweptAcked, int64(acked))
|
||||
}
|
||||
if expired > 0 {
|
||||
atomic.AddInt64(&pendingUploadsSweptExpired, int64(expired))
|
||||
}
|
||||
}
|
||||
|
||||
// PendingUploadsSweepError records a sweeper-cycle failure (transient
|
||||
// DB error etc). Counted separately so the rate of errored sweeps is
|
||||
// observable independent of how many rows the successful sweeps deleted.
|
||||
func PendingUploadsSweepError() {
|
||||
atomic.AddInt64(&pendingUploadsSweepErrors, 1)
|
||||
}
|
||||
|
||||
// PendingUploadsSweepCounts returns the current (acked, expired, error)
|
||||
// totals. Exposed for tests that need a deterministic delta probe of
|
||||
// the sweeper's metric writes — the /metrics endpoint is the production
|
||||
// observability surface; this is a unit-test escape hatch.
|
||||
func PendingUploadsSweepCounts() (acked, expired, errored int64) {
|
||||
return atomic.LoadInt64(&pendingUploadsSweptAcked),
|
||||
atomic.LoadInt64(&pendingUploadsSweptExpired),
|
||||
atomic.LoadInt64(&pendingUploadsSweepErrors)
|
||||
}
|
||||
|
||||
// Handler returns a Gin handler that serialises all collected metrics in
|
||||
// Prometheus text exposition format (v0.0.4). Mount this at GET /metrics.
|
||||
func Handler() gin.HandlerFunc {
|
||||
@@ -144,6 +195,21 @@ func Handler() gin.HandlerFunc {
|
||||
writeln(w, "# HELP molecule_websocket_connections_active Number of active WebSocket connections.")
|
||||
writeln(w, "# TYPE molecule_websocket_connections_active gauge")
|
||||
fmt.Fprintf(w, "molecule_websocket_connections_active %d\n", atomic.LoadInt64(&activeWSConns))
|
||||
|
||||
// ── Molecule AI scheduler ──────────────────────────────────────────────
|
||||
writeln(w, "# HELP molecule_phantom_busy_resets_total Cumulative count of workspace rows reset by the phantom-busy sweep (active_tasks cleared after >10 min of activity_log silence). High reset rate signals task-lifecycle accounting regressions — see issue #2865.")
|
||||
writeln(w, "# TYPE molecule_phantom_busy_resets_total counter")
|
||||
fmt.Fprintf(w, "molecule_phantom_busy_resets_total %d\n", atomic.LoadInt64(&phantomBusyResets))
|
||||
|
||||
// ── Pending-uploads sweeper ────────────────────────────────────────────
|
||||
writeln(w, "# HELP molecule_pending_uploads_swept_total Pending-uploads rows deleted by the GC sweeper, by outcome.")
|
||||
writeln(w, "# TYPE molecule_pending_uploads_swept_total counter")
|
||||
fmt.Fprintf(w, "molecule_pending_uploads_swept_total{outcome=\"acked\"} %d\n",
|
||||
atomic.LoadInt64(&pendingUploadsSweptAcked))
|
||||
fmt.Fprintf(w, "molecule_pending_uploads_swept_total{outcome=\"expired\"} %d\n",
|
||||
atomic.LoadInt64(&pendingUploadsSweptExpired))
|
||||
fmt.Fprintf(w, "molecule_pending_uploads_swept_total{outcome=\"error\"} %d\n",
|
||||
atomic.LoadInt64(&pendingUploadsSweepErrors))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
package metrics
|
||||
|
||||
// Tests for the phantom-busy reset counter wired up by issue #2865.
|
||||
// The counter is exposed at /metrics as
|
||||
// molecule_phantom_busy_resets_total. A high steady-state value
|
||||
// signals task-lifecycle accounting regressions in the agent loop —
|
||||
// see scheduler.sweepPhantomBusy for the writer.
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// resetForTest zeroes the counter so a single test's TrackPhantomBusyReset
|
||||
// calls don't compound onto a previous test's run. metrics.go's package-
|
||||
// level state means every test that touches the counter must reset.
|
||||
func resetForTest() {
|
||||
atomic.StoreInt64(&phantomBusyResets, 0)
|
||||
}
|
||||
|
||||
func TestTrackPhantomBusyReset_IncrementsCounter(t *testing.T) {
|
||||
resetForTest()
|
||||
for i := 0; i < 7; i++ {
|
||||
TrackPhantomBusyReset()
|
||||
}
|
||||
got := atomic.LoadInt64(&phantomBusyResets)
|
||||
if got != 7 {
|
||||
t.Errorf("counter after 7 calls = %d, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackPhantomBusyReset_RaceFreeUnderConcurrentWrites(t *testing.T) {
|
||||
resetForTest()
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
const callsPerGoroutine = 200
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < callsPerGoroutine; j++ {
|
||||
TrackPhantomBusyReset()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
want := int64(goroutines * callsPerGoroutine)
|
||||
got := atomic.LoadInt64(&phantomBusyResets)
|
||||
if got != want {
|
||||
t.Errorf("counter under concurrent writes = %d, want %d (lost increments → atomic broken)",
|
||||
got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_ExposesPhantomBusyResetsCounter(t *testing.T) {
|
||||
resetForTest()
|
||||
for i := 0; i < 3; i++ {
|
||||
TrackPhantomBusyReset()
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.GET("/metrics", Handler())
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/metrics", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
body := w.Body.String()
|
||||
// HELP + TYPE lines must precede the metric (Prometheus text exposition format).
|
||||
if !strings.Contains(body, "# HELP molecule_phantom_busy_resets_total") {
|
||||
t.Errorf("metrics output missing HELP line for molecule_phantom_busy_resets_total:\n%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "# TYPE molecule_phantom_busy_resets_total counter") {
|
||||
t.Errorf("metrics output missing TYPE line for molecule_phantom_busy_resets_total:\n%s", body)
|
||||
}
|
||||
if !strings.Contains(body, "molecule_phantom_busy_resets_total 3\n") {
|
||||
t.Errorf("metrics output missing counter value 3:\n%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_PhantomBusyResetsZeroByDefault(t *testing.T) {
|
||||
// Fresh process should report 0 — pin the contract so a future
|
||||
// refactor that lazy-inits the counter to nil doesn't silently
|
||||
// drop the metric from /metrics.
|
||||
resetForTest()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.GET("/metrics", Handler())
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/metrics", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if !strings.Contains(w.Body.String(), "molecule_phantom_busy_resets_total 0\n") {
|
||||
t.Errorf("metric must report 0 by default:\n%s", w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package models
|
||||
|
||||
// runtime_defaults.go — single source of truth for per-runtime defaults
|
||||
// the platform applies when the operator/agent didn't supply a value.
|
||||
//
|
||||
// Why this lives in models/ (not handlers/): default selection is a
|
||||
// pure data fact about the runtime, not handler logic. Multiple
|
||||
// callers (Create-workspace handler, org-import handler, future
|
||||
// auto-provision paths) need the same answer; concentrating the
|
||||
// rule here means one edit when a runtime's default changes.
|
||||
//
|
||||
// Related work (RFC #2873): this is the seed for a future
|
||||
// `RuntimeConfig` interface that will also expose `ProvisioningTimeout()`,
|
||||
// `CapabilitiesSupported()`, and other per-runtime facts. For now the
|
||||
// surface is one helper — extracted from the duplicate branch in
|
||||
// workspace_provision.go:537 and org_import.go:54 that diverged silently
|
||||
// during refactors before this consolidation.
|
||||
|
||||
// DefaultModel returns the model slug to use when a workspace is
|
||||
// created without an explicit model and the runtime can't infer one
|
||||
// from its own config.
|
||||
//
|
||||
// - claude-code: "sonnet" — Anthropic's CLI accepts the short
|
||||
// name and resolves it via the operator's anthropic-oauth or
|
||||
// ANTHROPIC_API_KEY chain.
|
||||
// - everything else (hermes, langgraph, crewai, autogen, deepagents,
|
||||
// codex, openclaw, gemini-cli, external, ""): a fully-qualified
|
||||
// vendor:model slug that the universal MODEL_PROVIDER chain in
|
||||
// molecule-core PR #247 can route via per-vendor required_env.
|
||||
//
|
||||
// The function never returns an empty string; an unknown runtime
|
||||
// gets the universal default rather than failing closed (matches the
|
||||
// pre-refactor behavior — both call sites used the same fallback).
|
||||
func DefaultModel(runtime string) string {
|
||||
if runtime == "claude-code" {
|
||||
return "sonnet"
|
||||
}
|
||||
return "anthropic:claude-opus-4-7"
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestDefaultModel pins the contract: known runtimes return their
|
||||
// expected default; unknowns and the empty string fall through to the
|
||||
// universal default. Add new runtimes here as `case` entries — pre-fix
|
||||
// adding a runtime required two source edits + an audit; post-SSOT it
|
||||
// requires one entry in DefaultModel + one assertion here.
|
||||
func TestDefaultModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
runtime string
|
||||
want string
|
||||
}{
|
||||
// Known runtimes.
|
||||
{"claude-code", "sonnet"},
|
||||
|
||||
// Universal fallback for everything else. Each runtime is named
|
||||
// explicitly so a future drift (e.g., adding a hermes-specific
|
||||
// branch) shows up as a failure on the runtime that drifted, not
|
||||
// as a generic "unknown" failure.
|
||||
{"hermes", "anthropic:claude-opus-4-7"},
|
||||
{"langgraph", "anthropic:claude-opus-4-7"},
|
||||
{"crewai", "anthropic:claude-opus-4-7"},
|
||||
{"autogen", "anthropic:claude-opus-4-7"},
|
||||
{"deepagents", "anthropic:claude-opus-4-7"},
|
||||
{"codex", "anthropic:claude-opus-4-7"},
|
||||
{"openclaw", "anthropic:claude-opus-4-7"},
|
||||
{"gemini-cli", "anthropic:claude-opus-4-7"},
|
||||
{"external", "anthropic:claude-opus-4-7"},
|
||||
|
||||
// Unknown / empty — fall through to universal default rather
|
||||
// than failing closed. Pre-refactor both call sites also fell
|
||||
// through; pinning the existing behavior, not changing it.
|
||||
{"", "anthropic:claude-opus-4-7"},
|
||||
{"some-future-runtime", "anthropic:claude-opus-4-7"},
|
||||
{"CLAUDE-CODE", "anthropic:claude-opus-4-7"}, // case-sensitive — matches prior behavior
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.runtime, func(t *testing.T) {
|
||||
got := DefaultModel(tc.runtime)
|
||||
if got != tc.want {
|
||||
t.Errorf("DefaultModel(%q) = %q, want %q", tc.runtime, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultModel_NeverEmpty — invariant: no input produces an empty
|
||||
// string. The handlers that consume this would write empty into
|
||||
// config.yaml, which the runtime then can't dispatch — pinning the
|
||||
// non-empty contract here protects against a future "return early on
|
||||
// unknown runtime" change that would silently break workspace creation.
|
||||
func TestDefaultModel_NeverEmpty(t *testing.T) {
|
||||
for _, runtime := range []string{
|
||||
"", "claude-code", "hermes", "unknown-runtime",
|
||||
} {
|
||||
if got := DefaultModel(runtime); got == "" {
|
||||
t.Errorf("DefaultModel(%q) returned empty string", runtime)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package pendinguploads
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to
|
||||
// the external test package. The production code uses StartSweeper
|
||||
// (which pins the canonical SweepInterval); tests pin a short interval
|
||||
// to exercise the ticker-driven cycle without burning real wall-clock
|
||||
// time. The Go convention `export_test.go` keeps this seam OUT of the
|
||||
// production binary — files ending in _test.go are stripped at build
|
||||
// time, so this re-export only exists during `go test`.
|
||||
func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
|
||||
startSweeperWithInterval(ctx, storage, ackRetention, interval)
|
||||
}
|
||||
@@ -0,0 +1,394 @@
|
||||
// Package pendinguploads is the platform-side staging layer for chat file
|
||||
// uploads bound for poll-mode workspaces (delivery_mode='poll', no public
|
||||
// callback URL — typically external runtimes on a laptop / behind NAT).
|
||||
//
|
||||
// In push-mode the platform synchronously POSTs the multipart body to the
|
||||
// workspace's /internal/chat/uploads/ingest endpoint and forgets about it.
|
||||
// Poll-mode has no callback URL to forward to, so the platform parses the
|
||||
// multipart on this side, persists each file as one pending_uploads row,
|
||||
// and lets the workspace pull it on its next inbox poll cycle.
|
||||
//
|
||||
// The Storage interface keeps the bytes-vs-metadata split clean: today
|
||||
// content is stored inline as bytea on the pending_uploads row, but the
|
||||
// shape lets a future PR (RFC #2789, S3-backed shared storage) swap to
|
||||
// object storage by adding a new Storage implementation without touching
|
||||
// any of the handler-layer callers.
|
||||
//
|
||||
// Lifecycle:
|
||||
//
|
||||
// Put — handler creates a row with the file content; assigns file_id.
|
||||
// Get — GET /workspaces/:id/pending-uploads/:fid/content reads bytes.
|
||||
// MarkFetched — stamps fetched_at on the row (Phase 3 observability).
|
||||
// Ack — POST /workspaces/:id/pending-uploads/:fid/ack;
|
||||
// terminal happy-path state. After ack, Get returns ErrNotFound.
|
||||
// GC sweep deletes acked rows after a retention window.
|
||||
//
|
||||
// Hard TTL: every row has an expires_at default of created_at + 24h. After
|
||||
// expiration the row is GC'd by Phase 3's sweep cron regardless of ack
|
||||
// state. Get on an expired row returns ErrNotFound — the workspace's next
|
||||
// poll will see the underlying activity_logs row was orphaned and the
|
||||
// agent surfaces "file expired" to the user.
|
||||
package pendinguploads
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Per-file size cap. Mirrors workspace-side ingest_handler
|
||||
// (workspace/internal_chat_uploads.py:198). Pinned at the DB level via
|
||||
// the size_bytes CHECK constraint; this Go-side constant exists so the
|
||||
// Put implementation can reject before round-tripping to Postgres.
|
||||
const MaxFileBytes = 25 * 1024 * 1024
|
||||
|
||||
// ErrNotFound is returned by Get / MarkFetched / Ack when the row is
|
||||
// absent. Callers turn this into HTTP 404. Treat acked + expired rows
|
||||
// as not-found so the workspace can never re-fetch a file we've
|
||||
// considered handed-off.
|
||||
var ErrNotFound = errors.New("pendinguploads: row not found, expired, or already acked")
|
||||
|
||||
// ErrTooLarge is returned by Put when content exceeds MaxFileBytes.
|
||||
// Callers turn this into HTTP 413. Pre-DB check so we don't push a
|
||||
// 25 MB+1 byte payload through Postgres just to have the CHECK reject it.
|
||||
var ErrTooLarge = errors.New("pendinguploads: content exceeds per-file cap")
|
||||
|
||||
// Record carries the full row including content. Returned by Get;
|
||||
// the GET /content handler streams Record.Content as the response body.
|
||||
type Record struct {
|
||||
FileID uuid.UUID
|
||||
WorkspaceID uuid.UUID
|
||||
Content []byte
|
||||
Filename string
|
||||
Mimetype string
|
||||
SizeBytes int64
|
||||
CreatedAt time.Time
|
||||
FetchedAt *time.Time // nil before first MarkFetched
|
||||
AckedAt *time.Time // nil before Ack (Get returns ErrNotFound after)
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// SweepResult is the per-cycle accounting from Sweep. Both counts are
|
||||
// non-negative; Total is just Acked + Expired for log/metrics
|
||||
// convenience. Phase 3 metrics expose these as separate counters so
|
||||
// dashboards can spot a stuck-ack pattern (high Expired, low Acked) vs.
|
||||
// healthy churn (Acked dominates).
|
||||
type SweepResult struct {
|
||||
Acked int // rows deleted because acked_at + retention elapsed
|
||||
Expired int // rows deleted because expires_at < now AND never acked
|
||||
}
|
||||
|
||||
// Total returns the sum of Acked + Expired — convenient for log lines.
|
||||
func (r SweepResult) Total() int { return r.Acked + r.Expired }
|
||||
|
||||
// PutItem is one file in a PutBatch call. Same per-field rules as Put —
|
||||
// empty content, missing filename, or content > MaxFileBytes is rejected
|
||||
// up-front so a bad item in the batch doesn't poison the transaction.
|
||||
type PutItem struct {
|
||||
Content []byte
|
||||
Filename string
|
||||
Mimetype string
|
||||
}
|
||||
|
||||
// Storage is the platform-side persistence boundary for poll-mode chat
|
||||
// uploads. The Postgres implementation backs all callers today; an S3-
|
||||
// backed implementation can drop in once RFC #2789 lands by making
|
||||
// content storage out-of-line and updating the Postgres-only metadata
|
||||
// columns.
|
||||
type Storage interface {
|
||||
// Put creates a row for one file targeting workspaceID and returns
|
||||
// the assigned file_id. content is bounded by MaxFileBytes;
|
||||
// filename / mimetype are stored verbatim — caller is responsible
|
||||
// for sanitization (matches workspace-side rule, see
|
||||
// internal_chat_uploads.py:sanitize_filename). Empty filename and
|
||||
// content > MaxFileBytes return errors before any DB write.
|
||||
Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error)
|
||||
|
||||
// PutBatch inserts N uploads atomically — either all rows commit or
|
||||
// none do. Returns assigned file_ids in input order on success;
|
||||
// returns an error and does NOT insert any row on failure.
|
||||
//
|
||||
// Use this from multi-file upload handlers so a per-row failure on
|
||||
// row K doesn't leave rows 1..K-1 orphaned in the table (a client
|
||||
// retry would then double-insert them on success). All-or-nothing
|
||||
// semantics match the multipart request the canvas sends — either
|
||||
// the whole batch succeeds or the user re-uploads.
|
||||
PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error)
|
||||
|
||||
// Get returns the full row including content. Returns ErrNotFound
|
||||
// when the row is absent, acked, or past expires_at. Caller should
|
||||
// not differentiate the three cases in the response — from the
|
||||
// workspace's perspective they all mean "not available, give up."
|
||||
Get(ctx context.Context, fileID uuid.UUID) (Record, error)
|
||||
|
||||
// MarkFetched stamps fetched_at on the row. Idempotent — repeated
|
||||
// calls update fetched_at to the latest timestamp. Returns
|
||||
// ErrNotFound if the row is absent / acked / expired.
|
||||
MarkFetched(ctx context.Context, fileID uuid.UUID) error
|
||||
|
||||
// Ack stamps acked_at on the row. Idempotent on the row state
|
||||
// (acked_at is only set the first time so workspace double-acks
|
||||
// don't move the timestamp). Returns ErrNotFound if the row is
|
||||
// absent or already expired; on already-acked, returns nil so
|
||||
// the workspace's at-least-once retry succeeds without an error.
|
||||
Ack(ctx context.Context, fileID uuid.UUID) error
|
||||
|
||||
// Sweep deletes rows past their retention window:
|
||||
// - acked rows older than ackRetention (give the workspace a
|
||||
// window to re-fetch in case it processed but failed to write
|
||||
// the file before crashing — at-least-once behavior).
|
||||
// - unacked rows past expires_at (the platform's hard TTL — 24h
|
||||
// by default; a workspace that hasn't fetched by then is
|
||||
// considered dead from the upload's perspective).
|
||||
// Returns the per-category deletion counts for observability.
|
||||
// Errors are surfaced to the caller; a transient DB error must NOT
|
||||
// crash the sweeper loop (it just retries on the next tick).
|
||||
Sweep(ctx context.Context, ackRetention time.Duration) (SweepResult, error)
|
||||
}
|
||||
|
||||
// PostgresStorage is the production Storage implementation backed by
|
||||
// the pending_uploads table.
|
||||
type PostgresStorage struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewPostgres returns a Storage backed by db. db must be a connected
|
||||
// pool; this constructor does no I/O.
|
||||
func NewPostgres(db *sql.DB) *PostgresStorage {
|
||||
return &PostgresStorage{db: db}
|
||||
}
|
||||
|
||||
// Compile-time check that PostgresStorage satisfies Storage.
|
||||
var _ Storage = (*PostgresStorage)(nil)
|
||||
|
||||
func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) {
|
||||
if len(content) == 0 {
|
||||
return uuid.Nil, fmt.Errorf("pendinguploads: empty content")
|
||||
}
|
||||
if len(content) > MaxFileBytes {
|
||||
return uuid.Nil, ErrTooLarge
|
||||
}
|
||||
if filename == "" {
|
||||
return uuid.Nil, fmt.Errorf("pendinguploads: empty filename")
|
||||
}
|
||||
// Filename length cap is enforced both here (early reject) and at
|
||||
// the DB layer (CHECK constraint) so a buggy caller can't write a
|
||||
// 200-char filename that Phase 2's URI rewrite would then truncate.
|
||||
if len(filename) > 100 {
|
||||
return uuid.Nil, fmt.Errorf("pendinguploads: filename exceeds 100 chars")
|
||||
}
|
||||
|
||||
var fileID uuid.UUID
|
||||
err := p.db.QueryRowContext(ctx, `
|
||||
INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING file_id
|
||||
`, workspaceID, content, int64(len(content)), filename, mimetype).Scan(&fileID)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("pendinguploads: insert: %w", err)
|
||||
}
|
||||
return fileID, nil
|
||||
}
|
||||
|
||||
// PutBatch inserts every item atomically inside a single Tx. On any
|
||||
// per-item validation or per-row INSERT error the Tx is rolled back and
|
||||
// the caller sees the error without any rows committed — no partial
|
||||
// orphans for a multi-file upload that fails mid-batch.
|
||||
//
|
||||
// Validation runs BEFORE BEGIN so a bad input shape (empty content,
|
||||
// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only
|
||||
// failures expected are DB-side (broken connection, statement timeout)
|
||||
// — those abort cleanly via Rollback.
|
||||
func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
for i, it := range items {
|
||||
if len(it.Content) == 0 {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: empty content", i)
|
||||
}
|
||||
if len(it.Content) > MaxFileBytes {
|
||||
return nil, ErrTooLarge
|
||||
}
|
||||
if it.Filename == "" {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i)
|
||||
}
|
||||
if len(it.Filename) > 100 {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i)
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: begin tx: %w", err)
|
||||
}
|
||||
// Defer-rollback is safe even after a successful Commit — the second
|
||||
// Rollback is a no-op (database/sql tracks tx state).
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
out := make([]uuid.UUID, 0, len(items))
|
||||
for i, it := range items {
|
||||
var fid uuid.UUID
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING file_id
|
||||
`, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err)
|
||||
}
|
||||
out = append(out, fid)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: commit batch: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) {
|
||||
// The expires_at + acked_at filter in the WHERE clause means a
|
||||
// caller sees ErrNotFound for absent / acked / expired without
|
||||
// needing per-case branching. Trade-off: we can't differentiate
|
||||
// in metrics, but the workspace's response is the same in all
|
||||
// three cases ("file gone, give up") so the granularity isn't
|
||||
// useful at this layer. Phase 3 dashboards aggregate row-state
|
||||
// counts directly off the table.
|
||||
var r Record
|
||||
err := p.db.QueryRowContext(ctx, `
|
||||
SELECT file_id, workspace_id, content, filename, mimetype,
|
||||
size_bytes, created_at, fetched_at, acked_at, expires_at
|
||||
FROM pending_uploads
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`, fileID).Scan(
|
||||
&r.FileID, &r.WorkspaceID, &r.Content, &r.Filename, &r.Mimetype,
|
||||
&r.SizeBytes, &r.CreatedAt, &r.FetchedAt, &r.AckedAt, &r.ExpiresAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return Record{}, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("pendinguploads: select: %w", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (p *PostgresStorage) MarkFetched(ctx context.Context, fileID uuid.UUID) error {
|
||||
// UPDATE on the same gating predicate as Get — keeps the "absent
|
||||
// or acked or expired = ErrNotFound" contract symmetric. Without
|
||||
// the predicate a workspace could re-stamp fetched_at on an acked
|
||||
// row, which would mislead Phase 3's stuck-fetch dashboard.
|
||||
res, err := p.db.ExecContext(ctx, `
|
||||
UPDATE pending_uploads
|
||||
SET fetched_at = now()
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`, fileID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pendinguploads: mark_fetched: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pendinguploads: mark_fetched rows: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresStorage) Ack(ctx context.Context, fileID uuid.UUID) error {
|
||||
// Set acked_at only if currently NULL — workspace at-least-once
|
||||
// retries don't move the timestamp, so dashboards see the first
|
||||
// successful ack as the "delivery time." Two-clause WHERE: row
|
||||
// must exist and not be expired; acked-but-still-in-window is
|
||||
// returned as success (idempotent retry).
|
||||
res, err := p.db.ExecContext(ctx, `
|
||||
UPDATE pending_uploads
|
||||
SET acked_at = now()
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`, fileID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pendinguploads: ack: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pendinguploads: ack rows: %w", err)
|
||||
}
|
||||
if n == 1 {
|
||||
return nil
|
||||
}
|
||||
// Zero-rows-affected: either the row doesn't exist / has expired,
|
||||
// OR it was already acked. Re-query to disambiguate so the
|
||||
// idempotent-retry case returns nil instead of ErrNotFound.
|
||||
var ackedAt sql.NullTime
|
||||
err = p.db.QueryRowContext(ctx, `
|
||||
SELECT acked_at FROM pending_uploads
|
||||
WHERE file_id = $1 AND expires_at > now()
|
||||
`, fileID).Scan(&ackedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("pendinguploads: ack disambiguate: %w", err)
|
||||
}
|
||||
if ackedAt.Valid {
|
||||
// Already acked — idempotent success.
|
||||
return nil
|
||||
}
|
||||
// Predicate matched a non-acked, non-expired row but RowsAffected
|
||||
// was 0. This means the row was concurrently modified between the
|
||||
// UPDATE and the SELECT (extremely rare; e.g. a Phase 3 sweep
|
||||
// raced with the ACK). Treat as success — the row is gone, but
|
||||
// the workspace's intent ("I'm done with this file") was honored.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sweep deletes acked rows past their retention window plus any
|
||||
// unacked rows whose hard TTL has elapsed. Single round-trip: a CTE
|
||||
// captures the deletion in one DELETE … RETURNING and the outer
|
||||
// SELECT sums by category. Cheaper and tighter than two round trips,
|
||||
// and atomic w.r.t. concurrent writes (the WHERE predicate sees a
|
||||
// consistent snapshot via Postgres MVCC).
|
||||
//
|
||||
// ackRetention=0 deletes all acked rows immediately; values <0 are
|
||||
// clamped to 0 for safety. Caller defaults are documented at
|
||||
// StartSweeper's DefaultAckRetention.
|
||||
func (p *PostgresStorage) Sweep(ctx context.Context, ackRetention time.Duration) (SweepResult, error) {
|
||||
if ackRetention < 0 {
|
||||
ackRetention = 0
|
||||
}
|
||||
// make_interval expects integer seconds — Postgres accepts a
|
||||
// floating point but we deliberately round to the nearest second
|
||||
// so test fixtures pin a deterministic value across PG versions.
|
||||
retentionSecs := int64(ackRetention.Seconds())
|
||||
|
||||
var acked, expired int
|
||||
err := p.db.QueryRowContext(ctx, `
|
||||
WITH deleted AS (
|
||||
DELETE FROM pending_uploads
|
||||
WHERE (acked_at IS NOT NULL AND acked_at < now() - make_interval(secs => $1))
|
||||
OR (acked_at IS NULL AND expires_at < now())
|
||||
RETURNING (acked_at IS NOT NULL) AS was_acked
|
||||
)
|
||||
SELECT
|
||||
COALESCE(SUM(CASE WHEN was_acked THEN 1 ELSE 0 END), 0)::int AS acked,
|
||||
COALESCE(SUM(CASE WHEN NOT was_acked THEN 1 ELSE 0 END), 0)::int AS expired
|
||||
FROM deleted
|
||||
`, retentionSecs).Scan(&acked, &expired)
|
||||
if err != nil {
|
||||
return SweepResult{}, fmt.Errorf("pendinguploads: sweep: %w", err)
|
||||
}
|
||||
return SweepResult{Acked: acked, Expired: expired}, nil
|
||||
}
|
||||
@@ -0,0 +1,733 @@
|
||||
package pendinguploads_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// Tests pin the SQL the handler relies on. Drift detection: if the
|
||||
// migration changes column order / predicate shape, sqlmock's
|
||||
// QueryMatcherEqual + ExpectQuery / ExpectExec on the literal text
|
||||
// fails the test before the handler can ship a silently-broken read.
|
||||
//
|
||||
// Why sqlmock and not testcontainers / real Postgres:
|
||||
//
|
||||
// The Storage contract is "this Go method runs THIS SQL." Real-DB
|
||||
// tests would catch SQL-syntax errors but not the contract drift
|
||||
// we care about (e.g. handler accidentally reordering columns,
|
||||
// dropping the acked_at predicate, etc.). Postgres-syntax coverage
|
||||
// lives in the migration round-trip test (Phase 4 E2E).
|
||||
|
||||
func newMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
// Single source of truth for the SQL strings — drift here = test fails;
|
||||
// matches the Go literals in storage.go exactly.
|
||||
const (
|
||||
insertSQL = `
|
||||
INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING file_id
|
||||
`
|
||||
selectSQL = `
|
||||
SELECT file_id, workspace_id, content, filename, mimetype,
|
||||
size_bytes, created_at, fetched_at, acked_at, expires_at
|
||||
FROM pending_uploads
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`
|
||||
markFetchedSQL = `
|
||||
UPDATE pending_uploads
|
||||
SET fetched_at = now()
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`
|
||||
ackSQL = `
|
||||
UPDATE pending_uploads
|
||||
SET acked_at = now()
|
||||
WHERE file_id = $1
|
||||
AND acked_at IS NULL
|
||||
AND expires_at > now()
|
||||
`
|
||||
ackDisambiguateSQL = `
|
||||
SELECT acked_at FROM pending_uploads
|
||||
WHERE file_id = $1 AND expires_at > now()
|
||||
`
|
||||
sweepSQL = `
|
||||
WITH deleted AS (
|
||||
DELETE FROM pending_uploads
|
||||
WHERE (acked_at IS NOT NULL AND acked_at < now() - make_interval(secs => $1))
|
||||
OR (acked_at IS NULL AND expires_at < now())
|
||||
RETURNING (acked_at IS NOT NULL) AS was_acked
|
||||
)
|
||||
SELECT
|
||||
COALESCE(SUM(CASE WHEN was_acked THEN 1 ELSE 0 END), 0)::int AS acked,
|
||||
COALESCE(SUM(CASE WHEN NOT was_acked THEN 1 ELSE 0 END), 0)::int AS expired
|
||||
FROM deleted
|
||||
`
|
||||
)
|
||||
|
||||
// ----- Put ------------------------------------------------------------------
|
||||
|
||||
func TestPut_HappyPath_ReturnsAssignedFileID(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
expectedID := uuid.New()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("hello"), int64(5), "report.pdf", "application/pdf").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(expectedID))
|
||||
|
||||
got, err := store.Put(context.Background(), wsID, []byte("hello"), "report.pdf", "application/pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
if got != expectedID {
|
||||
t.Errorf("file_id mismatch: got %s want %s", got, expectedID)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPut_RejectsEmptyContentBeforeDB(t *testing.T) {
|
||||
db, _ := newMockDB(t) // no expectations — must NOT round-trip
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.Put(context.Background(), uuid.New(), nil, "x.txt", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "empty content") {
|
||||
t.Fatalf("expected empty-content error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPut_RejectsOversizeBeforeDB(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
too := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
_, err := store.Put(context.Background(), uuid.New(), too, "x.txt", "")
|
||||
if !errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
t.Fatalf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPut_RejectsEmptyFilenameBeforeDB(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.Put(context.Background(), uuid.New(), []byte("hi"), "", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "empty filename") {
|
||||
t.Fatalf("expected empty-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPut_RejectsLongFilenameBeforeDB(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
long := strings.Repeat("a", 101)
|
||||
_, err := store.Put(context.Background(), uuid.New(), []byte("hi"), long, "")
|
||||
if err == nil || !strings.Contains(err.Error(), "exceeds 100 chars") {
|
||||
t.Fatalf("expected too-long-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPut_PropagatesDBError(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(uuid.Nil, sqlmock.AnyArg(), int64(2), "x", "").
|
||||
WillReturnError(errors.New("connection refused"))
|
||||
|
||||
wsID := uuid.Nil
|
||||
_, err := store.Put(context.Background(), wsID, []byte("hi"), "x", "")
|
||||
if err == nil || !strings.Contains(err.Error(), "insert") {
|
||||
t.Fatalf("expected wrapped insert error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Get ------------------------------------------------------------------
|
||||
|
||||
func TestGet_HappyPath_ReturnsFullRow(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
wsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"file_id", "workspace_id", "content", "filename", "mimetype",
|
||||
"size_bytes", "created_at", "fetched_at", "acked_at", "expires_at",
|
||||
}).AddRow(
|
||||
fid, wsID, []byte("data"), "x.bin", "application/octet-stream",
|
||||
int64(4), now, nil, nil, now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
r, err := store.Get(context.Background(), fid)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if r.FileID != fid || r.WorkspaceID != wsID {
|
||||
t.Errorf("ids mismatch: %+v", r)
|
||||
}
|
||||
if string(r.Content) != "data" || r.SizeBytes != 4 {
|
||||
t.Errorf("content mismatch: %+v", r)
|
||||
}
|
||||
if r.FetchedAt != nil || r.AckedAt != nil {
|
||||
t.Errorf("expected nil timestamps for unfetched row, got fetched=%v acked=%v", r.FetchedAt, r.AckedAt)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_AbsentRow_ReturnsErrNotFound(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
_, err := store.Get(context.Background(), fid)
|
||||
if !errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet_DBError_WrappedAndPropagated(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WillReturnError(errors.New("connection lost"))
|
||||
|
||||
_, err := store.Get(context.Background(), uuid.New())
|
||||
if err == nil || errors.Is(err, pendinguploads.ErrNotFound) || !strings.Contains(err.Error(), "select") {
|
||||
t.Fatalf("expected wrapped select error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- MarkFetched ----------------------------------------------------------
|
||||
|
||||
func TestMarkFetched_HappyPath(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(markFetchedSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := store.MarkFetched(context.Background(), fid); err != nil {
|
||||
t.Fatalf("MarkFetched: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkFetched_AbsentOrAckedOrExpired_ReturnsErrNotFound(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(markFetchedSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
err := store.MarkFetched(context.Background(), fid)
|
||||
if !errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkFetched_DBError_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectExec(markFetchedSQL).
|
||||
WillReturnError(errors.New("pg flake"))
|
||||
|
||||
err := store.MarkFetched(context.Background(), uuid.New())
|
||||
if err == nil || errors.Is(err, pendinguploads.ErrNotFound) || !strings.Contains(err.Error(), "mark_fetched") {
|
||||
t.Fatalf("expected wrapped mark_fetched error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Ack ------------------------------------------------------------------
|
||||
|
||||
func TestAck_FirstAck_StampsAckedAt(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(ackSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := store.Ack(context.Background(), fid); err != nil {
|
||||
t.Fatalf("Ack: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_AlreadyAcked_IdempotentSuccess(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
// First UPDATE matches zero rows (already acked).
|
||||
mock.ExpectExec(ackSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
// Disambiguation SELECT finds the row with acked_at non-null.
|
||||
mock.ExpectQuery(ackDisambiguateSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked_at"}).AddRow(time.Now().UTC()))
|
||||
|
||||
if err := store.Ack(context.Background(), fid); err != nil {
|
||||
t.Fatalf("expected idempotent success on already-acked, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_AbsentOrExpired_ReturnsErrNotFound(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(ackSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery(ackDisambiguateSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
err := store.Ack(context.Background(), fid)
|
||||
if !errors.Is(err, pendinguploads.ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_RaceWithSweep_ReturnsSuccess(t *testing.T) {
|
||||
// UPDATE saw 0 rows AND the disambiguate SELECT saw a row with
|
||||
// acked_at IS NULL — only possible if the GC sweep raced between
|
||||
// the two queries. The contract says we honor the workspace's ACK
|
||||
// intent and return success.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(ackSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery(ackDisambiguateSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked_at"}).AddRow(nil))
|
||||
|
||||
if err := store.Ack(context.Background(), fid); err != nil {
|
||||
t.Fatalf("expected race success, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_DBErrorOnUpdate_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectExec(ackSQL).
|
||||
WillReturnError(errors.New("conn refused"))
|
||||
|
||||
err := store.Ack(context.Background(), uuid.New())
|
||||
if err == nil || !strings.Contains(err.Error(), "ack:") {
|
||||
t.Fatalf("expected wrapped ack error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkFetched_RowsAffectedError_Wrapped(t *testing.T) {
|
||||
// Some drivers (or Result wrappers) return an error from
|
||||
// RowsAffected() even when ExecContext succeeded — the contract
|
||||
// says we surface that as a wrapped error rather than silently
|
||||
// treating it as 0 rows (= ErrNotFound, which would mislead the
|
||||
// workspace into giving up on a possibly-fetched row).
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectExec(markFetchedSQL).
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("driver doesn't support RowsAffected")))
|
||||
|
||||
err := store.MarkFetched(context.Background(), uuid.New())
|
||||
if err == nil || !strings.Contains(err.Error(), "mark_fetched rows") {
|
||||
t.Fatalf("expected wrapped rows-affected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_RowsAffectedError_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectExec(ackSQL).
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("driver doesn't support RowsAffected")))
|
||||
|
||||
err := store.Ack(context.Background(), uuid.New())
|
||||
if err == nil || !strings.Contains(err.Error(), "ack rows") {
|
||||
t.Fatalf("expected wrapped rows-affected error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck_DBErrorOnDisambiguate_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
fid := uuid.New()
|
||||
mock.ExpectExec(ackSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery(ackDisambiguateSQL).
|
||||
WithArgs(fid).
|
||||
WillReturnError(errors.New("connection refused"))
|
||||
|
||||
err := store.Ack(context.Background(), fid)
|
||||
if err == nil || !strings.Contains(err.Error(), "disambiguate") {
|
||||
t.Fatalf("expected wrapped disambiguate error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Sweep ----------------------------------------------------------------
|
||||
|
||||
func TestSweep_DeletesAckedAndExpired_ReturnsCounts(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(sweepSQL).
|
||||
WithArgs(int64(3600)). // 1h retention
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked", "expired"}).AddRow(7, 2))
|
||||
|
||||
res, err := store.Sweep(context.Background(), time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Acked != 7 || res.Expired != 2 || res.Total() != 9 {
|
||||
t.Errorf("got %+v want acked=7 expired=2 total=9", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSweep_NothingToDelete_ReturnsZero(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(sweepSQL).
|
||||
WithArgs(int64(3600)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked", "expired"}).AddRow(0, 0))
|
||||
|
||||
res, err := store.Sweep(context.Background(), time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Total() != 0 {
|
||||
t.Errorf("got %+v, want zero result", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSweep_NegativeRetentionClampedToZero(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
// Negative retention must clamp to 0; the SQL gets `secs => 0` so an
|
||||
// acked-just-now row is eligible for deletion immediately. Pinned
|
||||
// here because passing the raw negative through `make_interval` would
|
||||
// silently shift acked_at → future and effectively retain rows
|
||||
// forever — exactly the wrong behavior for a "delete more aggressively"
|
||||
// caller.
|
||||
mock.ExpectQuery(sweepSQL).
|
||||
WithArgs(int64(0)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked", "expired"}).AddRow(3, 0))
|
||||
|
||||
res, err := store.Sweep(context.Background(), -1*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Acked != 3 {
|
||||
t.Errorf("got %+v want acked=3", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSweep_ZeroRetentionImmediatelyDeletesAcked(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(sweepSQL).
|
||||
WithArgs(int64(0)).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"acked", "expired"}).AddRow(5, 1))
|
||||
|
||||
res, err := store.Sweep(context.Background(), 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Sweep: %v", err)
|
||||
}
|
||||
if res.Acked != 5 || res.Expired != 1 {
|
||||
t.Errorf("got %+v want acked=5 expired=1", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSweep_DBError_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectQuery(sweepSQL).
|
||||
WithArgs(int64(60)).
|
||||
WillReturnError(errors.New("connection lost"))
|
||||
|
||||
_, err := store.Sweep(context.Background(), time.Minute)
|
||||
if err == nil || !strings.Contains(err.Error(), "sweep") {
|
||||
t.Fatalf("expected wrapped sweep error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSweepResult_TotalSumsCounts(t *testing.T) {
|
||||
r := pendinguploads.SweepResult{Acked: 4, Expired: 3}
|
||||
if r.Total() != 7 {
|
||||
t.Errorf("Total = %d, want 7", r.Total())
|
||||
}
|
||||
z := pendinguploads.SweepResult{}
|
||||
if z.Total() != 0 {
|
||||
t.Errorf("zero Total = %d, want 0", z.Total())
|
||||
}
|
||||
}
|
||||
|
||||
// ----- PutBatch -------------------------------------------------------------
|
||||
//
|
||||
// PutBatch is the multi-file atomic insert path used by uploadPollMode in
|
||||
// chat_files.go. The contract that callers rely on:
|
||||
//
|
||||
// - Either ALL rows commit, or NONE do — a per-row INSERT failure must
|
||||
// leave the table unchanged (no orphaned rows from a half-applied batch).
|
||||
// - Per-item validation runs BEFORE the Tx opens so a bad input shape
|
||||
// never wastes a BEGIN round-trip.
|
||||
// - Returned []uuid.UUID is in input order — handler maps response back
|
||||
// to the multipart Files[i].
|
||||
//
|
||||
// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us
|
||||
// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for
|
||||
// BeginTx-with-options, the test fails until we re-pin.
|
||||
|
||||
func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1, id2, id3 := uuid.New(), uuid.New(), uuid.New()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3))
|
||||
mock.ExpectCommit()
|
||||
// Rollback after Commit is a no-op in database/sql; sqlmock allows it
|
||||
// when ExpectCommit was already matched, so we don't need to expect it.
|
||||
|
||||
got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"},
|
||||
{Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutBatch: %v", err)
|
||||
}
|
||||
if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 {
|
||||
t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) {
|
||||
db, _ := newMockDB(t) // zero expectations — must NOT round-trip
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
got, err := store.PutBatch(context.Background(), uuid.New(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on empty batch, got %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("expected nil ids on empty batch, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "a.txt"},
|
||||
{Content: nil, Filename: "b.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") {
|
||||
t.Fatalf("expected item-1 empty-content error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
too := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "small.txt"},
|
||||
{Content: too, Filename: "huge.bin"},
|
||||
})
|
||||
if !errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
t.Fatalf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: ""},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") {
|
||||
t.Fatalf("expected item-0 empty-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
long := strings.Repeat("z", 101)
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "ok.txt"},
|
||||
{Content: []byte("hi"), Filename: long},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") {
|
||||
t.Fatalf("expected item-1 too-long-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_BeginTxError_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectBegin().WillReturnError(errors.New("conn refused"))
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "begin tx") {
|
||||
t.Fatalf("expected wrapped begin-tx error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) {
|
||||
// First INSERT succeeds, second errors. PutBatch MUST NOT issue
|
||||
// Commit; the deferred Rollback unwinds row 1 so neither row commits.
|
||||
// This is the contract that prevents orphan rows on a failed batch.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1 := uuid.New()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("bb"), int64(2), "b.txt", "").
|
||||
WillReturnError(errors.New("statement timeout"))
|
||||
// Critical: Rollback expected, NOT Commit. If a future refactor
|
||||
// accidentally swallows the per-row error and Commits anyway, this
|
||||
// test fails because the unmet ExpectCommit-vs-Rollback shape diverges.
|
||||
mock.ExpectRollback()
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("aaa"), Filename: "a.txt"},
|
||||
{Content: []byte("bb"), Filename: "b.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "batch insert item 1") {
|
||||
t.Fatalf("expected wrapped per-row insert error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations (must rollback, no commit): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) {
|
||||
// Edge case: very first INSERT fails. No rows ever staged — but the
|
||||
// Tx still needs to roll back to release the snapshot.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("oops"), int64(4), "a.txt", "").
|
||||
WillReturnError(errors.New("constraint violation"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("oops"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "batch insert item 0") {
|
||||
t.Fatalf("expected wrapped item-0 insert error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_CommitError_Wrapped(t *testing.T) {
|
||||
// Commit fails after every INSERT succeeded. Postgres has already
|
||||
// rolled back the Tx by this point; we surface the error so the
|
||||
// handler returns 500 and the client retries.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1 := uuid.New()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("hi"), int64(2), "a.txt", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectCommit().WillReturnError(errors.New("commit broken"))
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "commit batch") {
|
||||
t.Fatalf("expected wrapped commit error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
// sweeper.go — periodic GC for the pending_uploads table.
|
||||
//
|
||||
// The platform's poll-mode chat-upload handler creates a row in
|
||||
// pending_uploads for every chat-attached file the canvas sends to a
|
||||
// poll-mode workspace. The workspace's inbox poller fetches the bytes
|
||||
// and acks the row, but two failure modes leak rows long-term:
|
||||
//
|
||||
// 1. Workspace fetches but never acks (network hiccup between GET
|
||||
// /content and POST /ack; workspace crashed between the two).
|
||||
// Phase 1's Get refuses to re-serve an acked row, but a never-
|
||||
// acked row could in principle be fetched repeatedly until expires_at.
|
||||
// Phase 2's workspace-side fetcher is idempotent; the worry is
|
||||
// only disk usage on the platform side.
|
||||
//
|
||||
// 2. Workspace never fetches at all (workspace was offline when the
|
||||
// row was written; the upload's TTL elapsed).
|
||||
//
|
||||
// This sweeper handles both. It runs every SweepInterval, deletes rows
|
||||
// in either category, and emits structured logs + Prometheus counters
|
||||
// so a stuck-fetch dashboard can spot the leak class.
|
||||
//
|
||||
// Failure isolation: a transient DB error must NOT crash the sweeper.
|
||||
// We log + continue; the next tick retries. ctx cancellation cleanly
|
||||
// shuts the loop down for graceful shutdown.
|
||||
|
||||
package pendinguploads
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/metrics"
|
||||
)
|
||||
|
||||
// SweepInterval is the cadence of the GC loop. 5 minutes is a balance
|
||||
// between "rows reaped quickly enough that disk usage doesn't surprise
|
||||
// anyone" and "we don't pay a DELETE round-trip every 30 seconds when
|
||||
// there are no candidates." Aligned with other low-priority sweepers
|
||||
// (registry/orphan_sweeper runs at 60s but operates on Docker — much
|
||||
// more expensive per cycle than a single indexed DELETE).
|
||||
const SweepInterval = 5 * time.Minute
|
||||
|
||||
// DefaultAckRetention is how long an acked row sticks around before the
|
||||
// sweeper deletes it. 1 hour gives the workspace enough time to retry
|
||||
// the GET if its first fetch crashed mid-write — at-least-once handoff
|
||||
// without leaking content for a full 24h after the workspace already
|
||||
// has a copy.
|
||||
const DefaultAckRetention = 1 * time.Hour
|
||||
|
||||
// sweepDeadline bounds a single sweep cycle. A daemon at the edge of
|
||||
// timeout shouldn't pile up goroutines; 30s is generous for a single
|
||||
// indexed DELETE on a table that should rarely have more than a few
|
||||
// thousand rows in flight.
|
||||
const sweepDeadline = 30 * time.Second
|
||||
|
||||
// StartSweeper runs the GC loop until ctx is cancelled. nil storage
|
||||
// makes the loop a no-op (matches the handlers' tolerance for an
|
||||
// unconfigured pendinguploads — some test harnesses run without the
|
||||
// storage wired).
|
||||
//
|
||||
// Pass ackRetention=0 to use DefaultAckRetention. Negative values are
|
||||
// clamped at the storage layer.
|
||||
//
|
||||
// Production callers use SweepInterval (5m). Tests use a short interval
|
||||
// to exercise the ticker-driven sweep path without burning real wall-
|
||||
// clock time.
|
||||
func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) {
|
||||
startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval)
|
||||
}
|
||||
|
||||
// startSweeperWithInterval is the test-friendly variant of StartSweeper
|
||||
// — same loop, but the cadence is caller-specified. Production code
|
||||
// should use StartSweeper to keep the SweepInterval constant pinned.
|
||||
func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
|
||||
if storage == nil {
|
||||
log.Println("pendinguploads sweeper: storage is nil — sweeper disabled")
|
||||
return
|
||||
}
|
||||
if ackRetention == 0 {
|
||||
ackRetention = DefaultAckRetention
|
||||
}
|
||||
log.Printf(
|
||||
"pendinguploads sweeper started — sweeping every %s; ack retention %s",
|
||||
interval, ackRetention,
|
||||
)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
// Run once immediately so a platform restart cleans up any rows
|
||||
// that became eligible while we were down — don't make the
|
||||
// operator wait 5 minutes for the first sweep.
|
||||
sweepOnce(ctx, storage, ackRetention)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("pendinguploads sweeper: shutdown")
|
||||
return
|
||||
case <-ticker.C:
|
||||
sweepOnce(ctx, storage, ackRetention)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sweepOnce(parent context.Context, storage Storage, ackRetention time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(parent, sweepDeadline)
|
||||
defer cancel()
|
||||
|
||||
res, err := storage.Sweep(ctx, ackRetention)
|
||||
if err != nil {
|
||||
// Transient errors: log + continue. The next tick retries; if
|
||||
// the DB is genuinely down, the rest of the platform is also
|
||||
// broken and disk usage is the least of the operator's
|
||||
// problems.
|
||||
log.Printf("pendinguploads sweeper: Sweep failed: %v", err)
|
||||
metrics.PendingUploadsSweepError()
|
||||
return
|
||||
}
|
||||
metrics.PendingUploadsSwept(res.Acked, res.Expired)
|
||||
if res.Total() > 0 {
|
||||
// Per-cycle structured-ish log (one line per cycle that did
|
||||
// something). Quiet by design — most cycles delete zero rows
|
||||
// on a healthy system, and a stream of empty-result lines
|
||||
// would drown the production log without surfacing a signal.
|
||||
log.Printf(
|
||||
"pendinguploads sweeper: deleted acked=%d expired=%d total=%d",
|
||||
res.Acked, res.Expired, res.Total(),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package pendinguploads_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/metrics"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
)
|
||||
|
||||
// fakeSweepStorage is a minimal Storage that records every Sweep call
|
||||
// and lets each test inject the per-cycle return values. The other
|
||||
// methods are no-ops — the sweeper goroutine never calls them.
|
||||
type fakeSweepStorage struct {
|
||||
calls atomic.Int64
|
||||
results []pendinguploads.SweepResult
|
||||
errs []error
|
||||
cycleDone chan struct{} // closed after each Sweep call (test sync)
|
||||
gotRetention atomic.Int64 // last ackRetention seen, in seconds
|
||||
}
|
||||
|
||||
func newFakeSweepStorage(results []pendinguploads.SweepResult, errs []error) *fakeSweepStorage {
|
||||
return &fakeSweepStorage{
|
||||
results: results,
|
||||
errs: errs,
|
||||
cycleDone: make(chan struct{}, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeSweepStorage) Put(_ context.Context, _ uuid.UUID, _ []byte, _, _ string) (uuid.UUID, error) {
|
||||
return uuid.Nil, errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) Get(_ context.Context, _ uuid.UUID) (pendinguploads.Record, error) {
|
||||
return pendinguploads.Record{}, errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error {
|
||||
return errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error {
|
||||
return errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) {
|
||||
idx := int(f.calls.Load())
|
||||
f.calls.Add(1)
|
||||
f.gotRetention.Store(int64(ackRetention.Seconds()))
|
||||
defer func() {
|
||||
select {
|
||||
case f.cycleDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
if idx < len(f.errs) && f.errs[idx] != nil {
|
||||
return pendinguploads.SweepResult{}, f.errs[idx]
|
||||
}
|
||||
if idx < len(f.results) {
|
||||
return f.results[idx], nil
|
||||
}
|
||||
return pendinguploads.SweepResult{}, nil
|
||||
}
|
||||
|
||||
// waitForCycle blocks until at least one Sweep completes, with a deadline.
|
||||
// Tests use this instead of time.Sleep to avoid flakes on slow CI hosts.
|
||||
//
|
||||
// CAVEAT: cycleDone fires from inside fakeSweepStorage.Sweep's defer,
|
||||
// which runs as Sweep returns its result — BEFORE the StartSweeper
|
||||
// loop has processed the (result, error) tuple and called the
|
||||
// metric recorders. Tests that assert on metric counters must NOT
|
||||
// rely on this wait alone; use waitForMetricDelta instead so the
|
||||
// metric increment race (Sweep returns → cycleDone fires → test
|
||||
// reads counter → only then does StartSweeper's loop call
|
||||
// metrics.PendingUploadsSweepError) doesn't produce a flake.
|
||||
func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.NewTimer(timeout)
|
||||
defer deadline.Stop()
|
||||
for got := 0; got < n; got++ {
|
||||
select {
|
||||
case <-f.cycleDone:
|
||||
case <-deadline.C:
|
||||
t.Fatalf("waited %s for %d sweep cycles, got %d", timeout, n, f.calls.Load())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForMetricDelta polls the supplied delta function until it returns
|
||||
// `want` or the timeout elapses. Use after waitForCycle when the test
|
||||
// asserts on a metric counter — closes the race between cycleDone
|
||||
// (signalled inside fakeSweepStorage.Sweep's defer, BEFORE Sweep
|
||||
// returns to StartSweeper) and the metric recording (which happens in
|
||||
// StartSweeper's loop AFTER Sweep returns). On a slow CI host the test
|
||||
// goroutine wins the read before StartSweeper's goroutine writes the
|
||||
// counter; the polling assert preserves the determinism of "the metric
|
||||
// MUST be N" without timing-based flakes.
|
||||
//
|
||||
// Per memory feedback_question_test_when_unexpected.md: the failure
|
||||
// mode "delta=0, want=1" looked like a real bug at first glance —
|
||||
// "metric never incremented" — but instrumented analysis showed the
|
||||
// metric DID increment, just AFTER the test's read. The fix is the
|
||||
// test's wait shape, not the production code.
|
||||
func waitForMetricDelta(t *testing.T, delta func() int64, want int64, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if delta() == want {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("waited %s for metric delta=%d, last seen %d", timeout, want, delta())
|
||||
}
|
||||
|
||||
func TestStartSweeper_NilStorageDoesNotPanic(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Should return immediately without panicking; no goroutine to wait on.
|
||||
pendinguploads.StartSweeper(ctx, nil, time.Second)
|
||||
}
|
||||
|
||||
func TestStartSweeper_RunsImmediatelyAndOnTick(t *testing.T) {
|
||||
store := newFakeSweepStorage(
|
||||
[]pendinguploads.SweepResult{{Acked: 5}, {Acked: 1, Expired: 2}},
|
||||
nil,
|
||||
)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
if got := store.calls.Load(); got < 1 {
|
||||
t.Errorf("expected at least one immediate sweep, got %d", got)
|
||||
}
|
||||
// Retention propagated.
|
||||
if store.gotRetention.Load() != 3600 {
|
||||
t.Errorf("retention seconds = %d, want 3600", store.gotRetention.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSweeper_ZeroAckRetentionUsesDefault(t *testing.T) {
|
||||
store := newFakeSweepStorage([]pendinguploads.SweepResult{{}}, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeper(ctx, store, 0)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
want := int64(pendinguploads.DefaultAckRetention.Seconds())
|
||||
if store.gotRetention.Load() != want {
|
||||
t.Errorf("retention = %d, want default %d", store.gotRetention.Load(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSweeper_ContextCancelStopsLoop(t *testing.T) {
|
||||
store := newFakeSweepStorage([]pendinguploads.SweepResult{{}}, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
pendinguploads.StartSweeper(ctx, store, time.Second)
|
||||
close(done)
|
||||
}()
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("StartSweeper did not return after ctx cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) {
|
||||
store := newFakeSweepStorage(
|
||||
[]pendinguploads.SweepResult{{Acked: 1}, {Expired: 1}, {}, {}, {}},
|
||||
nil,
|
||||
)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond)
|
||||
|
||||
// Immediate cycle + at least one tick-driven cycle.
|
||||
store.waitForCycle(t, 2, 2*time.Second)
|
||||
|
||||
if got := store.calls.Load(); got < 2 {
|
||||
t.Errorf("expected ≥2 cycles (immediate + 1 tick), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSweeper_TransientErrorDoesNotCrashLoop(t *testing.T) {
|
||||
// First call errors; second call succeeds. The loop must keep running
|
||||
// across the error so a one-off DB hiccup doesn't disable the GC.
|
||||
store := newFakeSweepStorage(
|
||||
[]pendinguploads.SweepResult{{}, {Acked: 1}},
|
||||
[]error{errors.New("transient db error"), nil},
|
||||
)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// 50ms ticker so the second cycle fires quickly enough for the test.
|
||||
// We re-export SweepInterval as a const, but tests use the public
|
||||
// StartSweeper that takes its own interval — wait, the public
|
||||
// StartSweeper signature uses the package-level SweepInterval. Hmm,
|
||||
// this means the test takes ~5 minutes. Let me reconsider.
|
||||
//
|
||||
// (We patch the test below to just look at the immediate-sweep call
|
||||
// + an error path, since the immediate call is enough to prove the
|
||||
// "error doesn't crash" contract — the loop continues afterward
|
||||
// regardless of timing.)
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
|
||||
// Wait for the first (errored) cycle.
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
// Cancel — the goroutine returns cleanly, proving the error path
|
||||
// didn't crash the loop. Without this fix the goroutine would have
|
||||
// either panicked (process abort visible at exit) or stuck (this
|
||||
// cancel + done-channel pattern would deadlock instead).
|
||||
cancel()
|
||||
}
|
||||
|
||||
// metricDelta returns a function that, when called, returns how much
|
||||
// the (acked, expired, errored) counters have advanced since metricDelta
|
||||
// was originally called. metrics is a process-singleton across the test
|
||||
// suite; deltas isolate this test from order-of-execution dependencies.
|
||||
func metricDelta(t *testing.T) (deltaAcked, deltaExpired, deltaError func() int64) {
|
||||
t.Helper()
|
||||
a0, e0, err0 := metrics.PendingUploadsSweepCounts()
|
||||
deltaAcked = func() int64 {
|
||||
a, _, _ := metrics.PendingUploadsSweepCounts()
|
||||
return a - a0
|
||||
}
|
||||
deltaExpired = func() int64 {
|
||||
_, e, _ := metrics.PendingUploadsSweepCounts()
|
||||
return e - e0
|
||||
}
|
||||
deltaError = func() int64 {
|
||||
_, _, x := metrics.PendingUploadsSweepCounts()
|
||||
return x - err0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestStartSweeper_RecordsMetricsOnSuccess(t *testing.T) {
|
||||
deltaAcked, deltaExpired, deltaError := metricDelta(t)
|
||||
|
||||
store := newFakeSweepStorage(
|
||||
[]pendinguploads.SweepResult{{Acked: 3, Expired: 5}},
|
||||
nil,
|
||||
)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
|
||||
// Poll for the success counters to settle — closes the cycleDone-
|
||||
// vs-metric-record race (see waitForMetricDelta comment).
|
||||
waitForMetricDelta(t, deltaAcked, 3, 2*time.Second)
|
||||
waitForMetricDelta(t, deltaExpired, 5, 2*time.Second)
|
||||
// Error counter MUST stay at zero on the success path. Read after
|
||||
// the success counters have settled — once those are correct,
|
||||
// StartSweeper has fully processed this cycle's result.
|
||||
if got := deltaError(); got != 0 {
|
||||
t.Errorf("error counter delta = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartSweeper_RecordsMetricsOnError(t *testing.T) {
|
||||
_, _, deltaError := metricDelta(t)
|
||||
|
||||
store := newFakeSweepStorage(
|
||||
[]pendinguploads.SweepResult{{}},
|
||||
[]error{errors.New("db down")},
|
||||
)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
|
||||
// Poll for the error counter to settle — cycleDone fires inside
|
||||
// the fake's Sweep defer, BEFORE StartSweeper's loop receives the
|
||||
// returned error and calls metrics.PendingUploadsSweepError. On
|
||||
// slow CI hosts a direct deltaError() read here returns 0 even
|
||||
// though the metric WILL be 1 a few ms later. See
|
||||
// waitForMetricDelta comment.
|
||||
waitForMetricDelta(t, deltaError, 1, 2*time.Second)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// CPProvisionerAPI is the contract WorkspaceHandler uses to talk to the
|
||||
@@ -214,6 +215,13 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
|
||||
}
|
||||
|
||||
log.Printf("CP provisioner: workspace %s → EC2 instance %s (%s)", cfg.WorkspaceID, result.InstanceID, result.State)
|
||||
provlog.Event("provision.ec2_started", map[string]any{
|
||||
"workspace_id": cfg.WorkspaceID,
|
||||
"instance_id": result.InstanceID,
|
||||
"state": result.State,
|
||||
"tier": cfg.Tier,
|
||||
"runtime": cfg.Runtime,
|
||||
})
|
||||
return result.InstanceID, nil
|
||||
}
|
||||
|
||||
@@ -273,6 +281,10 @@ func (p *CPProvisioner) Stop(ctx context.Context, workspaceID string) error {
|
||||
return fmt.Errorf("cp provisioner: stop %s: unexpected %d: %s",
|
||||
workspaceID, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
provlog.Event("provision.ec2_stopped", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"instance_id": instanceID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
// Package provlog emits structured, single-line JSON log records for
|
||||
// provisioning-lifecycle boundaries (workspace create, EC2 start/stop,
|
||||
// restart, idempotency skips). Records share a stable `evt:` prefix and
|
||||
// JSON payload so a future grep|jq pipeline (or a Loki/Datadog ingest)
|
||||
// can reconstruct the per-workspace timeline without parsing the
|
||||
// human-prose log lines that already exist.
|
||||
//
|
||||
// Existing log.Printf lines are intentionally NOT replaced — they
|
||||
// remain the operator-facing message. Event() emits a paired structured
|
||||
// record alongside, additive only.
|
||||
//
|
||||
// Event taxonomy (extend by appending; never rename):
|
||||
//
|
||||
// provision.start — workspace row inserted, EC2 about to launch
|
||||
// provision.skip_existing — idempotency hit, no new EC2
|
||||
// provision.ec2_started — RunInstances returned an instance id
|
||||
// provision.ec2_stopped — TerminateInstances acknowledged
|
||||
// restart.pre_stop — Restart handler about to call Stop
|
||||
//
|
||||
// Required fields per event are documented at each call site.
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
)
|
||||
|
||||
// Event writes a single line of the form:
|
||||
//
|
||||
// evt: <name> {"k":"v",...}
|
||||
//
|
||||
// to the standard logger. JSON encoding errors are silently swallowed —
|
||||
// a logging helper must never panic the request path. fields may be
|
||||
// nil; the empty payload `{}` is still useful to mark an event boundary.
|
||||
func Event(name string, fields map[string]any) {
|
||||
if fields == nil {
|
||||
fields = map[string]any{}
|
||||
}
|
||||
payload, err := json.Marshal(fields)
|
||||
if err != nil {
|
||||
// Fall back to a static payload so the event boundary still
|
||||
// appears in the log. The marshal error itself is recorded
|
||||
// on a best-effort basis.
|
||||
log.Printf("evt: %s {\"_marshal_err\":%q}", name, err.Error())
|
||||
return
|
||||
}
|
||||
log.Printf("evt: %s %s", name, payload)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// captureLog redirects the default logger to a buffer for the duration
|
||||
// of fn and returns whatever was written.
|
||||
func captureLog(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetOutput(&buf)
|
||||
log.SetFlags(0) // strip date/time so assertions stay deterministic
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
fn()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestEvent_EmitsEvtPrefixAndJSONPayload(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.start", map[string]any{
|
||||
"workspace_id": "ws-123",
|
||||
"tier": 4,
|
||||
"runtime": "claude-code",
|
||||
})
|
||||
})
|
||||
out = strings.TrimSpace(out)
|
||||
if !strings.HasPrefix(out, "evt: provision.start ") {
|
||||
t.Fatalf("expected evt-prefixed line, got %q", out)
|
||||
}
|
||||
jsonPart := strings.TrimPrefix(out, "evt: provision.start ")
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonPart), &got); err != nil {
|
||||
t.Fatalf("payload not valid JSON: %v (raw=%q)", err, jsonPart)
|
||||
}
|
||||
if got["workspace_id"] != "ws-123" {
|
||||
t.Errorf("workspace_id field lost: %+v", got)
|
||||
}
|
||||
// JSON unmarshal turns numbers into float64 — exact-equal compare.
|
||||
if got["tier"].(float64) != 4 {
|
||||
t.Errorf("tier field lost: %+v", got)
|
||||
}
|
||||
if got["runtime"] != "claude-code" {
|
||||
t.Errorf("runtime field lost: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_NilFieldsEmitsEmptyObject(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("restart.pre_stop", nil)
|
||||
})
|
||||
if !strings.Contains(out, "evt: restart.pre_stop {}") {
|
||||
t.Fatalf("nil fields should emit empty object, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_PreservesEventBoundaryOnUnmarshalableValue(t *testing.T) {
|
||||
// A channel cannot be marshaled by encoding/json — verify we still
|
||||
// emit the event boundary with a recorded marshal error. This is
|
||||
// the structural guarantee: the call site never sees a panic, and
|
||||
// the event name is always present in the log.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.ec2_started", map[string]any{
|
||||
"chan": make(chan int),
|
||||
})
|
||||
})
|
||||
if !strings.Contains(out, "evt: provision.ec2_started ") {
|
||||
t.Fatalf("event boundary missing on marshal error: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "_marshal_err") {
|
||||
t.Fatalf("expected _marshal_err sentinel, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_SingleLineOutput(t *testing.T) {
|
||||
// Log aggregators line-split on \n. A multi-line emit would silently
|
||||
// fragment the JSON across two records — pin single-line shape.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.skip_existing", map[string]any{
|
||||
"existing_id": "ws-abc",
|
||||
"name": "child-1",
|
||||
})
|
||||
})
|
||||
trimmed := strings.TrimRight(out, "\n")
|
||||
if strings.Contains(trimmed, "\n") {
|
||||
t.Fatalf("event line must be single-line, got %q", out)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/handlers"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/pendinguploads"
|
||||
memwiring "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/wiring"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/metrics"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/middleware"
|
||||
@@ -242,13 +243,15 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
// entire platform. Gated behind AdminAuth (issue #180).
|
||||
r.GET("/approvals/pending", middleware.AdminAuth(db.DB), apph.ListAll)
|
||||
|
||||
// Team handlers — Collapse only. The bulk-Expand path is gone:
|
||||
// every workspace can have children via the regular CreateWorkspace
|
||||
// flow with parent_id set, so a separate handler that bulk-creates
|
||||
// from sub_workspaces (and was non-idempotent — calling it twice
|
||||
// duplicated the team) earned its way out.
|
||||
teamh := handlers.NewTeamHandler(broadcaster, wh, platformURL, configsDir)
|
||||
wsAuth.POST("/collapse", teamh.Collapse)
|
||||
// (TeamHandler is gone — #2864.) The visual canvas Collapse
|
||||
// button calls PATCH /workspaces/:id { collapsed: true/false }
|
||||
// (presentational toggle on canvas_layouts), NOT the destructive
|
||||
// POST /collapse that stopped + removed children. The
|
||||
// destructive route had zero UI callers (verified via grep
|
||||
// across canvas/, scripts/, and the MCP tool registry — only
|
||||
// docs referenced it). team.go + team_test.go + the route
|
||||
// + helpers (findTemplateDirByName, NewTeamHandler) are
|
||||
// deleted; visual collapse is unaffected.
|
||||
|
||||
// Agents
|
||||
ah := handlers.NewAgentHandler(broadcaster)
|
||||
@@ -518,8 +521,9 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
r.GET("/canvas/viewport", vh.Get)
|
||||
r.PUT("/canvas/viewport", middleware.CanvasOrBearer(db.DB), vh.Save)
|
||||
|
||||
// Templates
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli)
|
||||
// Templates — wh threaded so generateDefaultConfig picks the
|
||||
// SaaS-aware default tier in Import + ReplaceFiles (#2910 PR-B).
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli, wh)
|
||||
// #686: GET /templates lists all template names+metadata from configsDir.
|
||||
// Open access lets unauthenticated callers enumerate org configurations and
|
||||
// installed plugins. AdminAuth-gate it alongside POST /templates/import.
|
||||
@@ -540,10 +544,20 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
// streaming download (agent → user). Namespaced under /chat/ so
|
||||
// the security model is obviously distinct from /files/* (which
|
||||
// handles workspace config/templates and has a different caller).
|
||||
chatfh := handlers.NewChatFilesHandler(tmplh)
|
||||
chatfh := handlers.NewChatFilesHandler(tmplh).
|
||||
WithPendingUploads(pendinguploads.NewPostgres(db.DB), broadcaster)
|
||||
wsAuth.POST("/chat/uploads", chatfh.Upload)
|
||||
wsAuth.GET("/chat/download", chatfh.Download)
|
||||
|
||||
// Phase 1 RFC: poll-mode chat upload — endpoints the workspace's
|
||||
// inbox poller hits to fetch staged file content + ack delivery.
|
||||
// Same wsAuth gate as the activity poll, so a token leak from
|
||||
// workspace A can't read workspace B's pending uploads (the
|
||||
// handler also re-checks workspace_id on each row).
|
||||
puh := handlers.NewPendingUploadsHandler(pendinguploads.NewPostgres(db.DB))
|
||||
wsAuth.GET("/pending-uploads/:file_id/content", puh.GetContent)
|
||||
wsAuth.POST("/pending-uploads/:file_id/ack", puh.Ack)
|
||||
|
||||
// Plugins
|
||||
pluginsDir := findPluginsDir(configsDir)
|
||||
// Runtime lookup lets the plugins handler filter the registry to plugins
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
cronlib "github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/metrics"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/supervised"
|
||||
)
|
||||
|
||||
@@ -741,6 +742,11 @@ func (s *Scheduler) sweepPhantomBusy(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
log.Printf("Scheduler: phantom-busy sweep — reset %s (no activity in %d min)", name, int(phantomStaleThreshold.Minutes()))
|
||||
// #2865: surface as molecule_phantom_busy_resets_total. High
|
||||
// reset rate signals task-lifecycle accounting regressions
|
||||
// (e.g. missing env vars causing claude --print timeouts that
|
||||
// leave active_tasks elevated until this sweep fires).
|
||||
metrics.TrackPhantomBusyReset()
|
||||
count++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
-- 20260505100000_pending_uploads.down.sql
|
||||
--
|
||||
-- Drops the pending_uploads table and its indexes. Any pending file
|
||||
-- uploads sitting in the table at rollback time are dropped — operators
|
||||
-- on poll-mode workspaces lose those attachments, but they were never
|
||||
-- fetched on the workspace side (otherwise they'd be acked + about to
|
||||
-- be GC'd anyway), so the practical loss is the same as a cron sweep.
|
||||
|
||||
DROP INDEX IF EXISTS idx_pending_uploads_expires;
|
||||
DROP INDEX IF EXISTS idx_pending_uploads_workspace_unacked;
|
||||
DROP TABLE IF EXISTS pending_uploads;
|
||||
@@ -0,0 +1,103 @@
|
||||
-- 20260505100000_pending_uploads.up.sql
|
||||
--
|
||||
-- RFC: poll-mode chat upload (counterpart to delivery_mode='poll' messaging).
|
||||
--
|
||||
-- Today, chat_files.go's Upload handler refuses delivery_mode != 'push'
|
||||
-- with HTTP 422 "workspace has no callback URL" — external runtime
|
||||
-- workspaces (laptop / behind NAT) cannot receive file attachments at all.
|
||||
-- The only escape was "register with ngrok / Cloudflare tunnel + push
|
||||
-- mode," which forces every external operator into infra plumbing they
|
||||
-- shouldn't need.
|
||||
--
|
||||
-- This table is the platform-side staging layer that lets canvas → external
|
||||
-- workspace file uploads ride the same poll loop the inbox already uses for
|
||||
-- text messages:
|
||||
--
|
||||
-- 1. Canvas POSTs multipart to workspace-server.
|
||||
-- 2. workspace-server parses multipart, stores each file as one
|
||||
-- pending_uploads row, AND inserts a matching activity_logs row
|
||||
-- (type='chat_upload_receive', request_body={file_id, filename, ...}).
|
||||
-- 3. Workspace's existing inbox poller picks up the activity row.
|
||||
-- 4. Workspace fetches bytes via GET /workspaces/:id/pending-uploads/:fid/content,
|
||||
-- writes to /workspace/.molecule/chat-uploads/, ACKs via POST.
|
||||
-- 5. Sweep cron deletes rows past expires_at OR acked_at + N hours.
|
||||
--
|
||||
-- Why a separate table and not bytea-on-activity_logs:
|
||||
--
|
||||
-- * activity_logs is text/JSON-shaped today; mixing 25 MB binary blobs
|
||||
-- into request_body inflates every JOIN, every since_id scan, every
|
||||
-- pgdump. The bytes need their own home.
|
||||
-- * Lifecycle differs: activity_logs is durable audit history (90d+);
|
||||
-- pending_uploads is transient buffer (24h default) that GCs hard.
|
||||
-- Keeping them split lets each table's retention policy run
|
||||
-- independently.
|
||||
-- * A future PR (RFC #2789) will migrate the bytes column to S3 keys
|
||||
-- without touching the activity_logs schema or the metadata columns
|
||||
-- here. That migration is one ALTER + one backfill rather than a
|
||||
-- cross-table rewrite.
|
||||
--
|
||||
-- No FK to workspaces:
|
||||
-- workspace delete should NOT cascade-purge pending_uploads — those
|
||||
-- rows are evidence-of-receipt and should expire on their own TTL.
|
||||
-- Same posture as tenant_resources (PR #2343) and delegations (PR #2829).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pending_uploads (
|
||||
-- Server-generated so the canvas can include the URI in the chat
|
||||
-- message it sends right after the upload POST. Workspace fetches
|
||||
-- by this id, no name collisions across workspaces.
|
||||
file_id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
|
||||
-- Target workspace. NOT a FK (see header).
|
||||
workspace_id uuid NOT NULL,
|
||||
|
||||
-- Content lives inline today via bytea. The Go-side storage interface
|
||||
-- (PendingUploadStorage) abstracts read/write so a future PR can
|
||||
-- relocate this column's job to S3 (RFC #2789) by adding an `s3_key
|
||||
-- text NULL` column, dual-writing for one release, then dropping
|
||||
-- `content` once the backfill drains. The CHECK below pins the same
|
||||
-- 25 MB per-file cap the workspace-side ingest_handler enforces
|
||||
-- (workspace/internal_chat_uploads.py:198) — discrepancy between
|
||||
-- the two would let the platform accept files the workspace would
|
||||
-- 413 on after pull.
|
||||
content bytea NOT NULL,
|
||||
size_bytes bigint NOT NULL CHECK (size_bytes > 0 AND size_bytes <= 26214400),
|
||||
|
||||
-- Filename + mimetype mirror the workspace-side ChatUploadedFile
|
||||
-- shape so the eventual InboxMessage hand-off needs no translation.
|
||||
-- Filename is sanitized at write-time (matches sanitize_filename in
|
||||
-- workspace/internal_chat_uploads.py); 100 char cap is the same.
|
||||
filename text NOT NULL CHECK (length(filename) > 0 AND length(filename) <= 100),
|
||||
mimetype text NOT NULL DEFAULT '',
|
||||
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
|
||||
-- Stamped on the GET /content request. Lets Phase 3 sweeper detect
|
||||
-- "fetched but never acked" — distinct failure mode from "never
|
||||
-- fetched" (workspace offline) so dashboards can split them.
|
||||
fetched_at timestamptz,
|
||||
|
||||
-- Stamped on the POST /ack request. Terminal state for the happy
|
||||
-- path. Sweep cron deletes acked rows past acked_at + retention.
|
||||
acked_at timestamptz,
|
||||
|
||||
-- Hard TTL: rows past this are deleted regardless of ack state.
|
||||
-- 24h matches the longest-observed legitimate "operator stepped
|
||||
-- away from laptop" gap; tunable later via app-level config without
|
||||
-- a migration. NOT acked_at + 24h — that would let a stuck-fetched
|
||||
-- row live forever.
|
||||
expires_at timestamptz NOT NULL DEFAULT (now() + interval '24 hours')
|
||||
);
|
||||
|
||||
-- Hot path: workspace's poll cycle pulls "give me my unacked uploads
|
||||
-- in chronological order." Partial-index because acked rows are GC
|
||||
-- candidates and shouldn't bloat the working set.
|
||||
CREATE INDEX IF NOT EXISTS idx_pending_uploads_workspace_unacked
|
||||
ON pending_uploads (workspace_id, created_at)
|
||||
WHERE acked_at IS NULL;
|
||||
|
||||
-- Phase 3 GC sweep hot path: list rows past expires_at, partial-indexed
|
||||
-- on unacked because acked rows have a different (shorter) retention
|
||||
-- and GC-via-acked_at is a separate query.
|
||||
CREATE INDEX IF NOT EXISTS idx_pending_uploads_expires
|
||||
ON pending_uploads (expires_at)
|
||||
WHERE acked_at IS NULL;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql.
|
||||
DROP INDEX IF EXISTS idx_pending_uploads_acked;
|
||||
@@ -0,0 +1,30 @@
|
||||
-- 20260505200000_pending_uploads_acked_index.up.sql
|
||||
--
|
||||
-- Adds the missing partial index for the acked-retention arm of the
|
||||
-- pendinguploads.Sweep query. The Phase 1 migration created two
|
||||
-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch
|
||||
-- hot path + expires_at sweep arm); the third query path —
|
||||
-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was
|
||||
-- left to a seq scan.
|
||||
--
|
||||
-- For a high-traffic deployment that's a real cost: the table
|
||||
-- accumulates one row per chat-attached file; the sweeper runs every
|
||||
-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq
|
||||
-- scan over 100K-1M acked rows holds an AccessShare lock for seconds
|
||||
-- on every cycle. Partial-indexing the inverse predicate reduces
|
||||
-- this to a btree range scan and lets the DELETE complete in
|
||||
-- low-millisecond range.
|
||||
--
|
||||
-- WHERE acked_at IS NOT NULL is intentionally inverse of the other
|
||||
-- two indexes — they cover the unacked working set; this covers the
|
||||
-- terminal-state set the sweeper visits. Disjoint subsets, so the
|
||||
-- two indexes don't overlap.
|
||||
--
|
||||
-- Caught in self-review on the parent RFC's Phase 4 PR; filed as
|
||||
-- a follow-up rather than a Phase 1 fix because the cost only
|
||||
-- materializes at a row count we don't expect to hit before the
|
||||
-- sweeper has had a chance to keep up.
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked
|
||||
ON pending_uploads (acked_at)
|
||||
WHERE acked_at IS NOT NULL;
|
||||
+25
-407
@@ -28,96 +28,20 @@ from platform_auth import list_registered_workspaces
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RBAC helpers (mirror builtin_tools/audit.py for a2a_tools isolation)
|
||||
# RBAC + auth helpers — extracted to a2a_tools_rbac (RFC #2873 iter 4a).
|
||||
# Re-exported here under the legacy underscore names so existing tests'
|
||||
# patch("a2a_tools._check_memory_write_permission", …) and call sites
|
||||
# inside this module that resolve bare names against the module-level
|
||||
# namespace continue to work unchanged.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ROLE_PERMISSIONS = {
|
||||
"admin": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
"operator": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
"read-only": {"memory.read"},
|
||||
"no-delegation": {"approve", "memory.read", "memory.write"},
|
||||
"no-approval": {"delegate", "memory.read", "memory.write"},
|
||||
"memory-readonly": {"memory.read"},
|
||||
}
|
||||
|
||||
|
||||
def _get_workspace_tier() -> int:
|
||||
"""Return the workspace tier from config (0 = root, 1+ = tenant)."""
|
||||
try:
|
||||
from config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
return getattr(cfg, "tier", 1)
|
||||
except Exception:
|
||||
return int(os.environ.get("WORKSPACE_TIER", 1))
|
||||
|
||||
|
||||
def _check_memory_write_permission() -> bool:
|
||||
"""Return True if this workspace's RBAC roles grant memory.write."""
|
||||
try:
|
||||
from config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
roles = list(getattr(cfg, "rbac", None).roles or ["operator"])
|
||||
allowed = dict(getattr(cfg, "rbac", None).allowed_actions or {})
|
||||
except Exception:
|
||||
# Fail closed: deny when config is unavailable
|
||||
roles = ["operator"]
|
||||
allowed = {}
|
||||
|
||||
for role in roles:
|
||||
if role == "admin":
|
||||
return True
|
||||
if role in allowed:
|
||||
if "memory.write" in allowed[role]:
|
||||
return True
|
||||
elif role in _ROLE_PERMISSIONS and "memory.write" in _ROLE_PERMISSIONS[role]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_memory_read_permission() -> bool:
|
||||
"""Return True if this workspace's RBAC roles grant memory.read."""
|
||||
try:
|
||||
from config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
roles = list(getattr(cfg, "rbac", None).roles or ["operator"])
|
||||
allowed = dict(getattr(cfg, "rbac", None).allowed_actions or {})
|
||||
except Exception:
|
||||
roles = ["operator"]
|
||||
allowed = {}
|
||||
|
||||
for role in roles:
|
||||
if role == "admin":
|
||||
return True
|
||||
if role in allowed:
|
||||
if "memory.read" in allowed[role]:
|
||||
return True
|
||||
elif role in _ROLE_PERMISSIONS and "memory.read" in _ROLE_PERMISSIONS[role]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_root_workspace() -> bool:
|
||||
"""Return True if this workspace is tier 0 (root/root-org)."""
|
||||
return _get_workspace_tier() == 0
|
||||
|
||||
|
||||
def _auth_headers_for_heartbeat(workspace_id: str | None = None) -> dict[str, str]:
|
||||
"""Return Phase 30.1 auth headers; tolerate platform_auth being absent
|
||||
in older installs (e.g. during rolling upgrade).
|
||||
|
||||
``workspace_id`` selects the per-workspace token from the multi-
|
||||
workspace registry when set (PR-1: external agent registered in
|
||||
multiple workspaces). With no arg the legacy single-token path is
|
||||
unchanged.
|
||||
"""
|
||||
try:
|
||||
from platform_auth import auth_headers
|
||||
return auth_headers(workspace_id) if workspace_id else auth_headers()
|
||||
except Exception:
|
||||
return {}
|
||||
from a2a_tools_rbac import ( # noqa: E402 (import after the from-a2a_client block)
|
||||
_auth_headers_for_heartbeat,
|
||||
_check_memory_read_permission,
|
||||
_check_memory_write_permission,
|
||||
_get_workspace_tier,
|
||||
_is_root_workspace,
|
||||
_ROLE_PERMISSIONS,
|
||||
)
|
||||
|
||||
|
||||
# Per-field caps on the heartbeat / activity payload. Borrowed from
|
||||
@@ -191,324 +115,18 @@ async def report_activity(
|
||||
pass # Best-effort — don't block delegation on activity reporting
|
||||
|
||||
|
||||
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
|
||||
# intentionally generous: 3s gives the platform's executeDelegation
|
||||
# goroutine room to dispatch + the callee to respond + the result to
|
||||
# write to activity_logs without thrashing the platform with rapid
|
||||
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
|
||||
# operators don't see behavior change beyond "no more 600s timeouts".
|
||||
_SYNC_POLL_INTERVAL_S = 3.0
|
||||
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
|
||||
|
||||
|
||||
async def _delegate_sync_via_polling(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
src: str,
|
||||
) -> str:
|
||||
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
|
||||
|
||||
Sidesteps the platform proxy's blocking `message/send` HTTP path that
|
||||
hits a hard 600s ceiling. Instead:
|
||||
|
||||
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
|
||||
— platform's executeDelegation goroutine handles A2A dispatch in
|
||||
the background. No client-side timeout dependency on the platform
|
||||
holding a connection open.
|
||||
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
|
||||
matching delegation_id reaching terminal status (completed/failed).
|
||||
3. Return the response_preview text on completed; surface error_detail
|
||||
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
|
||||
path uses, so caller error-detection logic is unchanged).
|
||||
|
||||
Both /delegate and /delegations are existing endpoints — this helper
|
||||
just composes them into a polling synchronous facade. The result is
|
||||
available the moment the platform writes the terminal status row;
|
||||
no extra latency vs. the legacy proxy-blocked path on fast cases.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
# 1. Dispatch via /delegate (the async, durable path).
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={
|
||||
"target_id": workspace_id,
|
||||
"task": task,
|
||||
"idempotency_key": idem_key,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
|
||||
|
||||
if resp.status_code != 202 and resp.status_code != 200:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
|
||||
|
||||
try:
|
||||
dispatch = resp.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
|
||||
|
||||
delegation_id = dispatch.get("delegation_id", "")
|
||||
if not delegation_id:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
|
||||
|
||||
# 2. Poll for terminal status with a deadline. Each poll is a cheap
|
||||
# /delegations GET — bounded by the platform's existing rate limit.
|
||||
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
|
||||
last_status = "unknown"
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
poll = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Transient — keep polling. The platform IS holding the
|
||||
# delegation row; we just lost a network request.
|
||||
last_status = f"poll-error: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
if poll.status_code != 200:
|
||||
last_status = f"poll HTTP {poll.status_code}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
try:
|
||||
rows = poll.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
last_status = f"poll non-JSON: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
# /delegations returns a flat list of delegation events. Filter to
|
||||
# our delegation_id; pick the first terminal one. The list may
|
||||
# have multiple rows per delegation_id (one for the original
|
||||
# dispatch, one per status update); we want the latest terminal.
|
||||
if not isinstance(rows, list):
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
terminal = None
|
||||
for r in rows:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
if r.get("delegation_id") != delegation_id:
|
||||
continue
|
||||
status = (r.get("status") or "").lower()
|
||||
last_status = status
|
||||
if status in ("completed", "failed"):
|
||||
terminal = r
|
||||
break
|
||||
if terminal:
|
||||
if (terminal.get("status") or "").lower() == "completed":
|
||||
return terminal.get("response_preview") or ""
|
||||
err = (
|
||||
terminal.get("error_detail")
|
||||
or terminal.get("summary")
|
||||
or "delegation failed"
|
||||
)
|
||||
return f"{_A2A_ERROR_PREFIX}{err}"
|
||||
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
|
||||
# Budget exhausted — the platform's row is still in flight (or queued).
|
||||
# Surface as an error so the caller can decide to retry or fall back;
|
||||
# the platform DOES still have the durable row, so the work isn't
|
||||
# lost — it'll complete eventually and a future check_task_status
|
||||
# will surface the result.
|
||||
return (
|
||||
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
|
||||
f"(delegation_id={delegation_id}, last_status={last_status}); "
|
||||
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
|
||||
)
|
||||
|
||||
|
||||
async def tool_delegate_task(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
delegation originates from — drives auth + the X-Workspace-ID source
|
||||
header so the platform's a2a_proxy logs the correct sender. Single-
|
||||
workspace operators leave it None and routing falls back to the
|
||||
module-level WORKSPACE_ID.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
# Auto-route: if source not specified, look up which registered
|
||||
# workspace last saw this peer (populated by tool_list_peers). Falls
|
||||
# back to the legacy WORKSPACE_ID for single-workspace operators.
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
|
||||
|
||||
# Discover the target. discover_peer is the access-control gate +
|
||||
# name/status lookup. The peer's reported ``url`` field is NOT used
|
||||
# for routing — see send_a2a_message, which constructs the URL via
|
||||
# the platform's A2A proxy.
|
||||
peer = await discover_peer(workspace_id, source_workspace_id=src)
|
||||
if not peer:
|
||||
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
|
||||
|
||||
if (peer.get("status") or "").lower() == "offline":
|
||||
return f"Error: workspace {workspace_id} is offline"
|
||||
|
||||
# Report delegation start — include the task text for traceability
|
||||
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
|
||||
_peer_names[workspace_id] = peer_name # cache for future use
|
||||
# Brief summary for canvas display — just the delegation target
|
||||
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
|
||||
|
||||
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
|
||||
# use the platform's durable async delegation API (POST /delegate +
|
||||
# poll /delegations) instead of the proxy-blocked message/send path.
|
||||
# This sidesteps the 600s message/send timeout class that broke
|
||||
# iteration-14/90-style long-running delegations on 2026-05-05.
|
||||
#
|
||||
# Default off — staging-canary first, flip default after PR-2's
|
||||
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
|
||||
# ≥1 week without incident.
|
||||
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
|
||||
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
|
||||
else:
|
||||
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
|
||||
# (the platform proxy) so the same code works for in-container and
|
||||
# external (standalone molecule-mcp) callers.
|
||||
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
|
||||
|
||||
# Detect delegation failures — wrap them clearly so the calling agent
|
||||
# can decide to retry, use another peer, or handle the task itself.
|
||||
is_error = result.startswith(_A2A_ERROR_PREFIX)
|
||||
# Strip the sentinel prefix so error_detail is the human-readable
|
||||
# cause directly. The Activity tab's red error chip surfaces this
|
||||
# without the user having to scroll into the raw response JSON.
|
||||
#
|
||||
# Cap at 4096 chars before sending — the platform's
|
||||
# activity_logs.error_detail column is unbounded TEXT and a
|
||||
# malicious or buggy peer could otherwise stream an arbitrarily
|
||||
# large error message into the caller's activity log. 4096 is
|
||||
# comfortably above any real exception traceback we've seen and
|
||||
# well below an obvious-DoS threshold.
|
||||
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
|
||||
await report_activity(
|
||||
"a2a_receive", workspace_id,
|
||||
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
|
||||
task_text=task, response_text=result,
|
||||
status="error" if is_error else "ok",
|
||||
error_detail=error_detail,
|
||||
)
|
||||
if is_error:
|
||||
return (
|
||||
f"DELEGATION FAILED to {peer_name}: {result}\n"
|
||||
f"You should either: (1) try a different peer, (2) handle this task yourself, "
|
||||
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def tool_delegate_task_async(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task via the platform's async delegation API (fire-and-forget).
|
||||
|
||||
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
|
||||
Results are tracked in the platform DB and broadcast via WebSocket.
|
||||
Use check_task_status to poll for results.
|
||||
|
||||
``source_workspace_id`` selects the sending workspace (which one of
|
||||
this agent's registered workspaces gets logged as the originator);
|
||||
auto-routes via the peer→source cache when omitted.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
|
||||
|
||||
# Idempotency key: SHA-256 of (source, target, task) so that a
|
||||
# restarted agent firing the same delegation gets the same key and
|
||||
# the platform returns the existing delegation_id instead of
|
||||
# creating a duplicate. Fixes #1456. Source is in the key so the
|
||||
# SAME task delegated from two different registered workspaces
|
||||
# produces two distinct delegations (the right behavior — one per
|
||||
# tenant audit trail).
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code == 202:
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"delegation_id": data.get("delegation_id", ""),
|
||||
"workspace_id": workspace_id,
|
||||
"status": "delegated",
|
||||
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
|
||||
})
|
||||
else:
|
||||
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
|
||||
except Exception as e:
|
||||
return f"Error: delegation failed — {e}"
|
||||
|
||||
|
||||
async def tool_check_task_status(
|
||||
workspace_id: str,
|
||||
task_id: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Check delegations for this workspace via the platform API.
|
||||
|
||||
Args:
|
||||
workspace_id: Ignored (kept for backward compat). Checks
|
||||
``source_workspace_id``'s delegations (the workspace that
|
||||
FIRED the delegations), not the target's.
|
||||
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
|
||||
source_workspace_id: Which registered workspace's delegation log
|
||||
to query. Defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: failed to check delegations ({resp.status_code})"
|
||||
delegations = resp.json()
|
||||
if task_id:
|
||||
# Filter by delegation_id
|
||||
matching = [d for d in delegations if d.get("delegation_id") == task_id]
|
||||
if matching:
|
||||
return json.dumps(matching[0])
|
||||
return json.dumps({"status": "not_found", "delegation_id": task_id})
|
||||
# Return all recent delegations
|
||||
summary = []
|
||||
for d in delegations[:10]:
|
||||
summary.append({
|
||||
"delegation_id": d.get("delegation_id", ""),
|
||||
"target_id": d.get("target_id", ""),
|
||||
"status": d.get("status", ""),
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
})
|
||||
return json.dumps({"delegations": summary, "count": len(delegations)})
|
||||
except Exception as e:
|
||||
return f"Error checking delegations: {e}"
|
||||
# Delegation tool handlers — extracted to a2a_tools_delegation
|
||||
# (RFC #2873 iter 4b). Re-imported here so call sites + tests that
|
||||
# reference ``a2a_tools.tool_delegate_task`` /
|
||||
# ``a2a_tools._delegate_sync_via_polling`` keep resolving identically.
|
||||
from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_client block)
|
||||
_SYNC_POLL_BUDGET_S,
|
||||
_SYNC_POLL_INTERVAL_S,
|
||||
_delegate_sync_via_polling,
|
||||
tool_check_task_status,
|
||||
tool_delegate_task,
|
||||
tool_delegate_task_async,
|
||||
)
|
||||
|
||||
|
||||
async def _upload_chat_files(
|
||||
|
||||
@@ -0,0 +1,372 @@
|
||||
"""Delegation tool handlers — single-concern slice of the a2a_tools surface.
|
||||
|
||||
Extracted from ``a2a_tools.py`` (RFC #2873 iter 4b). Owns the three
|
||||
delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they
|
||||
share.
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``tool_delegate_task`` — synchronous delegation, waits for response.
|
||||
* ``tool_delegate_task_async`` — fire-and-forget delegation; returns
|
||||
``{delegation_id, ...}``.
|
||||
* ``tool_check_task_status`` — poll the platform's ``/delegations`` log.
|
||||
|
||||
Internal:
|
||||
|
||||
* ``_delegate_sync_via_polling`` — durable async + poll for terminal
|
||||
status (RFC #2829 PR-5 cutover path; toggled by
|
||||
``DELEGATION_SYNC_VIA_INBOX=1``).
|
||||
* ``_SYNC_POLL_INTERVAL_S`` / ``_SYNC_POLL_BUDGET_S`` constants.
|
||||
|
||||
Circular-import note: this module calls ``report_activity`` from
|
||||
``a2a_tools`` to emit activity rows around the delegate dispatch.
|
||||
``a2a_tools`` imports the public symbols here at module-load time,
|
||||
so we use a LAZY import for ``report_activity`` inside the function
|
||||
that needs it. Without the lazy hop Python raises an ImportError
|
||||
on first ``a2a_tools`` import.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from a2a_client import (
|
||||
PLATFORM_URL,
|
||||
WORKSPACE_ID,
|
||||
_A2A_ERROR_PREFIX,
|
||||
_peer_names,
|
||||
_peer_to_source,
|
||||
discover_peer,
|
||||
send_a2a_message,
|
||||
)
|
||||
from a2a_tools_rbac import auth_headers_for_heartbeat as _auth_headers_for_heartbeat
|
||||
|
||||
|
||||
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
|
||||
# intentionally generous: 3s gives the platform's executeDelegation
|
||||
# goroutine room to dispatch + the callee to respond + the result to
|
||||
# write to activity_logs without thrashing the platform with rapid
|
||||
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
|
||||
# operators don't see behavior change beyond "no more 600s timeouts".
|
||||
_SYNC_POLL_INTERVAL_S = 3.0
|
||||
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
|
||||
|
||||
|
||||
async def _delegate_sync_via_polling(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
src: str,
|
||||
) -> str:
|
||||
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
|
||||
|
||||
Sidesteps the platform proxy's blocking `message/send` HTTP path that
|
||||
hits a hard 600s ceiling. Instead:
|
||||
|
||||
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
|
||||
— platform's executeDelegation goroutine handles A2A dispatch in
|
||||
the background. No client-side timeout dependency on the platform
|
||||
holding a connection open.
|
||||
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
|
||||
matching delegation_id reaching terminal status (completed/failed).
|
||||
3. Return the response_preview text on completed; surface error_detail
|
||||
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
|
||||
path uses, so caller error-detection logic is unchanged).
|
||||
|
||||
Both /delegate and /delegations are existing endpoints — this helper
|
||||
just composes them into a polling synchronous facade. The result is
|
||||
available the moment the platform writes the terminal status row;
|
||||
no extra latency vs. the legacy proxy-blocked path on fast cases.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
# 1. Dispatch via /delegate (the async, durable path).
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={
|
||||
"target_id": workspace_id,
|
||||
"task": task,
|
||||
"idempotency_key": idem_key,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
|
||||
|
||||
if resp.status_code != 202 and resp.status_code != 200:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
|
||||
|
||||
try:
|
||||
dispatch = resp.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
|
||||
|
||||
delegation_id = dispatch.get("delegation_id", "")
|
||||
if not delegation_id:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
|
||||
|
||||
# 2. Poll for terminal status with a deadline. Each poll is a cheap
|
||||
# /delegations GET — bounded by the platform's existing rate limit.
|
||||
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
|
||||
last_status = "unknown"
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
poll = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Transient — keep polling. The platform IS holding the
|
||||
# delegation row; we just lost a network request.
|
||||
last_status = f"poll-error: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
if poll.status_code != 200:
|
||||
last_status = f"poll HTTP {poll.status_code}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
try:
|
||||
rows = poll.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
last_status = f"poll non-JSON: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
# /delegations returns a flat list of delegation events. Filter to
|
||||
# our delegation_id; pick the first terminal one. The list may
|
||||
# have multiple rows per delegation_id (one for the original
|
||||
# dispatch, one per status update); we want the latest terminal.
|
||||
if not isinstance(rows, list):
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
terminal = None
|
||||
for r in rows:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
if r.get("delegation_id") != delegation_id:
|
||||
continue
|
||||
status = (r.get("status") or "").lower()
|
||||
last_status = status
|
||||
if status in ("completed", "failed"):
|
||||
terminal = r
|
||||
break
|
||||
if terminal:
|
||||
if (terminal.get("status") or "").lower() == "completed":
|
||||
return terminal.get("response_preview") or ""
|
||||
err = (
|
||||
terminal.get("error_detail")
|
||||
or terminal.get("summary")
|
||||
or "delegation failed"
|
||||
)
|
||||
return f"{_A2A_ERROR_PREFIX}{err}"
|
||||
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
|
||||
# Budget exhausted — the platform's row is still in flight (or queued).
|
||||
# Surface as an error so the caller can decide to retry or fall back;
|
||||
# the platform DOES still have the durable row, so the work isn't
|
||||
# lost — it'll complete eventually and a future check_task_status
|
||||
# will surface the result.
|
||||
return (
|
||||
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
|
||||
f"(delegation_id={delegation_id}, last_status={last_status}); "
|
||||
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
|
||||
)
|
||||
|
||||
|
||||
async def tool_delegate_task(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
delegation originates from — drives auth + the X-Workspace-ID source
|
||||
header so the platform's a2a_proxy logs the correct sender. Single-
|
||||
workspace operators leave it None and routing falls back to the
|
||||
module-level WORKSPACE_ID.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
# Auto-route: if source not specified, look up which registered
|
||||
# workspace last saw this peer (populated by tool_list_peers). Falls
|
||||
# back to the legacy WORKSPACE_ID for single-workspace operators.
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
|
||||
|
||||
# Discover the target. discover_peer is the access-control gate +
|
||||
# name/status lookup. The peer's reported ``url`` field is NOT used
|
||||
# for routing — see send_a2a_message, which constructs the URL via
|
||||
# the platform's A2A proxy.
|
||||
peer = await discover_peer(workspace_id, source_workspace_id=src)
|
||||
if not peer:
|
||||
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
|
||||
|
||||
if (peer.get("status") or "").lower() == "offline":
|
||||
return f"Error: workspace {workspace_id} is offline"
|
||||
|
||||
# Lazy import: a2a_tools imports this module at top-level, so a
|
||||
# top-level import of report_activity from a2a_tools would create a
|
||||
# circular dependency at first-import time. Lazy resolution inside
|
||||
# the function body breaks the cycle without forcing a ground-up
|
||||
# restructure of the activity-reporting layer.
|
||||
from a2a_tools import report_activity
|
||||
|
||||
# Report delegation start — include the task text for traceability
|
||||
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
|
||||
_peer_names[workspace_id] = peer_name # cache for future use
|
||||
# Brief summary for canvas display — just the delegation target
|
||||
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
|
||||
|
||||
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
|
||||
# use the platform's durable async delegation API (POST /delegate +
|
||||
# poll /delegations) instead of the proxy-blocked message/send path.
|
||||
# This sidesteps the 600s message/send timeout class that broke
|
||||
# iteration-14/90-style long-running delegations on 2026-05-05.
|
||||
#
|
||||
# Default off — staging-canary first, flip default after PR-2's
|
||||
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
|
||||
# ≥1 week without incident.
|
||||
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
|
||||
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
|
||||
else:
|
||||
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
|
||||
# (the platform proxy) so the same code works for in-container and
|
||||
# external (standalone molecule-mcp) callers.
|
||||
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
|
||||
|
||||
# Detect delegation failures — wrap them clearly so the calling agent
|
||||
# can decide to retry, use another peer, or handle the task itself.
|
||||
is_error = result.startswith(_A2A_ERROR_PREFIX)
|
||||
# Strip the sentinel prefix so error_detail is the human-readable
|
||||
# cause directly. The Activity tab's red error chip surfaces this
|
||||
# without the user having to scroll into the raw response JSON.
|
||||
#
|
||||
# Cap at 4096 chars before sending — the platform's
|
||||
# activity_logs.error_detail column is unbounded TEXT and a
|
||||
# malicious or buggy peer could otherwise stream an arbitrarily
|
||||
# large error message into the caller's activity log. 4096 is
|
||||
# comfortably above any real exception traceback we've seen and
|
||||
# well below an obvious-DoS threshold.
|
||||
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
|
||||
await report_activity(
|
||||
"a2a_receive", workspace_id,
|
||||
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
|
||||
task_text=task, response_text=result,
|
||||
status="error" if is_error else "ok",
|
||||
error_detail=error_detail,
|
||||
)
|
||||
if is_error:
|
||||
return (
|
||||
f"DELEGATION FAILED to {peer_name}: {result}\n"
|
||||
f"You should either: (1) try a different peer, (2) handle this task yourself, "
|
||||
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def tool_delegate_task_async(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task via the platform's async delegation API (fire-and-forget).
|
||||
|
||||
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
|
||||
Results are tracked in the platform DB and broadcast via WebSocket.
|
||||
Use check_task_status to poll for results.
|
||||
|
||||
``source_workspace_id`` selects the sending workspace (which one of
|
||||
this agent's registered workspaces gets logged as the originator);
|
||||
auto-routes via the peer→source cache when omitted.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
|
||||
|
||||
# Idempotency key: SHA-256 of (source, target, task) so that a
|
||||
# restarted agent firing the same delegation gets the same key and
|
||||
# the platform returns the existing delegation_id instead of
|
||||
# creating a duplicate. Fixes #1456. Source is in the key so the
|
||||
# SAME task delegated from two different registered workspaces
|
||||
# produces two distinct delegations (the right behavior — one per
|
||||
# tenant audit trail).
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code == 202:
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"delegation_id": data.get("delegation_id", ""),
|
||||
"workspace_id": workspace_id,
|
||||
"status": "delegated",
|
||||
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
|
||||
})
|
||||
else:
|
||||
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
|
||||
except Exception as e:
|
||||
return f"Error: delegation failed — {e}"
|
||||
|
||||
|
||||
async def tool_check_task_status(
|
||||
workspace_id: str,
|
||||
task_id: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Check delegations for this workspace via the platform API.
|
||||
|
||||
Args:
|
||||
workspace_id: Ignored (kept for backward compat). Checks
|
||||
``source_workspace_id``'s delegations (the workspace that
|
||||
FIRED the delegations), not the target's.
|
||||
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
|
||||
source_workspace_id: Which registered workspace's delegation log
|
||||
to query. Defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: failed to check delegations ({resp.status_code})"
|
||||
delegations = resp.json()
|
||||
if task_id:
|
||||
# Filter by delegation_id
|
||||
matching = [d for d in delegations if d.get("delegation_id") == task_id]
|
||||
if matching:
|
||||
return json.dumps(matching[0])
|
||||
return json.dumps({"status": "not_found", "delegation_id": task_id})
|
||||
# Return all recent delegations
|
||||
summary = []
|
||||
for d in delegations[:10]:
|
||||
summary.append({
|
||||
"delegation_id": d.get("delegation_id", ""),
|
||||
"target_id": d.get("target_id", ""),
|
||||
"status": d.get("status", ""),
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
})
|
||||
return json.dumps({"delegations": summary, "count": len(delegations)})
|
||||
except Exception as e:
|
||||
return f"Error checking delegations: {e}"
|
||||
@@ -0,0 +1,138 @@
|
||||
"""RBAC + auth-header helpers shared by all a2a_tools tool handlers.
|
||||
|
||||
Extracted from ``a2a_tools.py`` (RFC #2873 iter 4a). Centralises the
|
||||
"what can this workspace do" + "how do I prove it on a platform call"
|
||||
concerns into a single module so:
|
||||
|
||||
* Future tools added under ``a2a_tools/`` see one obvious helper to
|
||||
call instead of re-implementing the role/tier check.
|
||||
* The role-permission table is in ONE place — adding a new role
|
||||
or capability touches one file, not every tool that gates on it.
|
||||
* Tests targeting these helpers don't have to import the whole
|
||||
991-LOC ``a2a_tools`` surface.
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``ROLE_PERMISSIONS`` — canonical role → action set table.
|
||||
* ``get_workspace_tier()`` — config-resolved tier (0 = root).
|
||||
* ``check_memory_write_permission()`` — boolean.
|
||||
* ``check_memory_read_permission()`` — boolean.
|
||||
* ``is_root_workspace()`` — boolean (tier == 0).
|
||||
* ``auth_headers_for_heartbeat(workspace_id=None)`` — auth-header dict
|
||||
with the multi-workspace registry lookup; tolerates ``platform_auth``
|
||||
missing on older installs (returns ``{}``).
|
||||
|
||||
Underscore-prefixed back-compat aliases (``_ROLE_PERMISSIONS``,
|
||||
``_check_memory_write_permission``, etc.) match the names previously
|
||||
exposed in ``a2a_tools`` so existing tests'
|
||||
``patch("a2a_tools._foo", ...)`` continue to work via the re-exports
|
||||
in ``a2a_tools.py``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
# Mirror ``builtin_tools/audit.py`` for a2a_tools isolation. Listed as a
|
||||
# module-level constant rather than computed lazily so the table is
|
||||
# discoverable in static analysis + ``grep``.
|
||||
ROLE_PERMISSIONS: dict[str, set[str]] = {
|
||||
"admin": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
"operator": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
"read-only": {"memory.read"},
|
||||
"no-delegation": {"approve", "memory.read", "memory.write"},
|
||||
"no-approval": {"delegate", "memory.read", "memory.write"},
|
||||
"memory-readonly": {"memory.read"},
|
||||
}
|
||||
|
||||
|
||||
def get_workspace_tier() -> int:
|
||||
"""Return the workspace tier from config (0 = root, 1+ = tenant)."""
|
||||
try:
|
||||
from config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
return getattr(cfg, "tier", 1)
|
||||
except Exception:
|
||||
return int(os.environ.get("WORKSPACE_TIER", 1))
|
||||
|
||||
|
||||
def _resolve_role_state() -> tuple[list[str], dict]:
|
||||
"""Return (roles, allowed_actions) from config.
|
||||
|
||||
Fail-closed: if config is unavailable, fall back to an "operator"
|
||||
default with no per-role overrides. Operator has memory.read +
|
||||
memory.write but not the elevated approve/delegate over GLOBAL
|
||||
scope, so a config outage doesn't grant unexpected privileges.
|
||||
"""
|
||||
try:
|
||||
from config import load_config
|
||||
|
||||
cfg = load_config()
|
||||
roles = list(getattr(cfg, "rbac", None).roles or ["operator"])
|
||||
allowed = dict(getattr(cfg, "rbac", None).allowed_actions or {})
|
||||
return roles, allowed
|
||||
except Exception:
|
||||
return ["operator"], {}
|
||||
|
||||
|
||||
def check_memory_write_permission() -> bool:
|
||||
"""Return True if this workspace's RBAC roles grant memory.write."""
|
||||
roles, allowed = _resolve_role_state()
|
||||
for role in roles:
|
||||
if role == "admin":
|
||||
return True
|
||||
if role in allowed:
|
||||
if "memory.write" in allowed[role]:
|
||||
return True
|
||||
elif role in ROLE_PERMISSIONS and "memory.write" in ROLE_PERMISSIONS[role]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_memory_read_permission() -> bool:
|
||||
"""Return True if this workspace's RBAC roles grant memory.read."""
|
||||
roles, allowed = _resolve_role_state()
|
||||
for role in roles:
|
||||
if role == "admin":
|
||||
return True
|
||||
if role in allowed:
|
||||
if "memory.read" in allowed[role]:
|
||||
return True
|
||||
elif role in ROLE_PERMISSIONS and "memory.read" in ROLE_PERMISSIONS[role]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_root_workspace() -> bool:
|
||||
"""Return True if this workspace is tier 0 (root/root-org)."""
|
||||
return get_workspace_tier() == 0
|
||||
|
||||
|
||||
def auth_headers_for_heartbeat(workspace_id: str | None = None) -> dict[str, str]:
|
||||
"""Return Phase 30.1 auth headers; tolerate platform_auth being absent
|
||||
in older installs (e.g. during rolling upgrade).
|
||||
|
||||
``workspace_id`` selects the per-workspace token from the multi-
|
||||
workspace registry when set (PR-1: external agent registered in
|
||||
multiple workspaces). With no arg the legacy single-token path is
|
||||
unchanged.
|
||||
"""
|
||||
try:
|
||||
from platform_auth import auth_headers
|
||||
return auth_headers(workspace_id) if workspace_id else auth_headers()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
# ============== Back-compat aliases for the previous a2a_tools names ==============
|
||||
# Tests + downstream call sites refer to the pre-extract names; aliasing
|
||||
# keeps both forms valid. The new public names (no underscore prefix)
|
||||
# are preferred for new code.
|
||||
|
||||
_ROLE_PERMISSIONS = ROLE_PERMISSIONS
|
||||
_get_workspace_tier = get_workspace_tier
|
||||
_check_memory_write_permission = check_memory_write_permission
|
||||
_check_memory_read_permission = check_memory_read_permission
|
||||
_is_root_workspace = is_root_workspace
|
||||
_auth_headers_for_heartbeat = auth_headers_for_heartbeat
|
||||
+79
-2
@@ -432,7 +432,17 @@ def _is_self_notify_row(row: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def message_from_activity(row: dict[str, Any]) -> InboxMessage:
|
||||
"""Convert one /activity row into an InboxMessage."""
|
||||
"""Convert one /activity row into an InboxMessage.
|
||||
|
||||
Mutates ``row['request_body']`` in-place to swap any
|
||||
``platform-pending:`` URIs to the locally-staged ``workspace:`` URIs
|
||||
(see ``inbox_uploads.rewrite_request_body``) — by the time the
|
||||
upstream chat message arrives via this path, the upload-receive row
|
||||
that staged the bytes has already populated the URI cache (lower
|
||||
activity_logs.id, processed earlier in the same poll batch). A
|
||||
cache miss leaves the URI untouched; the agent surfaces an
|
||||
unresolvable URI rather than the inbox silently dropping the part.
|
||||
"""
|
||||
request_body = row.get("request_body")
|
||||
if isinstance(request_body, str):
|
||||
# The Go handler returns request_body as json.RawMessage; httpx
|
||||
@@ -443,6 +453,14 @@ def message_from_activity(row: dict[str, Any]) -> InboxMessage:
|
||||
except (TypeError, ValueError):
|
||||
request_body = None
|
||||
|
||||
# Rewrite platform-pending: URIs → workspace: URIs in-place. Imported
|
||||
# at call time to keep the import graph clean for the in-container
|
||||
# path that doesn't use this module (also avoids a circular: the
|
||||
# uploads module is small enough that re-importing per call is
|
||||
# cheap, and the Python import cache makes it free after the first).
|
||||
from inbox_uploads import rewrite_request_body
|
||||
rewrite_request_body(request_body)
|
||||
|
||||
return InboxMessage(
|
||||
activity_id=str(row.get("id", "")),
|
||||
text=_extract_text(request_body, row.get("summary")),
|
||||
@@ -532,11 +550,57 @@ def _poll_once(
|
||||
if cursor is None:
|
||||
rows = list(reversed(rows))
|
||||
|
||||
# Imported lazily at use-site so a runtime that never sees an
|
||||
# upload-receive row never imports the module. Cheap on the hot
|
||||
# path because Python caches the import.
|
||||
from inbox_uploads import is_chat_upload_row, BatchFetcher
|
||||
|
||||
new_count = 0
|
||||
last_id: str | None = None
|
||||
# ``batch_fetcher`` is lazy: a poll batch with no upload rows pays
|
||||
# zero overhead. Once the first upload row appears we open one
|
||||
# BatchFetcher and submit every subsequent upload row to its thread
|
||||
# pool; before processing the FIRST non-upload row we drain the
|
||||
# pool (wait_all) so the URI cache is hot when message rewriting
|
||||
# runs. Without the barrier, the chat message that references the
|
||||
# upload would arrive at the agent with the un-rewritten
|
||||
# platform-pending: URI.
|
||||
batch_fetcher: BatchFetcher | None = None
|
||||
|
||||
def _drain_uploads(bf: BatchFetcher | None) -> None:
|
||||
if bf is None:
|
||||
return
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
if is_chat_upload_row(row):
|
||||
# Side-effect row from the platform's poll-mode chat-upload
|
||||
# handler — fetch the bytes, stage to /workspace/.molecule/
|
||||
# chat-uploads, ack. NOT enqueued as an InboxMessage; the
|
||||
# agent will see the chat message that REFERENCES this
|
||||
# upload via a separate (later) activity row, with the
|
||||
# pending: URI rewritten to a workspace: URI by
|
||||
# message_from_activity. We DO advance the cursor past
|
||||
# this row so a permanent network outage on /content
|
||||
# doesn't stall the cursor and block real chat traffic.
|
||||
if batch_fetcher is None:
|
||||
batch_fetcher = BatchFetcher(
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
)
|
||||
batch_fetcher.submit(row)
|
||||
last_id = str(row.get("id", "")) or last_id
|
||||
continue
|
||||
# Non-upload row: drain any pending uploads first so the URI
|
||||
# cache is populated before we run rewrite_request_body /
|
||||
# message_from_activity on a row that may reference one.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
batch_fetcher = None
|
||||
if _is_self_notify_row(row):
|
||||
# The workspace-server's `/notify` handler writes the agent's
|
||||
# own send_message_to_user POSTs to activity_logs with
|
||||
@@ -571,6 +635,13 @@ def _poll_once(
|
||||
last_id = message.activity_id
|
||||
new_count += 1
|
||||
|
||||
# Drain any uploads still in flight if the batch ended with upload
|
||||
# rows (no chat-message row to trigger the inline drain). Without
|
||||
# this, a future poll that picks up the chat-message row first
|
||||
# would race with the still-running fetches.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
|
||||
if last_id is not None:
|
||||
state.save_cursor(last_id, cursor_key)
|
||||
return new_count
|
||||
@@ -613,6 +684,7 @@ def start_poller_thread(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
interval: float = POLL_INTERVAL_SECONDS,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> threading.Thread:
|
||||
"""Spawn the poller as a daemon thread. Returns the Thread handle.
|
||||
|
||||
@@ -624,13 +696,18 @@ def start_poller_thread(
|
||||
operator running ``ps -eL`` or eyeballing ``threading.enumerate()``
|
||||
can tell which thread is which without reverse-engineering it from
|
||||
crash tracebacks.
|
||||
|
||||
Pass ``stop_event`` to enable graceful shutdown — used by tests so
|
||||
the daemon thread doesn't outlive the test that started it and race
|
||||
with later tests' httpx patches. Production code passes None and
|
||||
relies on the daemon flag for process-exit cleanup.
|
||||
"""
|
||||
name = "molecule-mcp-inbox-poller"
|
||||
if workspace_id:
|
||||
name = f"{name}-{workspace_id[:8]}"
|
||||
t = threading.Thread(
|
||||
target=_poll_loop,
|
||||
args=(state, platform_url, workspace_id, interval),
|
||||
args=(state, platform_url, workspace_id, interval, stop_event),
|
||||
name=name,
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,724 @@
|
||||
"""Poll-mode chat-upload fetcher + URI cache for the standalone path.
|
||||
|
||||
Companion to ``inbox.py``. When the workspace's inbox poller sees an
|
||||
``activity_logs`` row with ``method='chat_upload_receive'`` (written by
|
||||
the platform's ``uploadPollMode`` handler — workspace-server
|
||||
``internal/handlers/chat_files.go``), this module:
|
||||
|
||||
1. Pulls the bytes from
|
||||
``GET /workspaces/:id/pending-uploads/:file_id/content``.
|
||||
2. Writes them to ``/workspace/.molecule/chat-uploads/<prefix>-<name>``
|
||||
— same on-disk shape as the push-mode handler in
|
||||
``internal_chat_uploads.py``, so anything downstream that already
|
||||
resolves ``workspace:/workspace/.molecule/chat-uploads/...`` URIs
|
||||
works unchanged.
|
||||
3. POSTs ``/workspaces/:id/pending-uploads/:file_id/ack`` so Phase 3
|
||||
sweep can clean up the platform-side ``pending_uploads`` row.
|
||||
4. Records a ``platform-pending:<wsid>/<file_id> →
|
||||
workspace:/workspace/.molecule/chat-uploads/...`` mapping in a
|
||||
process-local cache so the chat message that arrives later
|
||||
(referencing the platform-pending URI) gets rewritten before the
|
||||
agent sees it.
|
||||
|
||||
URI rewrite ordering — the chat message containing the
|
||||
``platform-pending:`` URI is logged by the platform AFTER the
|
||||
``chat_upload_receive`` row, so the inbox poller sees the upload-receive
|
||||
row first (lower activity_logs.id) and stages the bytes before the chat
|
||||
message arrives in the same poll batch (or a later one). The URI cache
|
||||
is therefore populated before the message_from_activity path needs it.
|
||||
A miss (network race, restart with stale cursor) is handled by keeping
|
||||
the original ``platform-pending:`` URI in the rewritten body — the agent
|
||||
will see something it can't open, which is preferable to silently
|
||||
dropping the URI.
|
||||
|
||||
Auth — same Bearer token the inbox poller uses (``platform_auth.auth_headers``).
|
||||
Both endpoints are on the wsAuth-gated route, so this module can never
|
||||
read another tenant's bytes even if a token is misrouted.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import secrets as pysecrets
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Same on-disk root as internal_chat_uploads.CHAT_UPLOAD_DIR — keeping
|
||||
# these decoupled would let drift sneak in. Imported here rather than
|
||||
# from internal_chat_uploads to avoid pulling in starlette as a
|
||||
# transitive dep (this module runs in the standalone MCP path which
|
||||
# doesn't ship the in-container HTTP server).
|
||||
CHAT_UPLOAD_DIR = "/workspace/.molecule/chat-uploads"
|
||||
|
||||
# Per-file safety net. The platform enforces 25 MB on the staging side,
|
||||
# but a buggy or hostile platform response shouldn't be able to fill the
|
||||
# workspace's disk — refuse to write more than this even if the response
|
||||
# claims a larger Content-Length.
|
||||
MAX_FILE_BYTES = 25 * 1024 * 1024
|
||||
|
||||
# Network deadline for the GET. Tuned for a 25 MB transfer over a
|
||||
# reasonable consumer link (~5 Mbps gives ~40s for the full payload),
|
||||
# plus headroom for TLS + platform auth. Aligned with inbox poller's
|
||||
# 10s default for /activity calls — both are user-perceived latency.
|
||||
DEFAULT_FETCH_TIMEOUT = 60.0
|
||||
|
||||
# Concurrency cap for ``BatchFetcher``. Four workers is enough headroom
|
||||
# for the realistic "user dragged 3-4 files into chat at once" case
|
||||
# while bounding the platform's per-workspace fan-out. The cap matters
|
||||
# because the platform's /content endpoint reads bytea from Postgres in
|
||||
# a single round-trip per request — N workers = N concurrent DB reads
|
||||
# of up to 25 MB each, so a higher cap could pressure platform memory
|
||||
# without much UX win (network bandwidth is the bottleneck once the
|
||||
# bytes are buffered).
|
||||
DEFAULT_BATCH_FETCH_WORKERS = 4
|
||||
|
||||
# Upper bound on how long ``BatchFetcher.wait_all`` blocks the inbox
|
||||
# poll loop before giving up on still-in-flight fetches. Aligned with
|
||||
# DEFAULT_FETCH_TIMEOUT so a single hung fetch can't stall the loop
|
||||
# longer than its own deadline. A timeout fires only if a worker thread
|
||||
# is stuck past the underlying httpx timeout — pathological case;
|
||||
# normal completion is bounded by per-fetch timeout × ceil(N/W).
|
||||
DEFAULT_BATCH_WAIT_TIMEOUT = DEFAULT_FETCH_TIMEOUT + 5.0
|
||||
|
||||
# Cap on the URI cache. A long-lived workspace handling thousands of
|
||||
# uploads shouldn't grow without bound; an LRU cap of 1024 keeps the
|
||||
# entries-needed-for-a-typical-conversation well within memory.
|
||||
URI_CACHE_MAX_ENTRIES = 1024
|
||||
|
||||
# Same character class as internal_chat_uploads — kept duplicated rather
|
||||
# than imported to avoid dragging starlette into the standalone path.
|
||||
_UNSAFE_FILENAME_CHARS = re.compile(r"[^a-zA-Z0-9._\-]")
|
||||
|
||||
|
||||
def sanitize_filename(name: str) -> str:
|
||||
"""Reduce a user-supplied filename to a safe form.
|
||||
|
||||
Mirrors ``internal_chat_uploads.sanitize_filename`` and the Go
|
||||
handler's ``SanitizeFilename`` — three-way parity is pinned by
|
||||
``workspace-server/internal/handlers/sanitize_filename_test.go`` and
|
||||
``workspace/tests/test_internal_chat_uploads.py`` so the URI shape
|
||||
is identical regardless of which path handles the upload.
|
||||
"""
|
||||
base = os.path.basename(name)
|
||||
base = base.replace(" ", "_")
|
||||
base = _UNSAFE_FILENAME_CHARS.sub("_", base)
|
||||
if len(base) > 100:
|
||||
ext = ""
|
||||
dot = base.rfind(".")
|
||||
if dot >= 0 and len(base) - dot <= 16:
|
||||
ext = base[dot:]
|
||||
base = base[: 100 - len(ext)] + ext
|
||||
if base in ("", ".", ".."):
|
||||
return "file"
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URI cache — maps platform-pending URIs to local workspace: URIs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _URICache:
|
||||
"""Thread-safe bounded LRU mapping of platform-pending → workspace URIs.
|
||||
|
||||
Bounded so a workspace that runs for months and handles thousands of
|
||||
uploads doesn't accumulate entries forever. ``OrderedDict.move_to_end``
|
||||
promotes recently-used entries; eviction takes the oldest.
|
||||
|
||||
The cache is intentionally per-process — there is no persistence
|
||||
across a workspace restart. A restart with a stale inbox cursor that
|
||||
re-poll an upload-receive row will re-fetch (the bytes are already
|
||||
on disk from the prior session — see ``stage_to_disk``'s O_EXCL
|
||||
handling) and re-register; a chat message that referenced the
|
||||
platform-pending URI BEFORE the restart and arrives AFTER would miss
|
||||
the rewrite and surface the platform-pending URI to the agent. That
|
||||
is preferable to a stale persisted mapping that points at a deleted
|
||||
file.
|
||||
"""
|
||||
|
||||
def __init__(self, max_entries: int = URI_CACHE_MAX_ENTRIES):
|
||||
self._max = max_entries
|
||||
self._lock = threading.Lock()
|
||||
self._entries: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def get(self, pending_uri: str) -> str | None:
|
||||
with self._lock:
|
||||
local = self._entries.get(pending_uri)
|
||||
if local is not None:
|
||||
self._entries.move_to_end(pending_uri)
|
||||
return local
|
||||
|
||||
def set(self, pending_uri: str, local_uri: str) -> None:
|
||||
with self._lock:
|
||||
self._entries[pending_uri] = local_uri
|
||||
self._entries.move_to_end(pending_uri)
|
||||
while len(self._entries) > self._max:
|
||||
self._entries.popitem(last=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._entries)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._entries.clear()
|
||||
|
||||
|
||||
_cache = _URICache()
|
||||
|
||||
|
||||
def get_cache() -> _URICache:
|
||||
"""Expose the module-singleton cache for tests and the rewrite path."""
|
||||
return _cache
|
||||
|
||||
|
||||
def resolve_pending_uri(uri: str) -> str | None:
|
||||
"""Return the local ``workspace:`` URI for a ``platform-pending:`` URI,
|
||||
or None if not yet staged. Convenience for callers that want to
|
||||
fall back to an on-demand fetch — pass the result through to
|
||||
``executor_helpers.resolve_attachment_uri``.
|
||||
"""
|
||||
return _cache.get(uri)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# On-disk staging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _open_safe(path: str) -> int:
|
||||
"""Open ``path`` for write with ``O_CREAT|O_EXCL|O_NOFOLLOW``.
|
||||
|
||||
Same shape as ``internal_chat_uploads._open_safe`` — refuses to
|
||||
follow a pre-existing symlink at the target and refuses to overwrite
|
||||
an existing regular file. The 16-byte random prefix makes a name
|
||||
collision astronomical, but defense-in-depth costs nothing.
|
||||
"""
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
|
||||
if hasattr(os, "O_NOFOLLOW"):
|
||||
flags |= os.O_NOFOLLOW
|
||||
return os.open(path, flags, 0o600)
|
||||
|
||||
|
||||
def stage_to_disk(content: bytes, filename: str) -> str:
|
||||
"""Write ``content`` under ``CHAT_UPLOAD_DIR`` and return the local URI.
|
||||
|
||||
Returns ``workspace:/workspace/.molecule/chat-uploads/<prefix>-<sanitized>``.
|
||||
The 32-hex prefix makes the on-disk name unguessable to anything
|
||||
that didn't see the response, so even if a stale agent has a guess
|
||||
at the original filename it can't construct a URL to a sibling's
|
||||
upload.
|
||||
|
||||
Raises:
|
||||
OSError: write failure (mkdir, open, or write). Caller is
|
||||
expected to log + skip; the activity row stays unacked so a
|
||||
future poll re-tries.
|
||||
ValueError: ``content`` exceeds ``MAX_FILE_BYTES``. Pre-staging
|
||||
guard belt-and-braces above the platform's same-side cap.
|
||||
"""
|
||||
if len(content) > MAX_FILE_BYTES:
|
||||
raise ValueError(
|
||||
f"content size {len(content)} exceeds workspace cap {MAX_FILE_BYTES}"
|
||||
)
|
||||
|
||||
Path(CHAT_UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sanitized = sanitize_filename(filename)
|
||||
prefix = pysecrets.token_hex(16)
|
||||
stored = f"{prefix}-{sanitized}"
|
||||
target = os.path.join(CHAT_UPLOAD_DIR, stored)
|
||||
|
||||
fd = _open_safe(target)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(content)
|
||||
except OSError:
|
||||
# Best-effort cleanup — partial writes leave a stub file that
|
||||
# would mask a future retry's success otherwise.
|
||||
try:
|
||||
os.unlink(target)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
return f"workspace:{CHAT_UPLOAD_DIR}/{stored}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Activity row → fetch/stage/ack flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _request_body_dict(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Coerce ``row['request_body']`` into a dict.
|
||||
|
||||
The /activity API returns request_body as JSON (already-deserialized
|
||||
by httpx). Some legacy paths or mocked transports may emit a string;
|
||||
handle defensively rather than raising.
|
||||
"""
|
||||
body = row.get("request_body")
|
||||
if isinstance(body, dict):
|
||||
return body
|
||||
if isinstance(body, str):
|
||||
import json
|
||||
try:
|
||||
decoded = json.loads(body)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return decoded if isinstance(decoded, dict) else None
|
||||
return None
|
||||
|
||||
|
||||
def is_chat_upload_row(row: dict[str, Any]) -> bool:
|
||||
"""True if ``row`` is the platform's chat-upload-receive activity.
|
||||
|
||||
Used by the inbox poller to fork the row off the regular A2A
|
||||
message handling path — this row is not a peer message; it's an
|
||||
instruction to fetch + stage bytes. Match on ``method`` only;
|
||||
``activity_type`` is already filtered to ``a2a_receive`` upstream.
|
||||
"""
|
||||
return row.get("method") == "chat_upload_receive"
|
||||
|
||||
|
||||
def fetch_and_stage(
|
||||
row: dict[str, Any],
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
client: Any = None,
|
||||
) -> str | None:
|
||||
"""Fetch the row's bytes, stage them under chat-uploads, and ack.
|
||||
|
||||
Returns the local ``workspace:`` URI on success, or ``None`` if any
|
||||
step failed (logged with enough detail to triage). Failure leaves
|
||||
the platform-side row unacked, so a subsequent poll retries — the
|
||||
activity row stays in the cursor's window because we DO advance the
|
||||
cursor (the row is "handled" from the inbox's perspective even on
|
||||
fetch failure; otherwise a permanent network outage would stall the
|
||||
cursor and block real chat traffic).
|
||||
|
||||
On success, the URI cache is updated so a subsequent chat message
|
||||
referencing the same ``platform-pending:`` URI is rewritten before
|
||||
the agent sees it.
|
||||
|
||||
Pass ``client`` to reuse a shared ``httpx.Client`` for both GET and
|
||||
POST ack (saves one TLS handshake per row vs. constructing one
|
||||
per-call). ``BatchFetcher`` does this across an entire poll batch so
|
||||
N concurrent fetches share one connection pool.
|
||||
"""
|
||||
body = _request_body_dict(row)
|
||||
if body is None:
|
||||
logger.warning(
|
||||
"inbox_uploads: row %s missing request_body; cannot fetch",
|
||||
row.get("id"),
|
||||
)
|
||||
return None
|
||||
|
||||
file_id = body.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
logger.warning(
|
||||
"inbox_uploads: row %s has no file_id in request_body",
|
||||
row.get("id"),
|
||||
)
|
||||
return None
|
||||
|
||||
pending_uri = body.get("uri")
|
||||
if not isinstance(pending_uri, str) or not pending_uri:
|
||||
# Reconstruct what the platform would have written — defensive
|
||||
# against a row whose uri field got truncated. Same shape as the
|
||||
# Go handler's URI builder.
|
||||
pending_uri = f"platform-pending:{workspace_id}/{file_id}"
|
||||
|
||||
filename = body.get("name") or "file"
|
||||
if not isinstance(filename, str):
|
||||
filename = "file"
|
||||
|
||||
# Caller-supplied client: reuse for both GET + POST ack. Otherwise
|
||||
# build a one-shot client and close it on the way out. Lazy httpx
|
||||
# import keeps the standalone MCP path's optional dep optional.
|
||||
own_client = client is None
|
||||
if own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
logger.error("inbox_uploads: httpx not installed; cannot fetch %s", file_id)
|
||||
return None
|
||||
client = httpx.Client(timeout=timeout_secs)
|
||||
|
||||
try:
|
||||
return _fetch_and_stage_with_client(
|
||||
client,
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
file_id=file_id,
|
||||
pending_uri=pending_uri,
|
||||
filename=filename,
|
||||
body=body,
|
||||
)
|
||||
finally:
|
||||
if own_client:
|
||||
try:
|
||||
client.close()
|
||||
except Exception: # noqa: BLE001 — close should never crash the caller
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_and_stage_with_client(
|
||||
client: Any,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
file_id: str,
|
||||
pending_uri: str,
|
||||
filename: str,
|
||||
body: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Inner body of fetch_and_stage. Always uses the supplied client for
|
||||
both GET and POST so the connection pool is shared across the call.
|
||||
"""
|
||||
content_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/content"
|
||||
ack_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/ack"
|
||||
|
||||
try:
|
||||
resp = client.get(content_url, headers=headers)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: GET %s failed: %s", content_url, exc)
|
||||
return None
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Row was swept or already acked by a previous poll race — nothing
|
||||
# to fetch. Don't ack again; the platform's GC handles it. This is
|
||||
# a soft-skip, not an error — log at INFO so triage isn't noisy.
|
||||
logger.info(
|
||||
"inbox_uploads: pending upload %s already gone (404); skipping",
|
||||
file_id,
|
||||
)
|
||||
return None
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"inbox_uploads: GET %s returned %d: %s",
|
||||
content_url,
|
||||
resp.status_code,
|
||||
(resp.text or "")[:200],
|
||||
)
|
||||
return None
|
||||
|
||||
content = resp.content or b""
|
||||
if len(content) > MAX_FILE_BYTES:
|
||||
logger.warning(
|
||||
"inbox_uploads: refusing to stage %s — size %d exceeds cap %d",
|
||||
file_id,
|
||||
len(content),
|
||||
MAX_FILE_BYTES,
|
||||
)
|
||||
return None
|
||||
|
||||
# Mimetype precedence: platform's Content-Type header → request_body
|
||||
# mimeType field → extension guess. Same precedence as the in-
|
||||
# container ingest handler.
|
||||
mime_header = resp.headers.get("content-type", "").split(";")[0].strip()
|
||||
mime = (
|
||||
mime_header
|
||||
or (body.get("mimeType") if isinstance(body.get("mimeType"), str) else "")
|
||||
or (mimetypes.guess_type(filename)[0] or "")
|
||||
)
|
||||
|
||||
try:
|
||||
local_uri = stage_to_disk(content, filename)
|
||||
except (OSError, ValueError) as exc:
|
||||
logger.error(
|
||||
"inbox_uploads: failed to stage %s (%s) to disk: %s",
|
||||
file_id,
|
||||
filename,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
_cache.set(pending_uri, local_uri)
|
||||
logger.info(
|
||||
"inbox_uploads: staged file_id=%s name=%s size=%d mime=%s pending_uri=%s local_uri=%s",
|
||||
file_id,
|
||||
filename,
|
||||
len(content),
|
||||
mime,
|
||||
pending_uri,
|
||||
local_uri,
|
||||
)
|
||||
|
||||
# Ack last so a write failure above leaves the row available for a
|
||||
# retry on the next poll. A failed ack is logged but doesn't roll
|
||||
# back the on-disk file — the platform's sweep will clean up
|
||||
# eventually.
|
||||
try:
|
||||
ack_resp = client.post(ack_url, headers=headers)
|
||||
if ack_resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"inbox_uploads: ack %s returned %d: %s",
|
||||
ack_url,
|
||||
ack_resp.status_code,
|
||||
(ack_resp.text or "")[:200],
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: POST %s failed: %s", ack_url, exc)
|
||||
|
||||
return local_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BatchFetcher — concurrent fetch across a single poll batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BatchFetcher:
|
||||
"""Fetch + stage + ack a batch of upload-receive rows concurrently.
|
||||
|
||||
Why this exists: the inbox poll loop used to call ``fetch_and_stage``
|
||||
serially per row. With N upload rows in a batch (a user dragging
|
||||
multiple files into chat at once), the loop blocked for
|
||||
``N × per_fetch_latency`` before processing the chat message that
|
||||
referenced them — a 4-file upload at 5s each = 20s of stall
|
||||
before the agent saw the user's prompt. ``BatchFetcher`` runs the
|
||||
fetches on a small thread pool (default 4 workers) so the stall is
|
||||
bounded by ``ceil(N/W) × per_fetch_latency`` instead.
|
||||
|
||||
Connection reuse: one ``httpx.Client`` is shared across every fetch
|
||||
in the batch. httpx clients carry a connection pool, so a second
|
||||
fetch to the same platform host reuses the TCP+TLS handshake from
|
||||
the first — measurable win when fetches happen back-to-back.
|
||||
|
||||
Correctness invariant the caller MUST preserve: the inbox loop is
|
||||
expected to call ``wait_all()`` before processing the chat-message
|
||||
activity row that REFERENCES one of these uploads. Without the
|
||||
barrier, the URI cache is empty when ``rewrite_request_body`` runs
|
||||
and the agent sees the un-rewritten ``platform-pending:`` URI. The
|
||||
caller-side test ``test_poll_once_waits_for_uploads_before_messages``
|
||||
pins this end-to-end.
|
||||
|
||||
Use as a context manager so the executor + client are torn down
|
||||
even if the caller raises mid-batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
max_workers: int = DEFAULT_BATCH_FETCH_WORKERS,
|
||||
client: Any = None,
|
||||
):
|
||||
self._platform_url = platform_url
|
||||
self._workspace_id = workspace_id
|
||||
self._headers = dict(headers) # copy so caller mutations don't leak in
|
||||
self._timeout_secs = timeout_secs
|
||||
|
||||
# Caller can inject a client (tests do this); production callers
|
||||
# let us build one. Track ownership so we only close ours.
|
||||
self._own_client = client is None
|
||||
if self._own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
# Match fetch_and_stage's behavior: log + degrade rather
|
||||
# than raising at construction time. submit() will then
|
||||
# return None for every row.
|
||||
logger.error("inbox_uploads: httpx not installed; BatchFetcher inert")
|
||||
self._client: Any = None
|
||||
else:
|
||||
self._client = httpx.Client(timeout=timeout_secs)
|
||||
else:
|
||||
self._client = client
|
||||
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="upload-fetch",
|
||||
)
|
||||
self._futures: list[concurrent.futures.Future[Any]] = []
|
||||
self._closed = False
|
||||
# Flipped to True by wait_all when the timeout fires; close()
|
||||
# reads this to decide between drain-and-wait vs cancel-queued.
|
||||
self._timed_out = False
|
||||
|
||||
def submit(self, row: dict[str, Any]) -> concurrent.futures.Future[Any] | None:
|
||||
"""Submit ``row`` for fetch + stage + ack. Non-blocking — the
|
||||
worker thread runs ``fetch_and_stage`` with the shared client.
|
||||
|
||||
Returns the Future so a caller that wants per-row outcome can
|
||||
await it; ``None`` if the BatchFetcher is in a degraded state
|
||||
(httpx missing).
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("BatchFetcher: submit after close")
|
||||
if self._client is None:
|
||||
return None
|
||||
fut = self._executor.submit(
|
||||
fetch_and_stage,
|
||||
row,
|
||||
platform_url=self._platform_url,
|
||||
workspace_id=self._workspace_id,
|
||||
headers=self._headers,
|
||||
timeout_secs=self._timeout_secs,
|
||||
client=self._client,
|
||||
)
|
||||
self._futures.append(fut)
|
||||
return fut
|
||||
|
||||
def wait_all(self, timeout: float | None = DEFAULT_BATCH_WAIT_TIMEOUT) -> None:
|
||||
"""Block until every submitted future completes (or times out).
|
||||
|
||||
Per-future exceptions are logged + swallowed — ``fetch_and_stage``
|
||||
already converts every error path to ``return None``, so a real
|
||||
exception propagating up to here is unexpected and we don't want
|
||||
one bad fetch to abort the whole batch.
|
||||
|
||||
Timeouts are also logged + swallowed AND record the timed-out
|
||||
futures on ``self._timed_out`` so ``close`` can cancel them
|
||||
without paying their full latency. Without this hand-off,
|
||||
``close()``'s ``shutdown(wait=True)`` would block on the leaked
|
||||
workers and undo the user-facing timeout — the inbox poll loop
|
||||
would stall indefinitely on a hung /content fetch.
|
||||
"""
|
||||
if not self._futures:
|
||||
return
|
||||
try:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
self._futures,
|
||||
timeout=timeout,
|
||||
return_when=concurrent.futures.ALL_COMPLETED,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — concurrent.futures shouldn't raise here
|
||||
logger.warning("inbox_uploads: BatchFetcher.wait_all crashed: %s", exc)
|
||||
return
|
||||
for fut in done:
|
||||
exc = fut.exception()
|
||||
if exc is not None:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher worker raised: %s", exc
|
||||
)
|
||||
if not_done:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher.wait_all left %d in-flight after %ss timeout",
|
||||
len(not_done),
|
||||
timeout,
|
||||
)
|
||||
# Mark these futures so close() knows to cancel-not-wait. We
|
||||
# cancel queued-but-not-started ones immediately; futures
|
||||
# already running can't be cancelled (Python's threading
|
||||
# model), but close() will pass cancel_futures=True so any
|
||||
# remaining queued items don't run.
|
||||
for fut in not_done:
|
||||
fut.cancel()
|
||||
self._timed_out = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Tear down the executor + (if owned) the httpx client.
|
||||
|
||||
Idempotent. After close, ``submit`` raises and the BatchFetcher
|
||||
cannot be reused — construct a fresh one for the next poll.
|
||||
|
||||
If ``wait_all`` reported a timeout, shutdown skips the
|
||||
``wait=True`` drain and instead asks the executor to drop queued
|
||||
futures (``cancel_futures=True``). Currently-running workers
|
||||
can't be interrupted by Python's threading model, but the poll
|
||||
loop returns immediately rather than blocking on a hung fetch.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
timed_out = getattr(self, "_timed_out", False)
|
||||
try:
|
||||
if timed_out:
|
||||
# cancel_futures landed in Python 3.9 — guarded for older
|
||||
# interpreters via a TypeError fallback. Drop queued
|
||||
# tasks; running ones will exit when their httpx call
|
||||
# eventually returns or the daemon thread dies.
|
||||
try:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
except TypeError:
|
||||
self._executor.shutdown(wait=False)
|
||||
else:
|
||||
# Healthy path: wait for in-flight work so we don't
|
||||
# interrupt a fetch mid-write.
|
||||
self._executor.shutdown(wait=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: executor shutdown error: %s", exc)
|
||||
if self._own_client and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: client close error: %s", exc)
|
||||
|
||||
def __enter__(self) -> "BatchFetcher":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URI rewrite for incoming chat messages
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# The chat message that references a staged upload arrives as a
|
||||
# SEPARATE activity_log row, with parts of kind=file containing
|
||||
# platform-pending: URIs in the file.uri field. Walk the structure
|
||||
# in-place and rewrite to the local workspace: URI when the cache has it.
|
||||
# Unknown URIs pass through unchanged — the agent gets to choose how
|
||||
# to react (most runtimes log + ignore an unresolvable URI).
|
||||
|
||||
|
||||
def _rewrite_part(part: Any) -> None:
|
||||
"""Mutate a single A2A Part dict to swap platform-pending: URIs."""
|
||||
if not isinstance(part, dict):
|
||||
return
|
||||
file_obj = part.get("file")
|
||||
if not isinstance(file_obj, dict):
|
||||
return
|
||||
uri = file_obj.get("uri")
|
||||
if not isinstance(uri, str) or not uri.startswith("platform-pending:"):
|
||||
return
|
||||
rewritten = _cache.get(uri)
|
||||
if rewritten:
|
||||
file_obj["uri"] = rewritten
|
||||
|
||||
|
||||
def rewrite_request_body(body: Any) -> None:
|
||||
"""Mutate ``body`` in-place, replacing platform-pending: URIs with
|
||||
the cached local equivalents.
|
||||
|
||||
Walks the same shapes ``inbox._extract_text`` accepts:
|
||||
|
||||
- ``body['parts']``
|
||||
- ``body['params']['parts']``
|
||||
- ``body['params']['message']['parts']``
|
||||
|
||||
No-op for shapes that don't match — the message simply passes
|
||||
through to the agent as-is.
|
||||
"""
|
||||
if not isinstance(body, dict):
|
||||
return
|
||||
candidates: list[Any] = []
|
||||
params = body.get("params") if isinstance(body.get("params"), dict) else None
|
||||
if params:
|
||||
message = params.get("message") if isinstance(params.get("message"), dict) else None
|
||||
if message:
|
||||
candidates.append(message.get("parts"))
|
||||
candidates.append(params.get("parts"))
|
||||
candidates.append(body.get("parts"))
|
||||
|
||||
for parts in candidates:
|
||||
if isinstance(parts, list):
|
||||
for part in parts:
|
||||
_rewrite_part(part)
|
||||
+33
-466
@@ -31,422 +31,53 @@ dependency via ``a2a-sdk``.
|
||||
In-container usage (``python -m molecule_runtime.a2a_mcp_server`` or
|
||||
direct import) bypasses this wrapper — the workspace runtime has its
|
||||
own heartbeat loop in ``heartbeat.py`` so we don't double-heartbeat.
|
||||
|
||||
Module layout (RFC #2873 iter 3 split):
|
||||
* ``mcp_heartbeat`` — register POST + heartbeat loop + auth-failure
|
||||
escalation + inbound-secret persistence.
|
||||
* ``mcp_workspace_resolver`` — env validation, single + multi-workspace
|
||||
resolution, operator-help printer, on-disk token-file read.
|
||||
* ``mcp_inbox_pollers`` — activate the inbox singleton + spawn one
|
||||
daemon poller per workspace.
|
||||
|
||||
This file keeps just ``main()`` plus thin re-exports of the private
|
||||
symbols so existing tests' imports (``mcp_cli._build_agent_card``,
|
||||
``mcp_cli._heartbeat_loop``, etc.) keep working without churn.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import configs_dir
|
||||
import mcp_heartbeat
|
||||
import mcp_inbox_pollers
|
||||
import mcp_workspace_resolver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Heartbeat cadence. Must be tighter than healthsweep's stale window
|
||||
# (currently 60-90s — see registry/healthsweep.go) by a comfortable
|
||||
# margin so a single missed heartbeat doesn't flip awaiting_agent.
|
||||
# 20s gives the operator's network 3 attempts within the budget; long
|
||||
# enough that it doesn't spam, short enough to recover quickly after
|
||||
# laptop sleep.
|
||||
HEARTBEAT_INTERVAL_SECONDS = 20.0
|
||||
# Re-export public surface for back-compat with the pre-split callers
|
||||
# and tests. The underscore-prefixed names mirror the names that
|
||||
# existed in this module before the split — keeping them ensures
|
||||
# `mcp_cli._build_agent_card`, `mcp_cli._heartbeat_loop`, etc.
|
||||
# resolve identically to the new functions.
|
||||
HEARTBEAT_INTERVAL_SECONDS = mcp_heartbeat.HEARTBEAT_INTERVAL_SECONDS
|
||||
_HEARTBEAT_AUTH_LOUD_THRESHOLD = mcp_heartbeat.HEARTBEAT_AUTH_LOUD_THRESHOLD
|
||||
_HEARTBEAT_AUTH_RELOG_INTERVAL = mcp_heartbeat.HEARTBEAT_AUTH_RELOG_INTERVAL
|
||||
|
||||
# After this many consecutive 401/403 heartbeats, escalate from
|
||||
# WARNING to ERROR with re-onboard guidance. 3 ticks at 20s = ~1 minute
|
||||
# of sustained auth failure — enough to rule out a transient platform
|
||||
# blip but quick enough that an operator doesn't sit puzzled for 10
|
||||
# minutes wondering why their MCP tools 401. Same threshold used for
|
||||
# repeat-logging at 20-tick (~7 min) intervals so a long-running
|
||||
# session that missed the first ERROR still sees the message.
|
||||
_HEARTBEAT_AUTH_LOUD_THRESHOLD = 3
|
||||
_HEARTBEAT_AUTH_RELOG_INTERVAL = 20
|
||||
_build_agent_card = mcp_heartbeat.build_agent_card
|
||||
_platform_register = mcp_heartbeat.platform_register
|
||||
_heartbeat_loop = mcp_heartbeat.heartbeat_loop
|
||||
_log_heartbeat_auth_failure = mcp_heartbeat.log_heartbeat_auth_failure
|
||||
_persist_inbound_secret_from_heartbeat = mcp_heartbeat.persist_inbound_secret_from_heartbeat
|
||||
_start_heartbeat_thread = mcp_heartbeat.start_heartbeat_thread
|
||||
|
||||
_resolve_workspaces = mcp_workspace_resolver.resolve_workspaces
|
||||
_print_missing_env_help = mcp_workspace_resolver.print_missing_env_help
|
||||
_read_token_file = mcp_workspace_resolver.read_token_file
|
||||
|
||||
def _build_agent_card(workspace_id: str) -> dict:
|
||||
"""Build the ``agent_card`` payload sent to /registry/register.
|
||||
|
||||
Three optional env vars override the defaults so an operator can
|
||||
surface human-readable identity + capabilities to peers and the
|
||||
canvas Skills tab without code changes:
|
||||
|
||||
* ``MOLECULE_AGENT_NAME`` — display name (defaults to
|
||||
``molecule-mcp-{id[:8]}``). Surfaced in canvas workspace cards
|
||||
and ``list_peers`` output.
|
||||
* ``MOLECULE_AGENT_DESCRIPTION`` — one-liner about the agent's
|
||||
purpose. Rendered in canvas Details + Skills tabs.
|
||||
* ``MOLECULE_AGENT_SKILLS`` — comma-separated skill names
|
||||
(e.g. ``research,code-review,memory-curation``). Each name is
|
||||
expanded to a ``{"name": ...}`` skill object — the minimum
|
||||
shape that satisfies both ``shared_runtime.summarize_peers``
|
||||
(uses ``s["name"]``) and the canvas SkillsTab.tsx schema
|
||||
(id falls back to name when omitted). Empty / whitespace
|
||||
entries are dropped.
|
||||
|
||||
Defaults match the previous hardcoded behaviour exactly so this
|
||||
is a strict superset — an operator who sets none of the env vars
|
||||
sees no change.
|
||||
"""
|
||||
name = (os.environ.get("MOLECULE_AGENT_NAME") or "").strip()
|
||||
if not name:
|
||||
name = f"molecule-mcp-{workspace_id[:8]}"
|
||||
|
||||
description = (os.environ.get("MOLECULE_AGENT_DESCRIPTION") or "").strip()
|
||||
|
||||
skills_raw = (os.environ.get("MOLECULE_AGENT_SKILLS") or "").strip()
|
||||
skills: list[dict] = []
|
||||
if skills_raw:
|
||||
for s in skills_raw.split(","):
|
||||
label = s.strip()
|
||||
if label:
|
||||
skills.append({"name": label})
|
||||
|
||||
card: dict = {"name": name, "skills": skills}
|
||||
if description:
|
||||
card["description"] = description
|
||||
return card
|
||||
|
||||
|
||||
def _platform_register(platform_url: str, workspace_id: str, token: str) -> None:
|
||||
"""One-shot register at startup; fails fast on auth errors.
|
||||
|
||||
Lifts the workspace from ``awaiting_agent`` to ``online`` for
|
||||
operators who never ran the curl-register snippet. Safe to call
|
||||
repeatedly: the platform's register handler is an upsert that
|
||||
just refreshes ``url``, ``agent_card``, and ``status``.
|
||||
|
||||
Failure model (post-review):
|
||||
- 401 / 403 → ``sys.exit(3)`` immediately. The operator's
|
||||
token is wrong; silently looping in a broken state would
|
||||
make this hard to diagnose because the MCP tools would 401
|
||||
on every call too. Hard-fail is the kindest option.
|
||||
- Other 4xx/5xx → log a warning + continue. The heartbeat
|
||||
thread will surface persistent failures; transient platform
|
||||
blips shouldn't abort the MCP loop.
|
||||
- Network / transport errors → log + continue. Same reasoning.
|
||||
|
||||
Origin header is required by the SaaS edge WAF; without it
|
||||
/registry/register currently still works (it's on the WAF
|
||||
allowlist), but the heartbeat path needs Origin and we want one
|
||||
consistent header set across both calls.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
# httpx is a transitive dep via a2a-sdk; if missing, the MCP
|
||||
# server won't import either. Let the caller's later import
|
||||
# surface the real error.
|
||||
return
|
||||
|
||||
payload = {
|
||||
"id": workspace_id,
|
||||
"url": "",
|
||||
"agent_card": _build_agent_card(workspace_id),
|
||||
"delivery_mode": "poll",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": platform_url,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
resp = client.post(
|
||||
f"{platform_url}/registry/register",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
print(
|
||||
f"molecule-mcp: register rejected with HTTP {resp.status_code} — "
|
||||
f"the token in MOLECULE_WORKSPACE_TOKEN is invalid for workspace "
|
||||
f"{workspace_id}. Regenerate from the canvas → Tokens tab.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(3)
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"molecule-mcp: register POST returned HTTP %d: %s",
|
||||
resp.status_code,
|
||||
(resp.text or "")[:200],
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"molecule-mcp: registered workspace %s with platform",
|
||||
workspace_id,
|
||||
)
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("molecule-mcp: register POST failed: %s", exc)
|
||||
|
||||
|
||||
def _heartbeat_loop(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
token: str,
|
||||
interval: float = HEARTBEAT_INTERVAL_SECONDS,
|
||||
) -> None:
|
||||
"""Daemon thread body: POST /registry/heartbeat every ``interval``s.
|
||||
|
||||
Failures are logged at WARNING and the loop continues. The thread
|
||||
exits when the main process does (daemon=True). Each iteration
|
||||
rebuilds the payload + headers — cheap and ensures token rotation
|
||||
via env var (rare but possible) is picked up on the next tick.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
consecutive_auth_failures = 0
|
||||
while True:
|
||||
body = {
|
||||
"workspace_id": workspace_id,
|
||||
"error_rate": 0.0,
|
||||
"sample_error": "",
|
||||
"active_tasks": 0,
|
||||
"uptime_seconds": int(time.time() - start_time),
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": platform_url,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
resp = client.post(
|
||||
f"{platform_url}/registry/heartbeat",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
consecutive_auth_failures += 1
|
||||
_log_heartbeat_auth_failure(
|
||||
consecutive_auth_failures, workspace_id, resp.status_code,
|
||||
)
|
||||
elif resp.status_code >= 400:
|
||||
# Non-auth HTTP error — log, but DO NOT touch the
|
||||
# auth-failure counter (5xx blips, 429, etc. are
|
||||
# transient and unrelated to token validity).
|
||||
logger.warning(
|
||||
"molecule-mcp: heartbeat HTTP %d: %s",
|
||||
resp.status_code,
|
||||
(resp.text or "")[:200],
|
||||
)
|
||||
else:
|
||||
consecutive_auth_failures = 0
|
||||
_persist_inbound_secret_from_heartbeat(resp)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("molecule-mcp: heartbeat failed: %s", exc)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def _log_heartbeat_auth_failure(count: int, workspace_id: str, status_code: int) -> None:
|
||||
"""Escalate consecutive heartbeat 401/403s from quiet WARNING to
|
||||
actionable ERROR.
|
||||
|
||||
The operator's first sign of trouble shouldn't be "tools 401 with no
|
||||
explanation" — that was the failure mode that motivated this code,
|
||||
triggered by a workspace being deleted server-side and its tokens
|
||||
revoked while the runtime kept heartbeating in silence.
|
||||
|
||||
Cadence:
|
||||
* count < threshold: WARNING per tick (transient — could be a
|
||||
platform blip, don't shout yet)
|
||||
* count == threshold: ERROR with re-onboard instructions
|
||||
(the first signal the operator can't miss)
|
||||
* count > threshold and (count - threshold) % relog == 0: re-log
|
||||
ERROR (so a session that started after the first ERROR still
|
||||
sees the message scrolling past in their logs)
|
||||
"""
|
||||
if count < _HEARTBEAT_AUTH_LOUD_THRESHOLD:
|
||||
logger.warning(
|
||||
"molecule-mcp: heartbeat HTTP %d (auth failure %d/%d) — "
|
||||
"token may be revoked. Will retry; if persistent, regenerate "
|
||||
"from canvas → Tokens.",
|
||||
status_code, count, _HEARTBEAT_AUTH_LOUD_THRESHOLD,
|
||||
)
|
||||
return
|
||||
# At or past the threshold — this is the loud actionable error.
|
||||
if count == _HEARTBEAT_AUTH_LOUD_THRESHOLD or (
|
||||
count - _HEARTBEAT_AUTH_LOUD_THRESHOLD
|
||||
) % _HEARTBEAT_AUTH_RELOG_INTERVAL == 0:
|
||||
logger.error(
|
||||
"molecule-mcp: %d consecutive heartbeat auth failures (HTTP %d) — "
|
||||
"the token in MOLECULE_WORKSPACE_TOKEN has been REVOKED, likely "
|
||||
"because workspace %s was deleted server-side. The MCP server is "
|
||||
"still running but every platform call will fail. Regenerate the "
|
||||
"workspace + token from the canvas (Tokens tab), update your MCP "
|
||||
"config, and restart your runtime.",
|
||||
count, status_code, workspace_id,
|
||||
)
|
||||
|
||||
|
||||
def _persist_inbound_secret_from_heartbeat(resp: object) -> None:
|
||||
"""Persist ``platform_inbound_secret`` from a heartbeat response, if any.
|
||||
|
||||
The platform's heartbeat handler returns the secret on every beat
|
||||
(mirroring /registry/register) so a workspace that lazy-healed the
|
||||
secret on the platform side — typical recovery path for a workspace
|
||||
whose row had a NULL ``platform_inbound_secret`` after a partial
|
||||
bootstrap — picks it up within one heartbeat tick instead of
|
||||
requiring a runtime restart.
|
||||
|
||||
Without this delivery path the chat-upload code path's "secret was
|
||||
just minted, will pick up on next heartbeat" 503 message is a lie
|
||||
and the workspace stays 401-forever until the operator restarts
|
||||
the runtime. Caught 2026-04-30 on hongmingwang tenant.
|
||||
|
||||
Failure is non-fatal: if the body isn't JSON, doesn't carry the
|
||||
field, or the disk write fails, the next heartbeat retries. This
|
||||
matches the cold-start register flow in main.py:319-323.
|
||||
"""
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception: # noqa: BLE001
|
||||
return
|
||||
if not isinstance(body, dict):
|
||||
return
|
||||
secret = body.get("platform_inbound_secret")
|
||||
if not secret:
|
||||
return
|
||||
try:
|
||||
from platform_inbound_auth import save_inbound_secret
|
||||
|
||||
save_inbound_secret(secret)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"molecule-mcp: persist inbound secret from heartbeat failed: %s", exc
|
||||
)
|
||||
|
||||
|
||||
def _start_heartbeat_thread(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
token: str,
|
||||
) -> threading.Thread:
|
||||
"""Start the heartbeat daemon thread. Returns the Thread handle.
|
||||
|
||||
The MCP stdio loop runs in the foreground (asyncio); this thread
|
||||
runs alongside it. ``daemon=True`` so when the operator hits
|
||||
Ctrl-C / closes the runtime, the heartbeat dies with it instead
|
||||
of leaking and writing to a stale workspace.
|
||||
"""
|
||||
t = threading.Thread(
|
||||
target=_heartbeat_loop,
|
||||
args=(platform_url, workspace_id, token),
|
||||
name="molecule-mcp-heartbeat",
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
|
||||
def _resolve_workspaces() -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Return the list of ``(workspace_id, token)`` pairs to register.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. ``MOLECULE_WORKSPACES`` env var — JSON array of
|
||||
``{"id": "...", "token": "..."}`` objects. Activates the
|
||||
multi-workspace external-agent path (one process registered into
|
||||
N workspaces). When set, ``WORKSPACE_ID`` / ``MOLECULE_WORKSPACE_TOKEN``
|
||||
are IGNORED — the JSON is the source of truth.
|
||||
|
||||
2. Single-workspace fallback — ``WORKSPACE_ID`` env var + token from
|
||||
``MOLECULE_WORKSPACE_TOKEN`` or ``${CONFIGS_DIR}/.auth_token``.
|
||||
This is the pre-existing path; back-compat exact.
|
||||
|
||||
Returns ``(workspaces, errors)``:
|
||||
* ``workspaces``: list of ``(workspace_id, token)`` — non-empty
|
||||
on the happy path.
|
||||
* ``errors``: human-readable strings describing what's missing /
|
||||
malformed. ``main()`` surfaces these with the same shape as
|
||||
``_print_missing_env_help`` so the operator's first run gives
|
||||
actionable output.
|
||||
|
||||
Why JSON env (not file): ergonomic for Claude Code MCP config (one
|
||||
string in ``mcpServers.molecule.env`` instead of a sidecar file)
|
||||
and for CI / launchers. A separate config-file path can be added
|
||||
later without breaking this.
|
||||
"""
|
||||
raw = os.environ.get("MOLECULE_WORKSPACES", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
return [], [
|
||||
f"MOLECULE_WORKSPACES is not valid JSON ({exc.msg} at pos "
|
||||
f"{exc.pos}). Expected: '[{{\"id\":\"<wsid>\",\"token\":"
|
||||
f"\"<tok>\"}},{{...}}]'"
|
||||
]
|
||||
if not isinstance(parsed, list) or not parsed:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACES must be a non-empty JSON array of "
|
||||
"{\"id\":\"...\",\"token\":\"...\"} objects"
|
||||
]
|
||||
out: list[tuple[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
errors: list[str] = []
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] is not an object — got {type(entry).__name__}"
|
||||
)
|
||||
continue
|
||||
wsid = str(entry.get("id", "")).strip()
|
||||
tok = str(entry.get("token", "")).strip()
|
||||
if not wsid or not tok:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] missing 'id' or 'token'"
|
||||
)
|
||||
continue
|
||||
if wsid in seen:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] duplicate workspace id {wsid!r}"
|
||||
)
|
||||
continue
|
||||
seen.add(wsid)
|
||||
out.append((wsid, tok))
|
||||
if errors:
|
||||
return [], errors
|
||||
return out, []
|
||||
|
||||
# Single-workspace back-compat path.
|
||||
wsid = os.environ.get("WORKSPACE_ID", "").strip()
|
||||
if not wsid:
|
||||
return [], ["WORKSPACE_ID (or MOLECULE_WORKSPACES) is required"]
|
||||
tok = os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip()
|
||||
if not tok:
|
||||
tok = _read_token_file()
|
||||
if not tok:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token) is required"
|
||||
]
|
||||
return [(wsid, tok)], []
|
||||
|
||||
|
||||
def _print_missing_env_help(missing: list[str], have_token_file: bool) -> None:
|
||||
print("molecule-mcp: missing required environment.\n", file=sys.stderr)
|
||||
print("Set the following before running molecule-mcp:", file=sys.stderr)
|
||||
print(" WORKSPACE_ID — your workspace UUID (from canvas)", file=sys.stderr)
|
||||
print(
|
||||
" PLATFORM_URL — base URL of your Molecule platform "
|
||||
"(e.g. https://your-tenant.staging.moleculesai.app)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if not have_token_file:
|
||||
print(
|
||||
" MOLECULE_WORKSPACE_TOKEN — bearer token for this workspace "
|
||||
"(canvas → Tokens tab)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print("", file=sys.stderr)
|
||||
print(f"Currently missing: {', '.join(missing)}", file=sys.stderr)
|
||||
_start_inbox_pollers = mcp_inbox_pollers.start_inbox_pollers
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -558,69 +189,5 @@ def main() -> None:
|
||||
cli_main()
|
||||
|
||||
|
||||
def _start_inbox_pollers(platform_url: str, workspace_ids: list[str]) -> None:
|
||||
"""Activate the inbox singleton + spawn one poller daemon thread per workspace.
|
||||
|
||||
Done lazily here (not at module import) because importing inbox
|
||||
pulls in platform_auth, which only resolves cleanly AFTER env
|
||||
validation succeeds. Activation is idempotent within a process,
|
||||
so a stray double-call (e.g. test harness re-entering main) is
|
||||
harmless.
|
||||
|
||||
The poller threads are daemon=True — die with the main process.
|
||||
|
||||
Single-workspace path: one poller, single cursor file at the legacy
|
||||
location (``.mcp_inbox_cursor``). Cursor-key resolution falls back
|
||||
to the empty string for back-compat with operators whose existing
|
||||
on-disk cursor was written by the pre-multi-workspace code.
|
||||
|
||||
Multi-workspace path: N pollers, each with its own cursor file
|
||||
keyed by ``workspace_id[:8]``. Cursors live next to each other in
|
||||
configs_dir so an operator inspecting state sees all of them
|
||||
together.
|
||||
"""
|
||||
try:
|
||||
import inbox
|
||||
except ImportError as exc:
|
||||
logger.warning("molecule-mcp: inbox module unavailable: %s", exc)
|
||||
return
|
||||
|
||||
if len(workspace_ids) <= 1:
|
||||
# Back-compat exact: single-workspace mode reuses the legacy
|
||||
# cursor filename + cursor_path constructor arg, so an existing
|
||||
# operator's on-disk state isn't invalidated by upgrade.
|
||||
wsid = workspace_ids[0]
|
||||
state = inbox.InboxState(cursor_path=inbox.default_cursor_path())
|
||||
inbox.activate(state)
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
return
|
||||
|
||||
# Multi-workspace: per-workspace cursor file, one shared queue.
|
||||
cursor_paths = {wsid: inbox.default_cursor_path(wsid) for wsid in workspace_ids}
|
||||
state = inbox.InboxState(cursor_paths=cursor_paths)
|
||||
inbox.activate(state)
|
||||
for wsid in workspace_ids:
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
|
||||
|
||||
def _read_token_file() -> str:
|
||||
"""Read the token from the resolved configs dir's ``.auth_token`` if
|
||||
present.
|
||||
|
||||
Mirrors platform_auth._token_file's location resolution but without
|
||||
importing the heavy module here (that import triggers a2a_client's
|
||||
WORKSPACE_ID guard which is fine after env validation, but cheaper
|
||||
to inline a 4-line file read than pull in the whole stack just for
|
||||
the path).
|
||||
"""
|
||||
path = configs_dir.resolve() / ".auth_token"
|
||||
if not path.is_file():
|
||||
return ""
|
||||
try:
|
||||
return path.read_text().strip()
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
"""Heartbeat + register thread for the standalone ``molecule-mcp`` wrapper.
|
||||
|
||||
Extracted from ``mcp_cli.py`` (RFC #2873 iter 3) so the heartbeat /
|
||||
register concern lives in its own module. The console-script entry
|
||||
``mcp_cli:main`` still drives the spawn, but the loop body, auth-failure
|
||||
escalation, and inbound-secret persistence now live here so they can be
|
||||
read, tested, and replaced independently of the orchestrator.
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``HEARTBEAT_INTERVAL_SECONDS`` — cadence constant.
|
||||
* ``build_agent_card(workspace_id)`` — payload helper.
|
||||
* ``platform_register(platform_url, workspace_id, token)`` — one-shot
|
||||
POST /registry/register at startup.
|
||||
* ``start_heartbeat_thread(platform_url, workspace_id, token)`` — spawn
|
||||
the daemon thread.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Heartbeat cadence. Must be tighter than healthsweep's stale window
|
||||
# (currently 60-90s — see registry/healthsweep.go) by a comfortable
|
||||
# margin so a single missed heartbeat doesn't flip awaiting_agent.
|
||||
# 20s gives the operator's network 3 attempts within the budget; long
|
||||
# enough that it doesn't spam, short enough to recover quickly after
|
||||
# laptop sleep.
|
||||
HEARTBEAT_INTERVAL_SECONDS = 20.0
|
||||
|
||||
# After this many consecutive 401/403 heartbeats, escalate from
|
||||
# WARNING to ERROR with re-onboard guidance. 3 ticks at 20s = ~1 minute
|
||||
# of sustained auth failure — enough to rule out a transient platform
|
||||
# blip but quick enough that an operator doesn't sit puzzled for 10
|
||||
# minutes wondering why their MCP tools 401. Same threshold used for
|
||||
# repeat-logging at 20-tick (~7 min) intervals so a long-running
|
||||
# session that missed the first ERROR still sees the message.
|
||||
HEARTBEAT_AUTH_LOUD_THRESHOLD = 3
|
||||
HEARTBEAT_AUTH_RELOG_INTERVAL = 20
|
||||
|
||||
|
||||
def build_agent_card(workspace_id: str) -> dict:
|
||||
"""Build the ``agent_card`` payload sent to /registry/register.
|
||||
|
||||
Three optional env vars override the defaults so an operator can
|
||||
surface human-readable identity + capabilities to peers and the
|
||||
canvas Skills tab without code changes:
|
||||
|
||||
* ``MOLECULE_AGENT_NAME`` — display name (defaults to
|
||||
``molecule-mcp-{id[:8]}``). Surfaced in canvas workspace cards
|
||||
and ``list_peers`` output.
|
||||
* ``MOLECULE_AGENT_DESCRIPTION`` — one-liner about the agent's
|
||||
purpose. Rendered in canvas Details + Skills tabs.
|
||||
* ``MOLECULE_AGENT_SKILLS`` — comma-separated skill names
|
||||
(e.g. ``research,code-review,memory-curation``). Each name is
|
||||
expanded to a ``{"name": ...}`` skill object — the minimum
|
||||
shape that satisfies both ``shared_runtime.summarize_peers``
|
||||
(uses ``s["name"]``) and the canvas SkillsTab.tsx schema
|
||||
(id falls back to name when omitted). Empty / whitespace
|
||||
entries are dropped.
|
||||
|
||||
Defaults match the previous hardcoded behaviour exactly so this
|
||||
is a strict superset — an operator who sets none of the env vars
|
||||
sees no change.
|
||||
"""
|
||||
name = (os.environ.get("MOLECULE_AGENT_NAME") or "").strip()
|
||||
if not name:
|
||||
name = f"molecule-mcp-{workspace_id[:8]}"
|
||||
|
||||
description = (os.environ.get("MOLECULE_AGENT_DESCRIPTION") or "").strip()
|
||||
|
||||
skills_raw = (os.environ.get("MOLECULE_AGENT_SKILLS") or "").strip()
|
||||
skills: list[dict] = []
|
||||
if skills_raw:
|
||||
for s in skills_raw.split(","):
|
||||
label = s.strip()
|
||||
if label:
|
||||
skills.append({"name": label})
|
||||
|
||||
card: dict = {"name": name, "skills": skills}
|
||||
if description:
|
||||
card["description"] = description
|
||||
return card
|
||||
|
||||
|
||||
def platform_register(platform_url: str, workspace_id: str, token: str) -> None:
|
||||
"""One-shot register at startup; fails fast on auth errors.
|
||||
|
||||
Lifts the workspace from ``awaiting_agent`` to ``online`` for
|
||||
operators who never ran the curl-register snippet. Safe to call
|
||||
repeatedly: the platform's register handler is an upsert that
|
||||
just refreshes ``url``, ``agent_card``, and ``status``.
|
||||
|
||||
Failure model (post-review):
|
||||
- 401 / 403 → ``sys.exit(3)`` immediately. The operator's
|
||||
token is wrong; silently looping in a broken state would
|
||||
make this hard to diagnose because the MCP tools would 401
|
||||
on every call too. Hard-fail is the kindest option.
|
||||
- Other 4xx/5xx → log a warning + continue. The heartbeat
|
||||
thread will surface persistent failures; transient platform
|
||||
blips shouldn't abort the MCP loop.
|
||||
- Network / transport errors → log + continue. Same reasoning.
|
||||
|
||||
Origin header is required by the SaaS edge WAF; without it
|
||||
/registry/register currently still works (it's on the WAF
|
||||
allowlist), but the heartbeat path needs Origin and we want one
|
||||
consistent header set across both calls.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
# httpx is a transitive dep via a2a-sdk; if missing, the MCP
|
||||
# server won't import either. Let the caller's later import
|
||||
# surface the real error.
|
||||
return
|
||||
|
||||
payload = {
|
||||
"id": workspace_id,
|
||||
"url": "",
|
||||
"agent_card": build_agent_card(workspace_id),
|
||||
"delivery_mode": "poll",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": platform_url,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
resp = client.post(
|
||||
f"{platform_url}/registry/register",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
print(
|
||||
f"molecule-mcp: register rejected with HTTP {resp.status_code} — "
|
||||
f"the token in MOLECULE_WORKSPACE_TOKEN is invalid for workspace "
|
||||
f"{workspace_id}. Regenerate from the canvas → Tokens tab.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(3)
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"molecule-mcp: register POST returned HTTP %d: %s",
|
||||
resp.status_code,
|
||||
(resp.text or "")[:200],
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"molecule-mcp: registered workspace %s with platform",
|
||||
workspace_id,
|
||||
)
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("molecule-mcp: register POST failed: %s", exc)
|
||||
|
||||
|
||||
def heartbeat_loop(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
token: str,
|
||||
interval: float = HEARTBEAT_INTERVAL_SECONDS,
|
||||
) -> None:
|
||||
"""Daemon thread body: POST /registry/heartbeat every ``interval``s.
|
||||
|
||||
Failures are logged at WARNING and the loop continues. The thread
|
||||
exits when the main process does (daemon=True). Each iteration
|
||||
rebuilds the payload + headers — cheap and ensures token rotation
|
||||
via env var (rare but possible) is picked up on the next tick.
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
consecutive_auth_failures = 0
|
||||
while True:
|
||||
body = {
|
||||
"workspace_id": workspace_id,
|
||||
"error_rate": 0.0,
|
||||
"sample_error": "",
|
||||
"active_tasks": 0,
|
||||
"uptime_seconds": int(time.time() - start_time),
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": platform_url,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
resp = client.post(
|
||||
f"{platform_url}/registry/heartbeat",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code in (401, 403):
|
||||
consecutive_auth_failures += 1
|
||||
log_heartbeat_auth_failure(
|
||||
consecutive_auth_failures, workspace_id, resp.status_code,
|
||||
)
|
||||
elif resp.status_code >= 400:
|
||||
# Non-auth HTTP error — log, but DO NOT touch the
|
||||
# auth-failure counter (5xx blips, 429, etc. are
|
||||
# transient and unrelated to token validity).
|
||||
logger.warning(
|
||||
"molecule-mcp: heartbeat HTTP %d: %s",
|
||||
resp.status_code,
|
||||
(resp.text or "")[:200],
|
||||
)
|
||||
else:
|
||||
consecutive_auth_failures = 0
|
||||
persist_inbound_secret_from_heartbeat(resp)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("molecule-mcp: heartbeat failed: %s", exc)
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def log_heartbeat_auth_failure(count: int, workspace_id: str, status_code: int) -> None:
|
||||
"""Escalate consecutive heartbeat 401/403s from quiet WARNING to
|
||||
actionable ERROR.
|
||||
|
||||
The operator's first sign of trouble shouldn't be "tools 401 with no
|
||||
explanation" — that was the failure mode that motivated this code,
|
||||
triggered by a workspace being deleted server-side and its tokens
|
||||
revoked while the runtime kept heartbeating in silence.
|
||||
|
||||
Cadence:
|
||||
* count < threshold: WARNING per tick (transient — could be a
|
||||
platform blip, don't shout yet)
|
||||
* count == threshold: ERROR with re-onboard instructions
|
||||
(the first signal the operator can't miss)
|
||||
* count > threshold and (count - threshold) % relog == 0: re-log
|
||||
ERROR (so a session that started after the first ERROR still
|
||||
sees the message scrolling past in their logs)
|
||||
"""
|
||||
if count < HEARTBEAT_AUTH_LOUD_THRESHOLD:
|
||||
logger.warning(
|
||||
"molecule-mcp: heartbeat HTTP %d (auth failure %d/%d) — "
|
||||
"token may be revoked. Will retry; if persistent, regenerate "
|
||||
"from canvas → Tokens.",
|
||||
status_code, count, HEARTBEAT_AUTH_LOUD_THRESHOLD,
|
||||
)
|
||||
return
|
||||
# At or past the threshold — this is the loud actionable error.
|
||||
if count == HEARTBEAT_AUTH_LOUD_THRESHOLD or (
|
||||
count - HEARTBEAT_AUTH_LOUD_THRESHOLD
|
||||
) % HEARTBEAT_AUTH_RELOG_INTERVAL == 0:
|
||||
logger.error(
|
||||
"molecule-mcp: %d consecutive heartbeat auth failures (HTTP %d) — "
|
||||
"the token in MOLECULE_WORKSPACE_TOKEN has been REVOKED, likely "
|
||||
"because workspace %s was deleted server-side. The MCP server is "
|
||||
"still running but every platform call will fail. Regenerate the "
|
||||
"workspace + token from the canvas (Tokens tab), update your MCP "
|
||||
"config, and restart your runtime.",
|
||||
count, status_code, workspace_id,
|
||||
)
|
||||
|
||||
|
||||
def persist_inbound_secret_from_heartbeat(resp: object) -> None:
|
||||
"""Persist ``platform_inbound_secret`` from a heartbeat response, if any.
|
||||
|
||||
The platform's heartbeat handler returns the secret on every beat
|
||||
(mirroring /registry/register) so a workspace that lazy-healed the
|
||||
secret on the platform side — typical recovery path for a workspace
|
||||
whose row had a NULL ``platform_inbound_secret`` after a partial
|
||||
bootstrap — picks it up within one heartbeat tick instead of
|
||||
requiring a runtime restart.
|
||||
|
||||
Without this delivery path the chat-upload code path's "secret was
|
||||
just minted, will pick up on next heartbeat" 503 message is a lie
|
||||
and the workspace stays 401-forever until the operator restarts
|
||||
the runtime. Caught 2026-04-30 on hongmingwang tenant.
|
||||
|
||||
Failure is non-fatal: if the body isn't JSON, doesn't carry the
|
||||
field, or the disk write fails, the next heartbeat retries. This
|
||||
matches the cold-start register flow in main.py:319-323.
|
||||
"""
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception: # noqa: BLE001
|
||||
return
|
||||
if not isinstance(body, dict):
|
||||
return
|
||||
secret = body.get("platform_inbound_secret")
|
||||
if not secret:
|
||||
return
|
||||
try:
|
||||
from platform_inbound_auth import save_inbound_secret
|
||||
|
||||
save_inbound_secret(secret)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"molecule-mcp: persist inbound secret from heartbeat failed: %s", exc
|
||||
)
|
||||
|
||||
|
||||
def start_heartbeat_thread(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
token: str,
|
||||
) -> threading.Thread:
|
||||
"""Start the heartbeat daemon thread. Returns the Thread handle.
|
||||
|
||||
The MCP stdio loop runs in the foreground (asyncio); this thread
|
||||
runs alongside it. ``daemon=True`` so when the operator hits
|
||||
Ctrl-C / closes the runtime, the heartbeat dies with it instead
|
||||
of leaking and writing to a stale workspace.
|
||||
"""
|
||||
t = threading.Thread(
|
||||
target=heartbeat_loop,
|
||||
args=(platform_url, workspace_id, token),
|
||||
name="molecule-mcp-heartbeat",
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return t
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Inbox-poller spawn helpers for the standalone ``molecule-mcp`` wrapper.
|
||||
|
||||
Extracted from ``mcp_cli.py`` (RFC #2873 iter 3). The poller is the
|
||||
INBOUND side of the standalone path — without it, the universal MCP
|
||||
server is outbound-only (can call ``delegate_task`` /
|
||||
``send_message_to_user``, never observes canvas-user / peer-agent
|
||||
messages).
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``start_inbox_pollers(platform_url, workspace_ids)`` — activate the
|
||||
inbox singleton and spawn one daemon poller per workspace.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_inbox_pollers(platform_url: str, workspace_ids: list[str]) -> None:
|
||||
"""Activate the inbox singleton + spawn one poller daemon thread per workspace.
|
||||
|
||||
Done lazily here (not at module import) because importing inbox
|
||||
pulls in platform_auth, which only resolves cleanly AFTER env
|
||||
validation succeeds. Activation is idempotent within a process,
|
||||
so a stray double-call (e.g. test harness re-entering main) is
|
||||
harmless.
|
||||
|
||||
The poller threads are daemon=True — die with the main process.
|
||||
|
||||
Single-workspace path: one poller, single cursor file at the legacy
|
||||
location (``.mcp_inbox_cursor``). Cursor-key resolution falls back
|
||||
to the empty string for back-compat with operators whose existing
|
||||
on-disk cursor was written by the pre-multi-workspace code.
|
||||
|
||||
Multi-workspace path: N pollers, each with its own cursor file
|
||||
keyed by ``workspace_id[:8]``. Cursors live next to each other in
|
||||
configs_dir so an operator inspecting state sees all of them
|
||||
together.
|
||||
"""
|
||||
try:
|
||||
import inbox
|
||||
except ImportError as exc:
|
||||
logger.warning("molecule-mcp: inbox module unavailable: %s", exc)
|
||||
return
|
||||
|
||||
if len(workspace_ids) <= 1:
|
||||
# Back-compat exact: single-workspace mode reuses the legacy
|
||||
# cursor filename + cursor_path constructor arg, so an existing
|
||||
# operator's on-disk state isn't invalidated by upgrade.
|
||||
wsid = workspace_ids[0]
|
||||
state = inbox.InboxState(cursor_path=inbox.default_cursor_path())
|
||||
inbox.activate(state)
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
return
|
||||
|
||||
# Multi-workspace: per-workspace cursor file, one shared queue.
|
||||
cursor_paths = {wsid: inbox.default_cursor_path(wsid) for wsid in workspace_ids}
|
||||
state = inbox.InboxState(cursor_paths=cursor_paths)
|
||||
inbox.activate(state)
|
||||
for wsid in workspace_ids:
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
@@ -0,0 +1,146 @@
|
||||
"""Env validation + workspace resolution for the standalone ``molecule-mcp``.
|
||||
|
||||
Extracted from ``mcp_cli.py`` (RFC #2873 iter 3). Deals with the two
|
||||
shapes ``molecule-mcp`` accepts:
|
||||
|
||||
* Single-workspace legacy shape: ``WORKSPACE_ID`` + token from
|
||||
``MOLECULE_WORKSPACE_TOKEN`` or ``${CONFIGS_DIR}/.auth_token``.
|
||||
* Multi-workspace JSON shape: ``MOLECULE_WORKSPACES`` env var carries a
|
||||
JSON array of ``{"id": ..., "token": ...}`` entries.
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``resolve_workspaces()`` → ``(workspaces, errors)``.
|
||||
* ``read_token_file()`` → token text or ``""``.
|
||||
* ``print_missing_env_help(missing, have_token_file)`` — operator-help
|
||||
printer.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import configs_dir
|
||||
|
||||
|
||||
def resolve_workspaces() -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Return the list of ``(workspace_id, token)`` pairs to register.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. ``MOLECULE_WORKSPACES`` env var — JSON array of
|
||||
``{"id": "...", "token": "..."}`` objects. Activates the
|
||||
multi-workspace external-agent path (one process registered into
|
||||
N workspaces). When set, ``WORKSPACE_ID`` / ``MOLECULE_WORKSPACE_TOKEN``
|
||||
are IGNORED — the JSON is the source of truth.
|
||||
|
||||
2. Single-workspace fallback — ``WORKSPACE_ID`` env var + token from
|
||||
``MOLECULE_WORKSPACE_TOKEN`` or ``${CONFIGS_DIR}/.auth_token``.
|
||||
This is the pre-existing path; back-compat exact.
|
||||
|
||||
Returns ``(workspaces, errors)``:
|
||||
* ``workspaces``: list of ``(workspace_id, token)`` — non-empty
|
||||
on the happy path.
|
||||
* ``errors``: human-readable strings describing what's missing /
|
||||
malformed. ``main()`` surfaces these with the same shape as
|
||||
``print_missing_env_help`` so the operator's first run gives
|
||||
actionable output.
|
||||
|
||||
Why JSON env (not file): ergonomic for Claude Code MCP config (one
|
||||
string in ``mcpServers.molecule.env`` instead of a sidecar file)
|
||||
and for CI / launchers. A separate config-file path can be added
|
||||
later without breaking this.
|
||||
"""
|
||||
raw = os.environ.get("MOLECULE_WORKSPACES", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
return [], [
|
||||
f"MOLECULE_WORKSPACES is not valid JSON ({exc.msg} at pos "
|
||||
f"{exc.pos}). Expected: '[{{\"id\":\"<wsid>\",\"token\":"
|
||||
f"\"<tok>\"}},{{...}}]'"
|
||||
]
|
||||
if not isinstance(parsed, list) or not parsed:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACES must be a non-empty JSON array of "
|
||||
"{\"id\":\"...\",\"token\":\"...\"} objects"
|
||||
]
|
||||
out: list[tuple[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
errors: list[str] = []
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] is not an object — got {type(entry).__name__}"
|
||||
)
|
||||
continue
|
||||
wsid = str(entry.get("id", "")).strip()
|
||||
tok = str(entry.get("token", "")).strip()
|
||||
if not wsid or not tok:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] missing 'id' or 'token'"
|
||||
)
|
||||
continue
|
||||
if wsid in seen:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] duplicate workspace id {wsid!r}"
|
||||
)
|
||||
continue
|
||||
seen.add(wsid)
|
||||
out.append((wsid, tok))
|
||||
if errors:
|
||||
return [], errors
|
||||
return out, []
|
||||
|
||||
# Single-workspace back-compat path.
|
||||
wsid = os.environ.get("WORKSPACE_ID", "").strip()
|
||||
if not wsid:
|
||||
return [], ["WORKSPACE_ID (or MOLECULE_WORKSPACES) is required"]
|
||||
tok = os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip()
|
||||
if not tok:
|
||||
tok = read_token_file()
|
||||
if not tok:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token) is required"
|
||||
]
|
||||
return [(wsid, tok)], []
|
||||
|
||||
|
||||
def print_missing_env_help(missing: list[str], have_token_file: bool) -> None:
|
||||
print("molecule-mcp: missing required environment.\n", file=sys.stderr)
|
||||
print("Set the following before running molecule-mcp:", file=sys.stderr)
|
||||
print(" WORKSPACE_ID — your workspace UUID (from canvas)", file=sys.stderr)
|
||||
print(
|
||||
" PLATFORM_URL — base URL of your Molecule platform "
|
||||
"(e.g. https://your-tenant.staging.moleculesai.app)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if not have_token_file:
|
||||
print(
|
||||
" MOLECULE_WORKSPACE_TOKEN — bearer token for this workspace "
|
||||
"(canvas → Tokens tab)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print("", file=sys.stderr)
|
||||
print(f"Currently missing: {', '.join(missing)}", file=sys.stderr)
|
||||
|
||||
|
||||
def read_token_file() -> str:
|
||||
"""Read the token from the resolved configs dir's ``.auth_token`` if
|
||||
present.
|
||||
|
||||
Mirrors platform_auth._token_file's location resolution but without
|
||||
importing the heavy module here (that import triggers a2a_client's
|
||||
WORKSPACE_ID guard which is fine after env validation, but cheaper
|
||||
to inline a 4-line file read than pull in the whole stack just for
|
||||
the path).
|
||||
"""
|
||||
path = configs_dir.resolve() / ".auth_token"
|
||||
if not path.is_file():
|
||||
return ""
|
||||
try:
|
||||
return path.read_text().strip()
|
||||
except OSError:
|
||||
return ""
|
||||
@@ -339,8 +339,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen_send_src["src"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
@@ -367,8 +367,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen["send"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(
|
||||
peer_id, "do thing", source_workspace_id=ws_explicit,
|
||||
@@ -395,8 +395,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen["send"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Drift gate + direct surface tests for ``a2a_tools_delegation`` (RFC #2873 iter 4b).
|
||||
|
||||
The full behavior matrix for the three delegation MCP tools lives in
|
||||
``test_a2a_tools_impl.py`` (TestToolDelegateTask + TestToolDelegateTaskAsync
|
||||
+ TestToolCheckTaskStatus). Those exercise call paths through the
|
||||
``a2a_tools_delegation.foo`` module (after the iter 4b retarget).
|
||||
|
||||
This file owns the post-split contract:
|
||||
|
||||
1. **Drift gate** — every previously-public symbol on ``a2a_tools``
|
||||
(``tool_delegate_task``, ``tool_delegate_task_async``,
|
||||
``tool_check_task_status``, ``_delegate_sync_via_polling``,
|
||||
``_SYNC_POLL_INTERVAL_S``, ``_SYNC_POLL_BUDGET_S``) is the EXACT
|
||||
same callable / value as the new module's public name. A wrapper
|
||||
that drifted would silently bypass tests targeting the wrapper.
|
||||
|
||||
2. **Smoke import** — both modules import in either order without
|
||||
raising (the lazy ``report_activity`` import inside
|
||||
``tool_delegate_task`` is the contract that prevents a circular
|
||||
import; this test pins it).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_workspace_id(monkeypatch):
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ============== Drift gate ==============
|
||||
|
||||
class TestBackCompatAliases:
|
||||
def test_tool_delegate_task_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert a2a_tools.tool_delegate_task is a2a_tools_delegation.tool_delegate_task
|
||||
|
||||
def test_tool_delegate_task_async_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools.tool_delegate_task_async
|
||||
is a2a_tools_delegation.tool_delegate_task_async
|
||||
)
|
||||
|
||||
def test_tool_check_task_status_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools.tool_check_task_status
|
||||
is a2a_tools_delegation.tool_check_task_status
|
||||
)
|
||||
|
||||
def test_delegate_sync_via_polling_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools._delegate_sync_via_polling
|
||||
is a2a_tools_delegation._delegate_sync_via_polling
|
||||
)
|
||||
|
||||
def test_constants_match(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools._SYNC_POLL_INTERVAL_S
|
||||
== a2a_tools_delegation._SYNC_POLL_INTERVAL_S
|
||||
)
|
||||
assert (
|
||||
a2a_tools._SYNC_POLL_BUDGET_S
|
||||
== a2a_tools_delegation._SYNC_POLL_BUDGET_S
|
||||
)
|
||||
|
||||
|
||||
# ============== Smoke imports ==============
|
||||
|
||||
class TestImportContracts:
|
||||
def test_delegation_imports_without_a2a_tools_loaded(self, monkeypatch):
|
||||
"""``a2a_tools_delegation`` should NOT pull in ``a2a_tools`` at
|
||||
module-load time. The lazy ``from a2a_tools import report_activity``
|
||||
inside ``tool_delegate_task`` is the only legitimate hop.
|
||||
|
||||
Pin this so a future refactor that adds a top-level
|
||||
``from a2a_tools import …`` re-introduces the circular-import
|
||||
crash that motivated the lazy pattern.
|
||||
"""
|
||||
import sys
|
||||
# Drop both modules so we re-import in a controlled order
|
||||
for mod in ("a2a_tools", "a2a_tools_delegation"):
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
# Importing delegation first must succeed without a2a_tools
|
||||
# being loaded (because a2a_tools imports delegation, the
|
||||
# circular path ONLY closes if delegation top-level imports
|
||||
# something from a2a_tools).
|
||||
import a2a_tools_delegation # noqa: F401
|
||||
# If we got here, no circular import.
|
||||
assert "a2a_tools_delegation" in sys.modules
|
||||
|
||||
def test_a2a_tools_imports_via_delegation_re_export(self):
|
||||
"""The opposite direction: importing a2a_tools must trigger the
|
||||
delegation re-export so a2a_tools.tool_delegate_task resolves."""
|
||||
import a2a_tools
|
||||
assert hasattr(a2a_tools, "tool_delegate_task")
|
||||
assert hasattr(a2a_tools, "tool_delegate_task_async")
|
||||
assert hasattr(a2a_tools, "tool_check_task_status")
|
||||
|
||||
|
||||
# ============== Sync-poll budget env override ==============
|
||||
|
||||
class TestPollBudgetEnvOverride:
|
||||
def test_default_budget_when_env_unset(self):
|
||||
"""Module-level constant. Set DELEGATION_TIMEOUT before importing
|
||||
a2a_tools_delegation to override; default is 300.0."""
|
||||
# The constant is computed at module-load time. To verify the
|
||||
# override path we'd need to reload — skipped here because it's
|
||||
# tested at boot. This test pins the default for catch-the-eye
|
||||
# documentation.
|
||||
import a2a_tools_delegation
|
||||
# Whatever was set when the module first loaded — assert it's
|
||||
# numeric and >= the documented floor (180s healthsweep budget).
|
||||
assert isinstance(a2a_tools_delegation._SYNC_POLL_BUDGET_S, float)
|
||||
assert a2a_tools_delegation._SYNC_POLL_BUDGET_S >= 180.0
|
||||
@@ -226,16 +226,16 @@ class TestToolDelegateTask:
|
||||
|
||||
async def test_peer_not_found_returns_error(self):
|
||||
import a2a_tools
|
||||
with patch("a2a_tools.discover_peer", return_value=None):
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=None):
|
||||
result = await a2a_tools.tool_delegate_task("ws-missing", "task")
|
||||
assert "not found" in result or "Error" in result
|
||||
|
||||
async def test_offline_peer_returns_error(self):
|
||||
"""A peer with status=offline short-circuits before we hit the proxy."""
|
||||
import a2a_tools
|
||||
with patch("a2a_tools.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
|
||||
mc = _make_http_mock()
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "task")
|
||||
assert "offline" in result.lower()
|
||||
|
||||
@@ -261,8 +261,8 @@ class TestToolDelegateTask:
|
||||
captured["source"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
@@ -274,8 +274,8 @@ class TestToolDelegateTask:
|
||||
import a2a_tools
|
||||
|
||||
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="Task completed!"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="Task completed!"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
|
||||
|
||||
@@ -287,8 +287,8 @@ class TestToolDelegateTask:
|
||||
|
||||
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
|
||||
error_msg = f"{a2a_tools._A2A_ERROR_PREFIX}Agent error: something bad"
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value=error_msg), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value=error_msg), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
|
||||
|
||||
@@ -302,8 +302,8 @@ class TestToolDelegateTask:
|
||||
# Pre-populate the cache
|
||||
a2a_tools._peer_names["ws-cached"] = "CachedName"
|
||||
peer = {"id": "ws-cached", "url": "http://ws-cached.svc/a2a"} # no 'name'
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="done"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="done"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-cached", "task")
|
||||
|
||||
@@ -316,8 +316,8 @@ class TestToolDelegateTask:
|
||||
# Ensure not in cache
|
||||
a2a_tools._peer_names.pop("ws-nona000", None)
|
||||
peer = {"id": "ws-nona000", "url": "http://x.svc/a2a"} # no 'name'
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="ok"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="ok"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-nona000", "task")
|
||||
|
||||
@@ -349,7 +349,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(202, {"delegation_id": "d-123", "status": "delegated"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -362,7 +362,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(500, {"error": "internal"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
assert "Error" in result
|
||||
@@ -372,7 +372,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_exc=httpx.ConnectError("connection refused"))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
assert "Error" in result or "failed" in result.lower()
|
||||
@@ -393,7 +393,7 @@ class TestToolCheckTaskStatus:
|
||||
{"delegation_id": "d-2", "target_id": "ws-u", "status": "pending", "summary": "waiting"},
|
||||
]
|
||||
mc = _make_http_mock(get_resp=_resp(200, delegations))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -409,7 +409,7 @@ class TestToolCheckTaskStatus:
|
||||
{"delegation_id": "d-2", "status": "pending"},
|
||||
]
|
||||
mc = _make_http_mock(get_resp=_resp(200, delegations))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -421,7 +421,7 @@ class TestToolCheckTaskStatus:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-missing")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -432,7 +432,7 @@ class TestToolCheckTaskStatus:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(500, {"error": "db down"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
|
||||
|
||||
assert "Error" in result or "failed" in result.lower()
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Direct tests for ``a2a_tools_rbac`` (RFC #2873 iter 4a).
|
||||
|
||||
The full behavior matrix is exercised through ``a2a_tools._foo`` aliases
|
||||
in ``test_a2a_tools_impl.py``. This file pins:
|
||||
|
||||
1. **Drift gate** — ``a2a_tools._foo is a2a_tools_rbac.foo`` for every
|
||||
extracted symbol. A refactor that wraps or re-implements an alias
|
||||
fails this test.
|
||||
2. **Direct unit coverage** for each helper without going through the
|
||||
a2a_tools surface, so regressions in the small RBAC layer surface
|
||||
against THIS module's tests, not the 991-LOC tool-handler tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_workspace_id(monkeypatch):
|
||||
# a2a_client raises at import-time without WORKSPACE_ID. Setting it
|
||||
# once per test isolates the env so an absent value in CI doesn't
|
||||
# surface as an opaque RuntimeError from a2a_tools' import.
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ============== Drift gate ==============
|
||||
|
||||
class TestBackCompatAliases:
|
||||
"""Pin that every legacy underscore name in ``a2a_tools`` is the
|
||||
EXACT same callable / object as the new public name in
|
||||
``a2a_tools_rbac``. Catches accidental re-implementation in either
|
||||
direction."""
|
||||
|
||||
def test_role_permissions_is_same_object(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert a2a_tools._ROLE_PERMISSIONS is a2a_tools_rbac.ROLE_PERMISSIONS
|
||||
|
||||
def test_get_workspace_tier_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert a2a_tools._get_workspace_tier is a2a_tools_rbac.get_workspace_tier
|
||||
|
||||
def test_check_memory_write_permission_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert (
|
||||
a2a_tools._check_memory_write_permission
|
||||
is a2a_tools_rbac.check_memory_write_permission
|
||||
)
|
||||
|
||||
def test_check_memory_read_permission_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert (
|
||||
a2a_tools._check_memory_read_permission
|
||||
is a2a_tools_rbac.check_memory_read_permission
|
||||
)
|
||||
|
||||
def test_is_root_workspace_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert a2a_tools._is_root_workspace is a2a_tools_rbac.is_root_workspace
|
||||
|
||||
def test_auth_headers_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_rbac
|
||||
assert (
|
||||
a2a_tools._auth_headers_for_heartbeat
|
||||
is a2a_tools_rbac.auth_headers_for_heartbeat
|
||||
)
|
||||
|
||||
|
||||
# ============== get_workspace_tier ==============
|
||||
|
||||
class TestGetWorkspaceTier:
|
||||
def test_uses_config_when_available(self):
|
||||
"""Happy path: load_config returns an object with .tier."""
|
||||
import a2a_tools_rbac
|
||||
|
||||
class _Cfg:
|
||||
tier = 0
|
||||
|
||||
with patch("config.load_config", return_value=_Cfg()):
|
||||
assert a2a_tools_rbac.get_workspace_tier() == 0
|
||||
|
||||
def test_default_tier_when_config_lacks_attr(self):
|
||||
import a2a_tools_rbac
|
||||
|
||||
class _Cfg:
|
||||
pass
|
||||
|
||||
with patch("config.load_config", return_value=_Cfg()):
|
||||
# getattr default = 1
|
||||
assert a2a_tools_rbac.get_workspace_tier() == 1
|
||||
|
||||
def test_falls_back_to_env_var(self, monkeypatch):
|
||||
"""When load_config raises, read WORKSPACE_TIER from env."""
|
||||
import a2a_tools_rbac
|
||||
monkeypatch.setenv("WORKSPACE_TIER", "5")
|
||||
with patch("config.load_config", side_effect=RuntimeError("config unavailable")):
|
||||
assert a2a_tools_rbac.get_workspace_tier() == 5
|
||||
|
||||
def test_fallback_default_one_when_env_unset(self, monkeypatch):
|
||||
import a2a_tools_rbac
|
||||
monkeypatch.delenv("WORKSPACE_TIER", raising=False)
|
||||
with patch("config.load_config", side_effect=RuntimeError("boom")):
|
||||
assert a2a_tools_rbac.get_workspace_tier() == 1
|
||||
|
||||
|
||||
# ============== is_root_workspace ==============
|
||||
|
||||
class TestIsRootWorkspace:
|
||||
def test_tier_zero_is_root(self):
|
||||
import a2a_tools_rbac
|
||||
with patch.object(a2a_tools_rbac, "get_workspace_tier", return_value=0):
|
||||
assert a2a_tools_rbac.is_root_workspace() is True
|
||||
|
||||
def test_nonzero_tier_is_not_root(self):
|
||||
import a2a_tools_rbac
|
||||
for tier in (1, 2, 99):
|
||||
with patch.object(a2a_tools_rbac, "get_workspace_tier", return_value=tier):
|
||||
assert a2a_tools_rbac.is_root_workspace() is False, f"tier={tier}"
|
||||
|
||||
|
||||
# ============== check_memory_write_permission ==============
|
||||
|
||||
class _RBACCfg:
|
||||
"""Minimal config stub matching the load_config().rbac shape."""
|
||||
|
||||
def __init__(self, roles=None, allowed_actions=None):
|
||||
class _RBAC:
|
||||
pass
|
||||
self.rbac = _RBAC()
|
||||
self.rbac.roles = roles or ["operator"]
|
||||
self.rbac.allowed_actions = allowed_actions or {}
|
||||
|
||||
|
||||
class TestCheckMemoryWritePermission:
|
||||
def test_admin_role_grants_write(self):
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["admin"])):
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is True
|
||||
|
||||
def test_operator_role_grants_write(self):
|
||||
"""Operator is in the canonical ROLE_PERMISSIONS table with
|
||||
memory.write — must work without per-role overrides."""
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["operator"])):
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is True
|
||||
|
||||
def test_read_only_role_denies_write(self):
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["read-only"])):
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is False
|
||||
|
||||
def test_per_role_override_grants(self):
|
||||
"""Per-role override in allowed_actions wins over the canonical
|
||||
table — operators can grant write to memory-readonly via config."""
|
||||
import a2a_tools_rbac
|
||||
cfg = _RBACCfg(
|
||||
roles=["memory-readonly"],
|
||||
allowed_actions={"memory-readonly": {"memory.read", "memory.write"}},
|
||||
)
|
||||
with patch("config.load_config", return_value=cfg):
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is True
|
||||
|
||||
def test_per_role_override_denies(self):
|
||||
"""Per-role override that drops write blocks an operator from
|
||||
writing — the override is the authoritative source when present."""
|
||||
import a2a_tools_rbac
|
||||
cfg = _RBACCfg(
|
||||
roles=["operator"],
|
||||
allowed_actions={"operator": {"memory.read"}},
|
||||
)
|
||||
with patch("config.load_config", return_value=cfg):
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is False
|
||||
|
||||
def test_fail_closed_when_config_unavailable(self):
|
||||
"""Fail-closed contract: config outage falls back to ['operator']
|
||||
with no overrides — operator has memory.write in the canonical
|
||||
table, so write IS granted in this fallback. The fail-closed
|
||||
property is for ELEVATED ops (admin scope), not for the basic
|
||||
write that operator has by default. This test pins the contract:
|
||||
config errors do not silently grant admin."""
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", side_effect=RuntimeError("boom")):
|
||||
# operator has memory.write → True (preserved behavior)
|
||||
assert a2a_tools_rbac.check_memory_write_permission() is True
|
||||
|
||||
|
||||
# ============== check_memory_read_permission ==============
|
||||
|
||||
class TestCheckMemoryReadPermission:
|
||||
def test_admin_grants_read(self):
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["admin"])):
|
||||
assert a2a_tools_rbac.check_memory_read_permission() is True
|
||||
|
||||
def test_read_only_grants_read(self):
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["read-only"])):
|
||||
assert a2a_tools_rbac.check_memory_read_permission() is True
|
||||
|
||||
def test_unknown_role_denies(self):
|
||||
"""A role that's not in ROLE_PERMISSIONS and not in
|
||||
allowed_actions overrides denies by default."""
|
||||
import a2a_tools_rbac
|
||||
with patch("config.load_config", return_value=_RBACCfg(roles=["random-undefined-role"])):
|
||||
assert a2a_tools_rbac.check_memory_read_permission() is False
|
||||
|
||||
|
||||
# ============== auth_headers_for_heartbeat ==============
|
||||
|
||||
class TestAuthHeadersForHeartbeat:
|
||||
def test_no_workspace_id_uses_legacy_path(self):
|
||||
"""No-arg call routes to platform_auth.auth_headers() — the
|
||||
legacy single-token path."""
|
||||
import a2a_tools_rbac
|
||||
called: dict[str, object] = {}
|
||||
|
||||
def fake_auth_headers(*args):
|
||||
called["args"] = args
|
||||
return {"Authorization": "Bearer legacy-token"}
|
||||
|
||||
with patch("platform_auth.auth_headers", fake_auth_headers):
|
||||
out = a2a_tools_rbac.auth_headers_for_heartbeat()
|
||||
assert out == {"Authorization": "Bearer legacy-token"}
|
||||
# Legacy path is auth_headers() with no arg
|
||||
assert called["args"] == ()
|
||||
|
||||
def test_with_workspace_id_routes_per_workspace(self):
|
||||
import a2a_tools_rbac
|
||||
called: dict[str, object] = {}
|
||||
|
||||
def fake_auth_headers(wsid):
|
||||
called["wsid"] = wsid
|
||||
return {"Authorization": f"Bearer tok-{wsid}"}
|
||||
|
||||
with patch("platform_auth.auth_headers", fake_auth_headers):
|
||||
out = a2a_tools_rbac.auth_headers_for_heartbeat("ws-abc")
|
||||
assert out == {"Authorization": "Bearer tok-ws-abc"}
|
||||
assert called["wsid"] == "ws-abc"
|
||||
|
||||
def test_returns_empty_when_platform_auth_missing(self, monkeypatch):
|
||||
"""Older installs without platform_auth get {} so callers don't
|
||||
crash — they'll just send unauthed and the platform 401 handler
|
||||
surfaces the real error."""
|
||||
import a2a_tools_rbac
|
||||
# Force ImportError by setting sys.modules entry to None
|
||||
monkeypatch.setitem(sys.modules, "platform_auth", None)
|
||||
out = a2a_tools_rbac.auth_headers_for_heartbeat("ws-1")
|
||||
assert out == {}
|
||||
|
||||
|
||||
# ============== ROLE_PERMISSIONS canonical table ==============
|
||||
|
||||
class TestRolePermissionsTable:
|
||||
def test_admin_has_all_actions(self):
|
||||
import a2a_tools_rbac
|
||||
assert a2a_tools_rbac.ROLE_PERMISSIONS["admin"] == {
|
||||
"delegate", "approve", "memory.read", "memory.write",
|
||||
}
|
||||
|
||||
def test_read_only_has_only_memory_read(self):
|
||||
import a2a_tools_rbac
|
||||
assert a2a_tools_rbac.ROLE_PERMISSIONS["read-only"] == {"memory.read"}
|
||||
|
||||
def test_no_delegation_is_missing_delegate(self):
|
||||
import a2a_tools_rbac
|
||||
assert "delegate" not in a2a_tools_rbac.ROLE_PERMISSIONS["no-delegation"]
|
||||
|
||||
def test_no_approval_is_missing_approve(self):
|
||||
import a2a_tools_rbac
|
||||
assert "approve" not in a2a_tools_rbac.ROLE_PERMISSIONS["no-approval"]
|
||||
@@ -80,10 +80,10 @@ class TestFlagOffLegacyPath:
|
||||
async def fake_report_activity(*_a, **_kw):
|
||||
return None
|
||||
|
||||
with patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
with patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.report_activity", side_effect=fake_report_activity), \
|
||||
patch("a2a_tools._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
|
||||
patch("a2a_tools_delegation._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
|
||||
result = await a2a_tools.tool_delegate_task(
|
||||
"ws-target", "task body", source_workspace_id="ws-self"
|
||||
)
|
||||
@@ -105,7 +105,7 @@ class TestFlagOnDispatchFailures:
|
||||
import a2a_tools
|
||||
mc = _make_client(post_exc=httpx.ConnectError("network down"))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class TestFlagOnDispatchFailures:
|
||||
import a2a_tools
|
||||
mc = _make_client(post_resp=_resp(403, {"error": "forbidden"}))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class TestFlagOnDispatchFailures:
|
||||
# 202 Accepted but no delegation_id field — defensive shape check.
|
||||
mc = _make_client(post_resp=_resp(202, {"status": "delegated"}))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -168,7 +168,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[_resp(200, [completed_row])],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -196,7 +196,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[_resp(200, [failed_row])],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -234,7 +234,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=get_seq,
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -266,7 +266,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=get_seq,
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -304,7 +304,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[first_poll, second_poll],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
|
||||
@@ -555,16 +555,34 @@ def test_poll_once_self_notify_does_not_fire_notification(state: inbox.InboxStat
|
||||
def test_start_poller_thread_is_daemon(state: inbox.InboxState):
|
||||
"""Daemon flag is required so the poller dies with the parent
|
||||
process; a non-daemon poller would leak across `claude` restarts
|
||||
and write to a stale workspace."""
|
||||
and write to a stale workspace.
|
||||
|
||||
Stop_event is plumbed so the thread cleans up at the end of the
|
||||
test instead of leaking into later tests. Without cleanup, the
|
||||
daemon's ~10ms tick races with later tests that patch httpx.Client
|
||||
— the leaked thread sees their patched response and runs an
|
||||
unwanted iteration of _poll_once that double-counts mocked calls
|
||||
(caught when test_batch_fetcher_owns_client_when_not_supplied
|
||||
surfaced this on Python 3.11 CI but not 3.13 local).
|
||||
"""
|
||||
resp = _make_response(200, [])
|
||||
p, _ = _patch_httpx(resp)
|
||||
stop_event = threading.Event()
|
||||
with p, patch("platform_auth.auth_headers", return_value={}):
|
||||
# Use a very short interval so the loop body runs at least once
|
||||
# before we exit the test.
|
||||
t = inbox.start_poller_thread(state, "http://platform", "ws-1", interval=0.01)
|
||||
t = inbox.start_poller_thread(
|
||||
state, "http://platform", "ws-1", interval=0.01, stop_event=stop_event
|
||||
)
|
||||
time.sleep(0.05)
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
# Signal shutdown + wait for the thread to actually exit before
|
||||
# we leave the test scope. Without this join, the leaked thread
|
||||
# races with later tests' httpx patches.
|
||||
stop_event.set()
|
||||
t.join(timeout=2.0)
|
||||
assert not t.is_alive(), "poller thread did not exit on stop_event"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -577,6 +595,219 @@ def test_default_cursor_path_uses_configs_dir(monkeypatch, tmp_path: Path):
|
||||
assert inbox.default_cursor_path() == tmp_path / ".mcp_inbox_cursor"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 5b — BatchFetcher integration with the poll loop
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# These tests pin the cross-module contract between inbox._poll_once and
|
||||
# inbox_uploads.BatchFetcher: chat_upload_receive rows must be submitted
|
||||
# to a single BatchFetcher AND drained (URI cache populated) before any
|
||||
# subsequent message row is processed. Without the drain, the
|
||||
# rewrite_request_body path inside message_from_activity surfaces the
|
||||
# un-rewritten ``platform-pending:`` URI to the agent.
|
||||
|
||||
|
||||
def _upload_row(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": f"chat_upload_receive: {file_id}.pdf",
|
||||
"request_body": {
|
||||
"file_id": file_id,
|
||||
"name": f"{file_id}.pdf",
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 3,
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _message_row_referencing(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {
|
||||
"params": {
|
||||
"message": {
|
||||
"parts": [
|
||||
{"kind": "text", "text": "have a look"},
|
||||
{
|
||||
"kind": "file",
|
||||
"file": {
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"name": f"{file_id}.pdf",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:01Z",
|
||||
}
|
||||
|
||||
|
||||
def _patch_httpx_routing(activity_rows: list[dict], upload_bytes: bytes = b"PDF"):
|
||||
"""Replace ``httpx.Client`` so:
|
||||
|
||||
- GET /activity returns ``activity_rows``
|
||||
- GET /workspaces/.../content returns ``upload_bytes`` with content-type
|
||||
- POST /ack returns 200
|
||||
|
||||
Returns the patch context manager; tests use ``with p:``. Each new
|
||||
Client(...) gets a fresh MagicMock so the test can verify
|
||||
constructor-count expectations without pinning singletons.
|
||||
"""
|
||||
def _client_factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.__enter__ = MagicMock(return_value=c)
|
||||
c.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
def _get(url, params=None, headers=None):
|
||||
if "/activity" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = activity_rows
|
||||
resp.text = ""
|
||||
return resp
|
||||
if "/pending-uploads/" in url and "/content" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.content = upload_bytes
|
||||
resp.headers = {"content-type": "application/pdf"}
|
||||
resp.text = ""
|
||||
return resp
|
||||
resp = MagicMock()
|
||||
resp.status_code = 404
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
def _post(url, headers=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
c.get = MagicMock(side_effect=_get)
|
||||
c.post = MagicMock(side_effect=_post)
|
||||
c.close = MagicMock()
|
||||
return c
|
||||
|
||||
return patch("httpx.Client", side_effect=_client_factory)
|
||||
|
||||
|
||||
def test_poll_once_drains_uploads_before_processing_message_row(state: inbox.InboxState, tmp_path):
|
||||
"""The chat-message row's file.uri MUST be rewritten to the local
|
||||
workspace: URI by the time it lands in the InboxState queue. This
|
||||
requires BatchFetcher.wait_all() to run before message_from_activity
|
||||
on the second row.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Sandbox the on-disk staging dir so the test can't pollute the
|
||||
# workspace's real chat-uploads.
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [
|
||||
_upload_row("act-1", "file-A"),
|
||||
_message_row_referencing("act-2", "file-A"),
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF-bytes"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
|
||||
assert n == 1, "exactly one message row should be enqueued (the upload row is a side-effect, not a message)"
|
||||
queued = state.peek(10)
|
||||
assert len(queued) == 1
|
||||
# The contract this test exists to pin: the platform-pending: URI
|
||||
# was rewritten to workspace: BEFORE the message landed in the
|
||||
# state queue. message_from_activity mutates row['request_body']
|
||||
# in-place, so the rewritten URI is observable on the row dict
|
||||
# we passed in.
|
||||
rewritten_part = rows[1]["request_body"]["params"]["message"]["parts"][1]
|
||||
assert rewritten_part["file"]["uri"].startswith("workspace:"), (
|
||||
f"upload barrier broken: file.uri = {rewritten_part['file']['uri']!r}; "
|
||||
"rewrite_request_body ran before BatchFetcher.wait_all populated the cache"
|
||||
)
|
||||
# Cursor advanced past BOTH rows — upload-receive (act-1) is
|
||||
# acknowledged via the inbox cursor regardless of fetch outcome.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_with_only_upload_rows_drains_at_loop_end(state: inbox.InboxState, tmp_path):
|
||||
"""End-of-batch drain: a poll that contains ONLY upload rows (no
|
||||
chat-message row to trigger the inline drain) must still drain the
|
||||
BatchFetcher before _poll_once returns. Otherwise a future poll
|
||||
that picks up the corresponding chat-message row would race with
|
||||
in-flight fetches from the previous batch.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [_upload_row("act-1", "file-A"), _upload_row("act-2", "file-B")]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
# By the time _poll_once returned, the URI cache must be hot
|
||||
# for both file_ids — proves the end-of-loop drain ran.
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-A") is not None
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-B") is not None
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Upload rows are NOT message rows; queue stays empty.
|
||||
assert n == 0
|
||||
# Cursor advances past both upload rows.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_no_uploads_does_not_construct_batch_fetcher(state: inbox.InboxState):
|
||||
"""A batch with no upload-receive rows must not pay the BatchFetcher
|
||||
construction cost — the executor + httpx client allocation is
|
||||
deferred until the first upload row appears.
|
||||
"""
|
||||
import inbox_uploads
|
||||
|
||||
constructed: list[Any] = []
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
constructed.append(kwargs)
|
||||
# Don't actually run __init__; we never hit submit/wait_all.
|
||||
self._closed = False
|
||||
self._futures = []
|
||||
self._executor = MagicMock()
|
||||
self._client = MagicMock()
|
||||
self._own_client = False
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-1",
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {"parts": [{"type": "text", "text": "hi"}]},
|
||||
"created_at": "2026-04-30T22:00:00Z",
|
||||
},
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
with patch.object(inbox_uploads.BatchFetcher, "__init__", _patched_init), p:
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
assert n == 1
|
||||
assert constructed == [], "BatchFetcher must not be constructed when no upload rows are present"
|
||||
|
||||
|
||||
def test_default_cursor_path_falls_back_to_default(tmp_path, monkeypatch):
|
||||
"""When CONFIGS_DIR is unset, the cursor path resolves through
|
||||
configs_dir.resolve() — /configs in-container, ~/.molecule-workspace
|
||||
@@ -701,3 +932,165 @@ def test_set_notification_callback_none_clears(state: inbox.InboxState):
|
||||
state.record(_msg("act-1"))
|
||||
|
||||
assert received == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2 — chat_upload_receive rows route to inbox_uploads.fetch_and_stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_poll_once_skips_chat_upload_row_from_queue(state: inbox.InboxState, monkeypatch, tmp_path):
|
||||
"""A row with method='chat_upload_receive' must NOT enqueue as a
|
||||
chat message — it's a side-effect telling the workspace to fetch
|
||||
bytes. Pin the contract so a refactor that flattens the row loop
|
||||
can't silently re-enqueue these as 'empty A2A message' rows."""
|
||||
import inbox_uploads
|
||||
monkeypatch.setattr(inbox_uploads, "CHAT_UPLOAD_DIR", str(tmp_path / "chat-uploads"))
|
||||
inbox_uploads.get_cache().clear()
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-1",
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": "chat_upload_receive: foo.pdf",
|
||||
"request_body": {
|
||||
"file_id": "abc123",
|
||||
"name": "foo.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 4,
|
||||
"uri": "platform-pending:ws-1/abc123",
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
},
|
||||
]
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
fetch_called = []
|
||||
|
||||
def fake_fetch(row, **kwargs):
|
||||
fetch_called.append((row.get("id"), kwargs["workspace_id"]))
|
||||
return "workspace:/local/foo.pdf"
|
||||
|
||||
with p, patch.object(inbox_uploads, "fetch_and_stage", fake_fetch):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
# Not enqueued + cursor advanced.
|
||||
assert n == 0
|
||||
assert state.peek(10) == []
|
||||
assert state.load_cursor() == "act-1"
|
||||
# fetch_and_stage was invoked with the row and workspace_id.
|
||||
assert fetch_called == [("act-1", "ws-1")]
|
||||
|
||||
|
||||
def test_poll_once_chat_upload_row_then_chat_message_rewrites_uri(state: inbox.InboxState, monkeypatch, tmp_path):
|
||||
"""The classic ordering: upload-receive row first (lower id), chat
|
||||
message referencing platform-pending: URI second. The chat message
|
||||
that lands in the inbox must have its URI rewritten to the local
|
||||
workspace: URI before the agent sees it.
|
||||
"""
|
||||
import inbox_uploads
|
||||
monkeypatch.setattr(inbox_uploads, "CHAT_UPLOAD_DIR", str(tmp_path / "chat-uploads"))
|
||||
cache = inbox_uploads.get_cache()
|
||||
cache.clear()
|
||||
|
||||
# Pretend the fetch already populated the cache. (The real flow
|
||||
# populates it inside fetch_and_stage; we patch that to keep the
|
||||
# test focused on the rewrite contract.)
|
||||
cache.set("platform-pending:ws-1/abc123", "workspace:/workspace/.molecule/chat-uploads/xx-foo.pdf")
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-1",
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": "chat_upload_receive: foo.pdf",
|
||||
"request_body": {
|
||||
"file_id": "abc123",
|
||||
"name": "foo.pdf",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 4,
|
||||
"uri": "platform-pending:ws-1/abc123",
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "act-2",
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {
|
||||
"params": {
|
||||
"message": {
|
||||
"parts": [
|
||||
{"kind": "text", "text": "look at this"},
|
||||
{
|
||||
"kind": "file",
|
||||
"file": {
|
||||
"uri": "platform-pending:ws-1/abc123",
|
||||
"name": "foo.pdf",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:01Z",
|
||||
},
|
||||
]
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
|
||||
def fake_fetch(row, **kwargs):
|
||||
return "workspace:/workspace/.molecule/chat-uploads/xx-foo.pdf"
|
||||
|
||||
with p, patch.object(inbox_uploads, "fetch_and_stage", fake_fetch):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
# Only the chat message is enqueued.
|
||||
assert n == 1
|
||||
queue = state.peek(10)
|
||||
assert len(queue) == 1
|
||||
msg = queue[0]
|
||||
assert msg.activity_id == "act-2"
|
||||
# The URI in the row's request_body was mutated by message_from_activity
|
||||
# → rewrite_request_body. Re-extracting reveals the rewritten value.
|
||||
rewritten = rows[1]["request_body"]["params"]["message"]["parts"][1]["file"]["uri"]
|
||||
assert rewritten == "workspace:/workspace/.molecule/chat-uploads/xx-foo.pdf"
|
||||
|
||||
|
||||
def test_poll_once_chat_upload_row_advances_cursor_even_on_fetch_failure(
|
||||
state: inbox.InboxState, monkeypatch, tmp_path
|
||||
):
|
||||
"""A permanent network failure on /content must NOT stall the cursor
|
||||
— otherwise one bad upload blocks all real chat traffic for the
|
||||
workspace. fetch_and_stage returns None on failure, but the row is
|
||||
still considered handled from the cursor's perspective."""
|
||||
import inbox_uploads
|
||||
monkeypatch.setattr(inbox_uploads, "CHAT_UPLOAD_DIR", str(tmp_path / "chat-uploads"))
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-broken",
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": "chat_upload_receive: doomed.pdf",
|
||||
"request_body": {
|
||||
"file_id": "doom",
|
||||
"name": "doomed.pdf",
|
||||
"uri": "platform-pending:ws-1/doom",
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
},
|
||||
]
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
|
||||
def fake_fetch(row, **kwargs):
|
||||
return None # network failure
|
||||
|
||||
with p, patch.object(inbox_uploads, "fetch_and_stage", fake_fetch):
|
||||
inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
assert state.peek(10) == []
|
||||
assert state.load_cursor() == "act-broken"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
import mcp_cli
|
||||
import mcp_heartbeat
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -739,8 +740,13 @@ def test_heartbeat_loop_calls_persist_on_success(monkeypatch):
|
||||
def fake_persist(resp):
|
||||
saw.append(resp)
|
||||
|
||||
# Patch on mcp_heartbeat — that's where heartbeat_loop's internal
|
||||
# name resolution looks up persist_inbound_secret_from_heartbeat
|
||||
# after the RFC #2873 iter 3 split. The mcp_cli._persist_…_from_heartbeat
|
||||
# back-compat re-export still exists, but patching it here would not
|
||||
# affect the loop body.
|
||||
monkeypatch.setattr(
|
||||
mcp_cli, "_persist_inbound_secret_from_heartbeat", fake_persist
|
||||
mcp_heartbeat, "persist_inbound_secret_from_heartbeat", fake_persist
|
||||
)
|
||||
|
||||
class FakeResp:
|
||||
@@ -786,8 +792,8 @@ def test_heartbeat_loop_skips_persist_on_4xx(monkeypatch):
|
||||
"""Heartbeat 4xx error path must NOT invoke persist (no body to trust)."""
|
||||
saw: list[object] = []
|
||||
monkeypatch.setattr(
|
||||
mcp_cli,
|
||||
"_persist_inbound_secret_from_heartbeat",
|
||||
mcp_heartbeat,
|
||||
"persist_inbound_secret_from_heartbeat",
|
||||
lambda r: saw.append(r),
|
||||
)
|
||||
|
||||
@@ -899,7 +905,7 @@ def test_heartbeat_single_401_logs_warning_not_error(monkeypatch, caplog):
|
||||
transient platform blip. Log at WARNING; don't shout."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_cli")
|
||||
caplog.set_level(logging.WARNING, logger="mcp_heartbeat")
|
||||
|
||||
_multi_iter_runner(monkeypatch, [401])
|
||||
|
||||
@@ -923,7 +929,7 @@ def test_heartbeat_three_consecutive_401s_escalates_to_error(monkeypatch, caplog
|
||||
LOUD ERROR with re-onboard guidance — not buried at WARNING."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_cli")
|
||||
caplog.set_level(logging.WARNING, logger="mcp_heartbeat")
|
||||
|
||||
_multi_iter_runner(monkeypatch, [401, 401, 401])
|
||||
|
||||
@@ -949,7 +955,7 @@ def test_heartbeat_403_treated_same_as_401(monkeypatch, caplog):
|
||||
not authorized for this workspace). Same escalation path."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_cli")
|
||||
caplog.set_level(logging.WARNING, logger="mcp_heartbeat")
|
||||
|
||||
_multi_iter_runner(monkeypatch, [403, 403, 403])
|
||||
|
||||
@@ -963,7 +969,7 @@ def test_heartbeat_recovery_resets_consecutive_counter(monkeypatch, caplog):
|
||||
later should NOT immediately escalate."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_cli")
|
||||
caplog.set_level(logging.WARNING, logger="mcp_heartbeat")
|
||||
|
||||
# Two 401s, then 200, then one 401. If counter resets correctly,
|
||||
# the final 401 is "1 consecutive" and should NOT escalate.
|
||||
@@ -982,7 +988,7 @@ def test_heartbeat_500_does_not_increment_auth_counter(monkeypatch, caplog):
|
||||
misleading the operator."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_cli")
|
||||
caplog.set_level(logging.WARNING, logger="mcp_heartbeat")
|
||||
|
||||
_multi_iter_runner(monkeypatch, [500, 500, 500])
|
||||
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
"""RFC #2873 iter 3 — drift gate + behavior tests for the post-split surface.
|
||||
|
||||
The bulk of the heartbeat / resolver behavior is exercised by
|
||||
``test_mcp_cli.py`` and ``test_mcp_cli_multi_workspace.py`` through the
|
||||
``mcp_cli._symbol`` back-compat aliases. This file pins:
|
||||
|
||||
1. The split is **behavior-neutral via aliasing** — every previously-
|
||||
exposed ``mcp_cli._foo`` symbol is the SAME callable as the new
|
||||
module's authoritative function. If a refactor accidentally drops
|
||||
an alias or points it at a stale copy, this fails.
|
||||
|
||||
2. ``mcp_inbox_pollers.start_inbox_pollers`` works for both single-
|
||||
workspace (legacy back-compat) and multi-workspace shapes.
|
||||
``mcp_cli`` had no direct test for this branch before the split.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
import mcp_cli
|
||||
import mcp_heartbeat
|
||||
import mcp_inbox_pollers
|
||||
import mcp_workspace_resolver
|
||||
|
||||
|
||||
# ============== Drift gate: back-compat aliases point at the real fn ==============
|
||||
|
||||
class TestBackCompatAliases:
|
||||
"""Pin that ``mcp_cli._foo is real_fn``. A test that re-implements
|
||||
the alias would still pass — the ``is`` check guarantees we didn't
|
||||
create a wrapper that drifts."""
|
||||
|
||||
def test_heartbeat_aliases(self):
|
||||
assert mcp_cli._build_agent_card is mcp_heartbeat.build_agent_card
|
||||
assert mcp_cli._platform_register is mcp_heartbeat.platform_register
|
||||
assert mcp_cli._heartbeat_loop is mcp_heartbeat.heartbeat_loop
|
||||
assert mcp_cli._log_heartbeat_auth_failure is mcp_heartbeat.log_heartbeat_auth_failure
|
||||
assert (
|
||||
mcp_cli._persist_inbound_secret_from_heartbeat
|
||||
is mcp_heartbeat.persist_inbound_secret_from_heartbeat
|
||||
)
|
||||
assert mcp_cli._start_heartbeat_thread is mcp_heartbeat.start_heartbeat_thread
|
||||
|
||||
def test_resolver_aliases(self):
|
||||
assert mcp_cli._resolve_workspaces is mcp_workspace_resolver.resolve_workspaces
|
||||
assert mcp_cli._print_missing_env_help is mcp_workspace_resolver.print_missing_env_help
|
||||
assert mcp_cli._read_token_file is mcp_workspace_resolver.read_token_file
|
||||
|
||||
def test_inbox_pollers_alias(self):
|
||||
assert mcp_cli._start_inbox_pollers is mcp_inbox_pollers.start_inbox_pollers
|
||||
|
||||
def test_constants_match(self):
|
||||
assert (
|
||||
mcp_cli.HEARTBEAT_INTERVAL_SECONDS
|
||||
== mcp_heartbeat.HEARTBEAT_INTERVAL_SECONDS
|
||||
)
|
||||
assert (
|
||||
mcp_cli._HEARTBEAT_AUTH_LOUD_THRESHOLD
|
||||
== mcp_heartbeat.HEARTBEAT_AUTH_LOUD_THRESHOLD
|
||||
)
|
||||
assert (
|
||||
mcp_cli._HEARTBEAT_AUTH_RELOG_INTERVAL
|
||||
== mcp_heartbeat.HEARTBEAT_AUTH_RELOG_INTERVAL
|
||||
)
|
||||
|
||||
|
||||
# ============== mcp_inbox_pollers — both shapes + degraded import ==============
|
||||
|
||||
class _FakeInboxState:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def _install_fake_inbox(monkeypatch):
|
||||
"""Inject a fake ``inbox`` module so we observe the spawn calls
|
||||
without pulling in the real platform_auth dependency tree."""
|
||||
activations: list[_FakeInboxState] = []
|
||||
spawned: list[tuple[_FakeInboxState, str, str]] = []
|
||||
cursor_paths: list[str] = []
|
||||
|
||||
def default_cursor_path(wsid=None):
|
||||
# Mirror the real signature: optional wsid → distinct path per id,
|
||||
# absent → legacy single path.
|
||||
path = f"/tmp/.mcp_inbox_cursor.{wsid[:8]}" if wsid else "/tmp/.mcp_inbox_cursor"
|
||||
cursor_paths.append(path)
|
||||
return path
|
||||
|
||||
def activate(state):
|
||||
activations.append(state)
|
||||
|
||||
def start_poller_thread(state, platform_url, wsid):
|
||||
spawned.append((state, platform_url, wsid))
|
||||
|
||||
fake = types.ModuleType("inbox")
|
||||
fake.InboxState = _FakeInboxState
|
||||
fake.activate = activate
|
||||
fake.default_cursor_path = default_cursor_path
|
||||
fake.start_poller_thread = start_poller_thread
|
||||
monkeypatch.setitem(sys.modules, "inbox", fake)
|
||||
return activations, spawned, cursor_paths
|
||||
|
||||
|
||||
class TestStartInboxPollers:
|
||||
def test_single_workspace_uses_legacy_cursor_path(self, monkeypatch):
|
||||
"""Back-compat exact: single-workspace mode reuses the legacy
|
||||
cursor filename so an existing operator's on-disk state isn't
|
||||
invalidated by upgrade."""
|
||||
activations, spawned, cursor_paths = _install_fake_inbox(monkeypatch)
|
||||
|
||||
mcp_inbox_pollers.start_inbox_pollers(
|
||||
"https://test.moleculesai.app", ["ws-only-one"]
|
||||
)
|
||||
|
||||
assert len(activations) == 1, "exactly one inbox.activate call"
|
||||
assert len(spawned) == 1, "exactly one poller thread spawned"
|
||||
# Single-workspace path uses default_cursor_path() with no arg —
|
||||
# the cursor_path captured here must be the legacy filename
|
||||
# (no per-ws suffix).
|
||||
assert cursor_paths == ["/tmp/.mcp_inbox_cursor"]
|
||||
# State carries cursor_path, not cursor_paths
|
||||
state = activations[0]
|
||||
assert state.kwargs == {"cursor_path": "/tmp/.mcp_inbox_cursor"}
|
||||
# Spawned poller is for the right workspace
|
||||
assert spawned[0] == (state, "https://test.moleculesai.app", "ws-only-one")
|
||||
|
||||
def test_multi_workspace_uses_per_workspace_cursor_paths(self, monkeypatch):
|
||||
"""Multi-workspace path: per-workspace cursor file, one shared
|
||||
InboxState. N pollers, each pointed at the same state so the
|
||||
agent's inbox_peek/pop sees a merged view."""
|
||||
activations, spawned, _ = _install_fake_inbox(monkeypatch)
|
||||
|
||||
wsids = ["ws-aaaaaaaa", "ws-bbbbbbbb", "ws-cccccccc"]
|
||||
mcp_inbox_pollers.start_inbox_pollers(
|
||||
"https://test.moleculesai.app", wsids
|
||||
)
|
||||
|
||||
# One state, one activate, three pollers
|
||||
assert len(activations) == 1
|
||||
assert len(spawned) == 3
|
||||
state = activations[0]
|
||||
# Multi-workspace state carries cursor_paths (mapping)
|
||||
assert "cursor_paths" in state.kwargs
|
||||
assert set(state.kwargs["cursor_paths"].keys()) == set(wsids)
|
||||
# All pollers share the same state
|
||||
for s, _url, _wsid in spawned:
|
||||
assert s is state
|
||||
# All workspace ids covered
|
||||
assert sorted(t[2] for t in spawned) == sorted(wsids)
|
||||
|
||||
def test_inbox_module_unavailable_logs_and_returns(self, monkeypatch, caplog):
|
||||
"""If ``import inbox`` fails (older install or stripped
|
||||
runtime), spawn must NOT raise — log a warning and continue.
|
||||
The MCP server can still serve outbound tools."""
|
||||
import logging
|
||||
|
||||
# Force ImportError by injecting a module sentinel that raises.
|
||||
class _Boom:
|
||||
def __getattr__(self, _name):
|
||||
raise ImportError("inbox stripped from this build")
|
||||
|
||||
# Setting sys.modules["inbox"] to a broken object isn't enough —
|
||||
# the import statement reads sys.modules first; if the entry is
|
||||
# truthy, Python returns it. We need to force the import to raise.
|
||||
# Easiest: pre-poison sys.modules so the `import inbox` line
|
||||
# raises by setting the entry to None (Python special-cases None
|
||||
# as "explicit ImportError").
|
||||
monkeypatch.setitem(sys.modules, "inbox", None)
|
||||
|
||||
caplog.set_level(logging.WARNING, logger="mcp_inbox_pollers")
|
||||
# Should not raise.
|
||||
mcp_inbox_pollers.start_inbox_pollers(
|
||||
"https://test.moleculesai.app", ["ws-1"]
|
||||
)
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert any("inbox module unavailable" in r.message for r in warnings), (
|
||||
f"expected a 'inbox module unavailable' warning, got: "
|
||||
f"{[r.message for r in warnings]}"
|
||||
)
|
||||
|
||||
|
||||
# ============== mcp_heartbeat.build_agent_card — short direct tests ==============
|
||||
|
||||
class TestBuildAgentCardDirect:
|
||||
"""Spot-check the new module's public surface; the full test matrix
|
||||
lives in ``test_mcp_cli.py`` reaching through ``mcp_cli._build_agent_card``.
|
||||
"""
|
||||
|
||||
def test_default_card_shape(self, monkeypatch):
|
||||
for v in ("MOLECULE_AGENT_NAME", "MOLECULE_AGENT_DESCRIPTION", "MOLECULE_AGENT_SKILLS"):
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
card = mcp_heartbeat.build_agent_card("8dad3e29-c32a-4ec7-9ea7-94fe2d2d98ec")
|
||||
assert card == {"name": "molecule-mcp-8dad3e29", "skills": []}
|
||||
|
||||
def test_skills_csv_split_and_trim(self, monkeypatch):
|
||||
monkeypatch.setenv("MOLECULE_AGENT_SKILLS", "research, , code-review,memory-curation, ")
|
||||
card = mcp_heartbeat.build_agent_card("ws-1")
|
||||
assert card["skills"] == [
|
||||
{"name": "research"},
|
||||
{"name": "code-review"},
|
||||
{"name": "memory-curation"},
|
||||
]
|
||||
|
||||
|
||||
# ============== mcp_workspace_resolver — short direct tests ==============
|
||||
|
||||
class TestResolveWorkspacesDirect:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate(self, monkeypatch, tmp_path):
|
||||
for v in ("WORKSPACE_ID", "MOLECULE_WORKSPACE_TOKEN", "MOLECULE_WORKSPACES"):
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
monkeypatch.setenv("CONFIGS_DIR", str(tmp_path))
|
||||
yield
|
||||
|
||||
def test_single_workspace_via_env(self, monkeypatch):
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-1")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "tok")
|
||||
out, errors = mcp_workspace_resolver.resolve_workspaces()
|
||||
assert out == [("ws-1", "tok")]
|
||||
assert errors == []
|
||||
|
||||
def test_multi_workspace_via_json_env(self, monkeypatch):
|
||||
monkeypatch.setenv(
|
||||
"MOLECULE_WORKSPACES",
|
||||
'[{"id":"ws-a","token":"a"},{"id":"ws-b","token":"b"}]',
|
||||
)
|
||||
out, errors = mcp_workspace_resolver.resolve_workspaces()
|
||||
assert out == [("ws-a", "a"), ("ws-b", "b")]
|
||||
assert errors == []
|
||||
Reference in New Issue
Block a user