Compare commits

..

1 Commits

Author SHA1 Message Date
fullstack-engineer 27fea03c07 fix(provisioner): wire collectCPConfigFiles into CPProvisioner.Start (OFFSEC-010)
Block internal-flavored paths / Block forbidden paths (pull_request) Successful in 36s
CI / Detect changes (pull_request) Successful in 36s
Harness Replays / detect-changes (pull_request) Successful in 25s
Handlers Postgres Integration / detect-changes (pull_request) Successful in 45s
E2E API Smoke Test / detect-changes (pull_request) Successful in 53s
Runtime PR-Built Compatibility / detect-changes (pull_request) Successful in 37s
Secret scan / Scan diff for credential-shaped strings (pull_request) Successful in 20s
gate-check-v3 / gate-check (pull_request) Successful in 31s
security-review / approved (pull_request) Successful in 27s
qa-review / approved (pull_request) Successful in 30s
sop-tier-check / tier-check (pull_request) Successful in 21s
lint-required-no-paths / lint-required-no-paths (pull_request) Successful in 1m25s
CI / Canvas (Next.js) (pull_request) Successful in 8s
CI / Shellcheck (E2E scripts) (pull_request) Successful in 7s
CI / Python Lint & Test (pull_request) Successful in 17s
Handlers Postgres Integration / Handlers Postgres Integration (pull_request) Successful in 18s
Harness Replays / Harness Replays (pull_request) Successful in 16s
Runtime PR-Built Compatibility / PR-built wheel + import smoke (pull_request) Successful in 15s
CI / Canvas Deploy Reminder (pull_request) Has been skipped
sop-checklist / all-items-acked (pull_request) [info tier:low] acked: 0/7 — missing: comprehensive-testing, local-postgres-e2e, staging-smoke, +4 — body-unfilled: comprehensive-testing, l
E2E API Smoke Test / E2E API Smoke Test (pull_request) Successful in 2m38s
audit-force-merge / audit (pull_request) Has been skipped
CI / Platform (Go) (pull_request) Failing after 8m27s
CI / all-required (pull_request) Successful in 6s
OFFSEC-010 added collectCPConfigFiles but it was unreachable dead code —
CPProvisioner.Start never called it, so the control plane received no
config files at boot time. Both Core-BE and Core-DevOps A2A provision
backends depend on this wiring.

Two wiring steps:
1. Add ConfigFiles map[string]string to cpProvisionRequest. Value is
   base64-encoded (done by collectCPConfigFiles) so it passes safely
   over the wire without shell metacharacter concerns.
2. Call collectCPConfigFiles(cfg) before building the request. Errors
   propagate immediately — a workspace with no config files cannot
   function.

Also adds the collectCPConfigFiles function itself (OFFSEC-010 from PR
#1075) on top of staging, and fixes missing os + path/filepath imports
in cp_provisioner_test.go (pre-existing issue from the conflict-
resolution merge that introduced the original tests).

Test coverage: TestCollectCPConfigFiles_SkipsSymlinks,
TestCollectCPConfigFiles_RejectsRootSymlink,
TestStart_PassesConfigFilesToCP, TestStart_CollectConfigFilesErrorSurfaces.

Closes #1077
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-14 22:14:26 +00:00
44 changed files with 364 additions and 2318 deletions
+8 -11
View File
@@ -118,19 +118,17 @@ _DIRECTIVE_RE = re.compile(
def parse_directives(
comment_body: str,
numeric_aliases: dict[int, str],
) -> tuple[list[tuple[str, str, str]], list]:
) -> list[tuple[str, str, str]]:
"""Extract /sop-ack and /sop-revoke directives from a comment body.
Returns (directives, na_directives) where:
directives is a list of (kind, canonical_slug, note) tuples
kind is "sop-ack" or "sop-revoke"
canonical_slug is the normalized form (or "" if unparseable)
note is the trailing free-text (may be "")
na_directives is reserved for future N/A handling (always [] for now)
Returns a list of (kind, canonical_slug, note) tuples where:
kind is "sop-ack" or "sop-revoke"
canonical_slug is the normalized form (or "" if unparseable)
note is the trailing free-text (may be "")
"""
out: list[tuple[str, str, str]] = []
if not comment_body:
return out, []
return out
for m in _DIRECTIVE_RE.finditer(comment_body):
kind = m.group(1)
raw_slug = (m.group(2) or "").strip()
@@ -161,7 +159,7 @@ def parse_directives(
# If we collapsed multi-word slug into kebab and there's a
# trailing-text group too, append it.
out.append((kind, canonical, note_from_group))
return out, []
return out
# ---------------------------------------------------------------------------
@@ -251,8 +249,7 @@ def compute_ack_state(
user = (c.get("user") or {}).get("login", "")
if not user:
continue
directives, _na = parse_directives(body, numeric_aliases)
for kind, slug, _note in directives:
for kind, slug, _note in parse_directives(body, numeric_aliases):
if not slug:
unparseable_per_user[user] = unparseable_per_user.get(user, 0) + 1
continue
+21 -20
View File
@@ -133,6 +133,7 @@ jobs:
# the name match works on PRs that don't touch workspace-server/).
platform-build:
name: Platform (Go)
needs: changes
runs-on: ubuntu-latest
# mc#774 (closed 2026-05-14): Phase 4 flip of the platform-build job.
# Phase 4 (#656) originally flipped this to continue-on-error: false based on
@@ -153,29 +154,29 @@ jobs:
run:
working-directory: workspace-server
steps:
- if: false
- if: needs.changes.outputs.platform != 'true'
working-directory: .
run: echo "No platform/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
- if: always()
- if: needs.changes.outputs.platform == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- if: always()
- if: needs.changes.outputs.platform == 'true'
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
with:
go-version: 'stable'
- if: always()
- if: needs.changes.outputs.platform == 'true'
run: go mod download
- if: always()
- if: needs.changes.outputs.platform == 'true'
run: go build ./cmd/server
# CLI (molecli) moved to standalone repo: git.moleculesai.app/molecule-ai/molecule-cli
- if: always()
- if: needs.changes.outputs.platform == 'true'
run: go vet ./...
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Install golangci-lint
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Run golangci-lint
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Diagnostic — per-package verbose 60s
run: |
set +e
@@ -191,7 +192,7 @@ jobs:
echo "::endgroup::"
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
continue-on-error: true
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Run tests with race detection and coverage
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
# full ./... suite with race detection + coverage. A 10m per-step timeout
@@ -199,7 +200,7 @@ jobs:
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Per-file coverage report
# Advisory — lists every source file with its coverage so reviewers
# can see at-a-glance where gaps are. Sorted ascending so the worst
@@ -213,7 +214,7 @@ jobs:
END {for (f in s) printf "%6.1f%% %s\n", s[f]/c[f], f}' \
| sort -n
- if: always()
- if: needs.changes.outputs.platform == 'true'
name: Check coverage thresholds
# Enforces two gates from #1823 Layer 1:
# 1. Total floor (25% — ratchet plan in COVERAGE_FLOOR.md).
@@ -301,28 +302,28 @@ jobs:
# siblings — verified empirically on PR #2314).
canvas-build:
name: Canvas (Next.js)
needs: changes
runs-on: ubuntu-latest
timeout-minutes: 20
# Phase 4 (RFC #219 §1): confirmed green on main 2026-05-12.
continue-on-error: false
defaults:
run:
working-directory: canvas
steps:
- if: false
- if: needs.changes.outputs.canvas != 'true'
working-directory: .
run: echo "No canvas/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
- if: always()
- if: needs.changes.outputs.canvas == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- if: always()
- if: needs.changes.outputs.canvas == 'true'
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version: '22'
- if: always()
- if: needs.changes.outputs.canvas == 'true'
run: rm -f package-lock.json && npm install
- if: always()
- if: needs.changes.outputs.canvas == 'true'
run: npm run build
- if: always()
- if: needs.changes.outputs.canvas == 'true'
name: Run tests with coverage
# Coverage instrumentation is configured in canvas/vitest.config.ts
# (provider: v8, reporters: text + html + json-summary). Step 2 of
@@ -331,7 +332,7 @@ jobs:
# tracked in #1815) after the team sees what current coverage is.
run: npx vitest run --coverage
- name: Upload coverage summary as artifact
if: always()
if: needs.changes.outputs.canvas == 'true' && always()
# Pinned to v3 for Gitea act_runner v0.6 compatibility — v4+ uses
# the GHES 3.10+ artifact protocol that Gitea 1.22.x does NOT
# implement, surfacing as `GHESNotSupportedError: @actions/artifact
-26
View File
@@ -962,32 +962,6 @@ function MyChatPanel({ workspaceId, data }: Props) {
</div>
</div>
)}
{/* talk_to_user disabled banner — shown when the workspace has
talk_to_user_enabled=false. The agent cannot send canvas messages;
the user can re-enable the ability from here without opening settings. */}
{data.talkToUserEnabled === false && (
<div className="flex items-center gap-2 px-3 py-2 bg-surface-sunken border-b border-line/40 shrink-0">
<svg width="14" height="14" viewBox="0 0 16 16" fill="none" aria-hidden="true" className="shrink-0 text-ink-mid">
<path d="M8 1a7 7 0 1 0 0 14A7 7 0 0 0 8 1Zm0 10.5a.75.75 0 1 1 0-1.5.75.75 0 0 1 0 1.5ZM8 4a.75.75 0 0 1 .75.75v4a.75.75 0 0 1-1.5 0v-4A.75.75 0 0 1 8 4Z" fill="currentColor"/>
</svg>
<span className="text-[10px] text-ink-mid flex-1">
Agent is not enabled to chat with you.
</span>
<button
onClick={async () => {
try {
await api.patch(`/workspaces/${workspaceId}/abilities`, { talk_to_user_enabled: true });
useCanvasStore.getState().updateNodeData(workspaceId, { talkToUserEnabled: true });
} catch {
// ignore — user will see no change and can retry
}
}}
className="px-2 py-0.5 text-[10px] font-medium bg-accent/10 hover:bg-accent/20 text-accent rounded border border-accent/30 transition-colors shrink-0"
>
Enable
</button>
</div>
)}
{/* Messages */}
<div ref={containerRef} className="flex-1 overflow-y-auto p-3 space-y-3">
{loading && (
-4
View File
@@ -519,10 +519,6 @@ export function buildNodesAndEdges(
// #2054 — server-declared per-workspace provisioning timeout.
// Falls through to the runtime profile when null/absent.
provisionTimeoutMs: ws.provision_timeout_ms ?? null,
// Workspace abilities — defaults preserved for old platform versions
// that don't yet include these columns in the GET response.
broadcastEnabled: ws.broadcast_enabled ?? false,
talkToUserEnabled: ws.talk_to_user_enabled ?? true,
},
};
if (hasParent) {
-7
View File
@@ -99,13 +99,6 @@ export interface WorkspaceNodeData extends Record<string, unknown> {
* @/lib/runtimeProfiles. Lets a slow runtime declare its cold-boot
* expectation without a canvas release. */
provisionTimeoutMs?: number | null;
/** When true the workspace may POST /broadcast to send org-wide messages.
* Default false. Toggled by user/admin via PATCH /workspaces/:id/abilities. */
broadcastEnabled?: boolean;
/** When false the workspace cannot deliver canvas chat messages.
* send_message_to_user / POST /notify return 403 and the canvas
* shows a "not enabled" state with a button to re-enable. Default true. */
talkToUserEnabled?: boolean;
}
export type PanelTab = "details" | "skills" | "chat" | "terminal" | "config" | "schedule" | "channels" | "files" | "memory" | "traces" | "events" | "activity" | "audit";
-3
View File
@@ -299,9 +299,6 @@ export interface WorkspaceData {
* `@/lib/runtimeProfiles` when absent (the default behavior for any
* template that hasn't yet declared the field). */
provision_timeout_ms?: number | null;
/** Workspace ability flags (migration 20260514). */
broadcast_enabled?: boolean;
talk_to_user_enabled?: boolean;
}
let socket: ReconnectingSocket | null = null;
-296
View File
@@ -1,296 +0,0 @@
#!/usr/bin/env bash
# E2E test: workspace broadcast and talk-to-user platform abilities.
#
# What this proves:
# 1. talk_to_user_enabled (default true) — POST /notify works out-of-the-box.
# 2. PATCH /workspaces/:id/abilities { talk_to_user_enabled: false } disables
# delivery: /notify → 403 with error="talk_to_user_disabled" + delegate hint.
# 3. Re-enabling talk_to_user_enabled restores delivery.
# 4. broadcast_enabled (default false) — POST /broadcast → 403 when disabled.
# 5. PATCH { broadcast_enabled: true } enables fan-out.
# 6. POST /broadcast delivers to all non-sender, non-removed workspaces:
# - Returns {"status":"sent","delivered":N}
# - Receiver's activity log has a broadcast_receive entry with the message.
# - Sender's activity log has a broadcast_sent entry.
# 7. The sender itself does NOT receive a broadcast_receive entry.
#
# Usage: tests/e2e/test_workspace_abilities_e2e.sh
# Prereqs: workspace-server on http://localhost:8080, MOLECULE_ENV != production
set -euo pipefail
source "$(dirname "$0")/_lib.sh"
PASS=0
FAIL=0
SENDER_ID=""
RECEIVER_ID=""
cleanup() {
for wid in "$SENDER_ID" "$RECEIVER_ID"; do
if [ -n "$wid" ]; then
curl -s -X DELETE "$BASE/workspaces/$wid?confirm=true" > /dev/null || true
fi
done
}
trap cleanup EXIT INT TERM
assert() {
local label="$1" actual="$2" expected="$3"
if [ "$actual" = "$expected" ]; then
echo " PASS — $label"
PASS=$((PASS+1))
else
echo " FAIL — $label"
echo " expected: $expected"
echo " actual: $actual"
FAIL=$((FAIL+1))
fi
}
assert_contains() {
local label="$1" haystack="$2" needle="$3"
if echo "$haystack" | grep -qF "$needle"; then
echo " PASS — $label"
PASS=$((PASS+1))
else
echo " FAIL — $label"
echo " needle: $needle"
echo " haystack: $haystack"
FAIL=$((FAIL+1))
fi
}
assert_not_contains() {
local label="$1" haystack="$2" needle="$3"
if ! echo "$haystack" | grep -qF "$needle"; then
echo " PASS — $label"
PASS=$((PASS+1))
else
echo " FAIL — $label (unexpected match)"
echo " needle: $needle"
echo " haystack: $haystack"
FAIL=$((FAIL+1))
fi
}
# ── Pre-sweep: remove any stale leftover workspaces from a prior aborted run ──
echo "=== Setup ==="
for NAME in "Abilities Sender" "Abilities Receiver"; do
PRIOR=$(curl -s "$BASE/workspaces" | python3 -c "
import json, sys
try:
print(' '.join(w['id'] for w in json.load(sys.stdin) if w.get('name') == '$NAME'))
except Exception:
pass
")
for _wid in $PRIOR; do
echo "Sweeping leftover '$NAME' workspace: $_wid"
curl -s -X DELETE "$BASE/workspaces/$_wid?confirm=true" > /dev/null || true
done
done
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
-d '{"name":"Abilities Sender","tier":1}')
SENDER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
[ -n "$SENDER_ID" ] || { echo "Failed to create sender workspace: $R"; exit 1; }
echo "Created sender workspace: $SENDER_ID"
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
-d '{"name":"Abilities Receiver","tier":1}')
RECEIVER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
[ -n "$RECEIVER_ID" ] || { echo "Failed to create receiver workspace: $R"; exit 1; }
echo "Created receiver workspace: $RECEIVER_ID"
# Mint workspace-scoped bearer tokens (test-only endpoint, disabled in prod).
SENDER_TOKEN=$(e2e_mint_test_token "$SENDER_ID")
[ -n "$SENDER_TOKEN" ] || { echo "Failed to mint sender token"; exit 1; }
SENDER_AUTH="Authorization: Bearer $SENDER_TOKEN"
# Admin token — any live workspace bearer satisfies AdminAuth in local dev.
# In production-like envs, set MOLECULE_ADMIN_TOKEN.
ADMIN_TOKEN="${MOLECULE_ADMIN_TOKEN:-$SENDER_TOKEN}"
ADMIN_AUTH="Authorization: Bearer $ADMIN_TOKEN"
# ─────────────────────────────────────────────────────────────────────────────
echo ""
echo "=== Part 1: talk_to_user ability ==="
echo ""
echo "--- 1a: /notify works with default talk_to_user_enabled=true ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Hello from sender"}')
assert "POST /notify returns 200 when talk_to_user_enabled=true (default)" "$CODE" "200"
echo ""
echo "--- 1b: Disable talk_to_user ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
-d '{"talk_to_user_enabled": false}')
assert "PATCH /abilities talk_to_user_enabled=false returns 200" "$CODE" "200"
# Verify the flag is reflected in the workspace GET response.
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
assert "GET /workspaces/:id reflects talk_to_user_enabled=false" "$FLAG" "False"
echo ""
echo "--- 1c: /notify blocked when talk_to_user disabled ---"
BODY=$(curl -s -w "" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Should be blocked"}')
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Should be blocked"}')
assert "POST /notify returns 403 when talk_to_user_enabled=false" "$CODE" "403"
ERR=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("error",""))' 2>/dev/null || echo "")
assert_contains "403 body contains talk_to_user_disabled error code" "$ERR" "talk_to_user_disabled"
HINT=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("hint",""))' 2>/dev/null || echo "")
assert_contains "403 body contains delegate_task hint" "$HINT" "delegate_task"
echo ""
echo "--- 1d: Re-enable talk_to_user and verify /notify works again ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
-d '{"talk_to_user_enabled": true}')
assert "PATCH /abilities talk_to_user_enabled=true returns 200" "$CODE" "200"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Re-enabled, should work"}')
assert "POST /notify returns 200 after re-enabling talk_to_user" "$CODE" "200"
# ─────────────────────────────────────────────────────────────────────────────
echo ""
echo "=== Part 2: broadcast ability ==="
echo ""
echo "--- 2a: Broadcast blocked by default (broadcast_enabled=false) ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Should be blocked"}')
assert "POST /broadcast returns 403 when broadcast_enabled=false (default)" "$CODE" "403"
echo ""
echo "--- 2b: Enable broadcast ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
-d '{"broadcast_enabled": true}')
assert "PATCH /abilities broadcast_enabled=true returns 200" "$CODE" "200"
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
assert "GET /workspaces/:id reflects broadcast_enabled=true" "$FLAG" "True"
echo ""
echo "--- 2c: Successful broadcast fan-out ---"
BCAST=$(curl -s -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":"Org-wide notice: scheduled maintenance in 5 minutes."}')
BSTATUS=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("status",""))' 2>/dev/null || echo "")
BDELIVERED=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("delivered","-1"))' 2>/dev/null || echo "-1")
assert "POST /broadcast returns status=sent" "$BSTATUS" "sent"
# delivered count must be >= 1 (the receiver workspace).
echo " INFO — broadcast delivered=$BDELIVERED"
if python3 -c "import sys; sys.exit(0 if int('$BDELIVERED') >= 1 else 1)" 2>/dev/null; then
echo " PASS — delivered count >= 1"
PASS=$((PASS+1))
else
echo " FAIL — expected delivered >= 1, got $BDELIVERED"
FAIL=$((FAIL+1))
fi
echo ""
echo "--- 2d: Receiver activity log has broadcast_receive entry ---"
RECEIVER_TOKEN=$(e2e_mint_test_token "$RECEIVER_ID")
[ -n "$RECEIVER_TOKEN" ] || { echo "Failed to mint receiver token"; exit 1; }
RECEIVER_AUTH="Authorization: Bearer $RECEIVER_TOKEN"
ACT=$(curl -s -H "$RECEIVER_AUTH" "$BASE/workspaces/$RECEIVER_ID/activity?source=agent&limit=20")
ROW=$(echo "$ACT" | python3 -c '
import json, sys
rows = json.load(sys.stdin) or []
for r in rows:
if r.get("activity_type") == "broadcast_receive":
print(json.dumps(r))
break
')
[ -n "$ROW" ] || {
echo " FAIL — could not find broadcast_receive row in receiver activity"
FAIL=$((FAIL+1))
}
if [ -n "$ROW" ]; then
# Message is stored in summary field.
MSG=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("summary",""))')
assert_contains "broadcast_receive row summary has original message" "$MSG" "scheduled maintenance"
# Sender ID is stored in source_id field.
SRC=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("source_id",""))')
assert "broadcast_receive row source_id is sender workspace" "$SRC" "$SENDER_ID"
fi
echo ""
echo "--- 2e: Sender activity log has broadcast_sent entry ---"
ACT_SENDER=$(curl -s -H "$SENDER_AUTH" "$BASE/workspaces/$SENDER_ID/activity?limit=20")
SENT_ROW=$(echo "$ACT_SENDER" | python3 -c '
import json, sys
rows = json.load(sys.stdin) or []
for r in rows:
if r.get("activity_type") == "broadcast_sent":
print(json.dumps(r))
break
')
[ -n "$SENT_ROW" ] || {
echo " FAIL — could not find broadcast_sent row in sender activity"
FAIL=$((FAIL+1))
}
if [ -n "$SENT_ROW" ]; then
# Delivered count is baked into the summary field (no response_body for sender row).
SUMMARY=$(echo "$SENT_ROW" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("summary",""))')
assert_contains "broadcast_sent summary mentions workspace count" "$SUMMARY" "workspace"
fi
echo ""
echo "--- 2f: Sender does NOT receive a broadcast_receive entry ---"
SELF_RECV=$(echo "$ACT_SENDER" | python3 -c '
import json, sys
rows = json.load(sys.stdin) or []
for r in rows:
if r.get("activity_type") == "broadcast_receive":
print("found")
break
')
assert_not_contains "sender has no broadcast_receive in own activity log" "${SELF_RECV:-}" "found"
# ─────────────────────────────────────────────────────────────────────────────
echo ""
echo "--- 2g: Empty message is rejected ---"
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
-d '{"message":""}')
assert "POST /broadcast with empty message returns 400" "$CODE" "400"
echo ""
echo "--- 2h: Partial PATCH does not clobber other flags ---"
# Set talk_to_user=false, then patch only broadcast — talk_to_user must stay false.
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
-d '{"talk_to_user_enabled": false}'
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
-d '{"broadcast_enabled": false}'
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
TUF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
BEF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
assert "partial PATCH preserves talk_to_user_enabled=false" "$TUF" "False"
assert "partial PATCH sets broadcast_enabled=false" "$BEF" "False"
# ─────────────────────────────────────────────────────────────────────────────
echo ""
echo "=== Results: $PASS passed, $FAIL failed ==="
[ "$FAIL" -eq 0 ]
@@ -645,7 +645,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
// the caller can retry once the workspace is back online (~10s).
if status == "hibernated" {
log.Printf("ProxyA2A: waking hibernated workspace %s", workspaceID)
h.goAsync(func() { h.RestartByID(workspaceID) })
go h.RestartByID(workspaceID)
return "", &proxyA2AError{
Status: http.StatusServiceUnavailable,
Headers: map[string]string{"Retry-After": "15"},
@@ -1,408 +0,0 @@
package handlers
import (
"database/sql"
"net/http"
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/gin-gonic/gin"
)
// setupQueueStatusHandlerDB creates a sqlmock DB with QueryMatcherEqual for exact SQL string matching.
func setupQueueStatusHandlerDB(t *testing.T) sqlmock.Sqlmock {
t.Helper()
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("sqlmock.New: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
return mock
}
// Exact SQL strings used by the production code.
const (
sqlQueueRowAuthFields = `SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`
sqlQueueStatusByID = `
SELECT
q.id,
q.workspace_id,
q.status,
q.priority,
q.attempts,
q.last_error,
q.enqueued_at::text,
q.dispatched_at::text,
q.completed_at::text,
q.expires_at::text,
al.response_body::text
FROM a2a_queue q
LEFT JOIN activity_logs al
ON al.method = 'delegate_result'
AND al.target_id = q.workspace_id
AND al.workspace_id = q.caller_id
AND al.response_body->>'delegation_id' = (q.body->'params'->'message'->'metadata'->>'delegation_id')
WHERE q.id = $1`
)
// ── GetA2AQueueStatus HTTP handler tests ──────────────────────────────────────
// TestGetA2AQueueStatus_QueueIDEmpty_Returns400 exercises the handler directly
// (not via router) so we can verify the empty-value branch without relying on
// Gin route-matching behaviour.
func TestGetA2AQueueStatus_QueueIDEmpty_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
// queue_id param is empty string
c.Params = gin.Params{
{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
{Key: "queue_id", Value: ""},
}
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h.GetA2AQueueStatus(c)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400", w.Code)
}
}
func TestGetA2AQueueStatus_NoIdentity_NoOrgToken_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/wsid/a2a/queue/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// No identity derivable → 404 (not 401) per existence-non-inference policy.
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404", w.Code)
}
}
func TestGetA2AQueueStatus_OrgToken_SkipsCallerCheck(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow("other-ws", wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
statusRows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "queued", 50, 0,
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
)
mock.ExpectQuery(sqlQueueStatusByID).
WithArgs(queueID).
WillReturnRows(statusRows)
r := gin.New()
// Simulate org-token middleware setting org_token_id.
r.GET("/workspaces/:id/a2a/queue/:queue_id", func(c *gin.Context) {
c.Set("org_token_id", "org-admin")
h.GetA2AQueueStatus(c)
})
req := httptest.NewRequest(http.MethodGet, "/workspaces/wsid/a2a/queue/"+queueID, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_CallerWorkspaceMatchesCallerID_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
statusRows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "completed", 50, 1,
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
nil, []byte(`{"text":"result"}`),
)
mock.ExpectQuery(sqlQueueStatusByID).
WithArgs(queueID).
WillReturnRows(statusRows)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", callerID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_CallerWorkspaceMatchesWorkspaceID_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
statusRows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "queued", 50, 0,
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
)
mock.ExpectQuery(sqlQueueStatusByID).
WithArgs(queueID).
WillReturnRows(statusRows)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", wsID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_QueueNotFound_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnError(sql.ErrNoRows)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", callerID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_QueueAuthFieldsDBError_Returns500(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnError(sql.ErrConnDone)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", callerID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_WrongCallerWorkspace_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
wrongCaller := "dddddddd-dddd-dddd-dddd-dddddddddddd"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", wrongCaller)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_StatusFetchDBError_Returns500(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
mock.ExpectQuery(sqlQueueStatusByID).
WithArgs(queueID).
WillReturnError(sql.ErrConnDone)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", callerID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestGetA2AQueueStatus_FullHappyPath_ReturnsJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupQueueStatusHandlerDB(t)
h := &WorkspaceHandler{}
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(sqlQueueRowAuthFields).
WithArgs(queueID).
WillReturnRows(authRows)
respBody := []byte(`{"text":"delegation result"}`)
statusRows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "completed", 50, 1,
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
nil, respBody,
)
mock.ExpectQuery(sqlQueueStatusByID).
WithArgs(queueID).
WillReturnRows(statusRows)
r := gin.New()
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
req := httptest.NewRequest(http.MethodGet,
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
req.Header.Set("X-Workspace-ID", wsID)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if w.Body.Len() == 0 {
t.Error("response body is empty")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
@@ -482,13 +482,6 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
return
}
if errors.Is(err, ErrTalkToUserDisabled) {
c.JSON(http.StatusForbidden, gin.H{
"error": "talk_to_user_disabled",
"hint": "This workspace is not allowed to send messages directly to the user. Forward your update to a parent workspace using delegate_task — they may be able to reach the user.",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
@@ -464,9 +464,9 @@ func TestNotify_PersistsToActivityLogsForReloadRecovery(t *testing.T) {
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
// Workspace existence check
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
mock.ExpectQuery(`SELECT name FROM workspaces`).
WithArgs("ws-notify").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
// Persistence INSERT — verify shape
mock.ExpectExec(`INSERT INTO activity_logs`).
@@ -511,9 +511,9 @@ func TestNotify_WithAttachments_PersistsFilePartsForReload(t *testing.T) {
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
mock.ExpectQuery(`SELECT name FROM workspaces`).
WithArgs("ws-attach").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
// Capture the JSONB arg so we can assert on the persisted shape
// AFTER the call (must include parts[].kind=file so reload
@@ -640,9 +640,9 @@ func TestNotify_DBFailure_StillBroadcastsAnd200(t *testing.T) {
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
mock.ExpectQuery(`SELECT name FROM workspaces`).
WithArgs("ws-x").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnError(fmt.Errorf("simulated db hiccup"))
@@ -54,11 +54,6 @@ import (
// timeout) surface as wrapped errors and should be treated as 503.
var ErrWorkspaceNotFound = errors.New("agent_message: workspace not found")
// ErrTalkToUserDisabled is returned when the workspace has
// talk_to_user_enabled=false. Callers surface HTTP 403 so the Python tool
// can detect it and suggest forwarding to a parent workspace.
var ErrTalkToUserDisabled = errors.New("agent_message: talk_to_user disabled")
// AgentMessageAttachment is one file attached to an agent → user
// message. Identical to handlers.NotifyAttachment in field set; kept
// distinct so the writer's API doesn't import a handler type with HTTP
@@ -112,20 +107,16 @@ func (w *AgentMessageWriter) Send(
// notify call surfaced as "workspace not found" and masked real
// incidents in the alert path.
var wsName string
var talkToUserEnabled bool
err := w.db.QueryRowContext(ctx,
`SELECT name, talk_to_user_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`,
workspaceID,
).Scan(&wsName, &talkToUserEnabled)
).Scan(&wsName)
if errors.Is(err, sql.ErrNoRows) {
return ErrWorkspaceNotFound
}
if err != nil {
return fmt.Errorf("agent_message: workspace lookup: %w", err)
}
if !talkToUserEnabled {
return ErrTalkToUserDisabled
}
// 2. Build broadcast payload + WS-emit. Same shape that ChatTab's
// AGENT_MESSAGE handler in canvas/src/store/canvas-events.ts has
@@ -88,9 +88,9 @@ func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
mock := setupTestDB(t)
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-1").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
WithArgs(
@@ -116,9 +116,9 @@ func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
mock := setupTestDB(t)
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-att").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
WithArgs(
@@ -173,9 +173,9 @@ func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
emitter := &capturingEmitter{}
w := NewAgentMessageWriter(db.DB, emitter)
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-missing").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}))
WillReturnRows(sqlmock.NewRows([]string{"name"}))
err := w.Send(context.Background(), "ws-missing", "lost in the void", nil)
if !errors.Is(err, ErrWorkspaceNotFound) {
@@ -202,9 +202,9 @@ func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
mock := setupTestDB(t)
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-dbfail").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnError(errors.New("transient db error"))
@@ -223,9 +223,9 @@ func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
mock := setupTestDB(t)
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-trunc").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
longMsg := strings.Repeat("x", 200)
mock.ExpectExec(`INSERT INTO activity_logs`).
@@ -263,9 +263,9 @@ func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
emitter := &capturingEmitter{}
w := NewAgentMessageWriter(db.DB, emitter)
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-bc").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Workspace Name", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Workspace Name"))
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnResult(sqlmock.NewResult(1, 1))
@@ -315,7 +315,7 @@ func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
transientErr := errors.New("connection refused")
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-dbdown").
WillReturnError(transientErr)
@@ -350,9 +350,9 @@ func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
// the byte-slice bug.
msg := strings.Repeat("你", 200)
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-cjk").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
mock.ExpectExec(`INSERT INTO activity_logs`).
WithArgs(
@@ -395,9 +395,9 @@ func TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty(t *testing.T) {
emitter := &capturingEmitter{}
w := NewAgentMessageWriter(db.DB, emitter)
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-noatt").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("X", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("X"))
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnResult(sqlmock.NewResult(1, 1))
@@ -116,9 +116,6 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
"created_at": createdAt,
})
}
if err := rows.Err(); err != nil {
log.Printf("ListPendingApprovals rows.Err: %v", err)
}
c.JSON(http.StatusOK, approvals)
}
@@ -158,9 +155,6 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
"created_at": createdAt,
})
}
if err := rows.Err(); err != nil {
log.Printf("ListApprovals rows.Err workspace=%s: %v", workspaceID, err)
}
c.JSON(http.StatusOK, approvals)
}
@@ -230,21 +230,20 @@ func TestWorkspaceList_WithData(t *testing.T) {
broadcaster := newTestBroadcaster()
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
// 23 cols — broadcast_enabled + talk_to_user_enabled added after monthly_spend
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
// 21 cols — see scanWorkspaceRow for order (max_concurrent_tasks
// lands between active_tasks and last_error_rate).
columns := []string{
"id", "name", "role", "tier", "status", "agent_card", "url",
"parent_id", "active_tasks", "max_concurrent_tasks",
"last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
rows := sqlmock.NewRows(columns).
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte(`{"name":"agent1"}`), "http://localhost:8001",
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0), false, true).
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0)).
AddRow("ws-2", "Agent Two", "", 2, "degraded", []byte("null"), "",
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0), false, true)
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0))
mock.ExpectQuery("SELECT w.id, w.name").
WillReturnRows(rows)
@@ -392,21 +392,21 @@ func TestWorkspaceList(t *testing.T) {
broadcaster := newTestBroadcaster()
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
// 23 cols: broadcast_enabled + talk_to_user_enabled added after monthly_spend
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
// 21 cols: `max_concurrent_tasks` added between active_tasks and
// last_error_rate (see scanWorkspaceRow + COALESCE(w.max_concurrent_tasks, 1)
// in workspace.go). Column order must match that scan exactly.
columns := []string{
"id", "name", "role", "tier", "status", "agent_card", "url",
"parent_id", "active_tasks", "max_concurrent_tasks",
"last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
rows := sqlmock.NewRows(columns).
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte("null"), "http://localhost:8001",
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0), false, true).
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0)).
AddRow("ws-2", "Agent Two", "manager", 2, "provisioning", []byte("null"), "",
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0), false, true)
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0))
mock.ExpectQuery("SELECT w.id, w.name").
WillReturnRows(rows)
@@ -1120,14 +1120,13 @@ func TestWorkspaceGet_CurrentTask(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs("dddddddd-0004-0000-0000-000000000000").
WillReturnRows(sqlmock.NewRows(columns).AddRow(
"dddddddd-0004-0000-0000-000000000000", "Task Worker", "worker", 1, "online", []byte("null"), "http://localhost:9000",
nil, 2, 1, 0.0, "", 300, "Analyzing document", "langgraph", "", 10.0, 20.0, false,
nil, int64(0), false, true,
nil, int64(0),
))
w := httptest.NewRecorder()
@@ -248,9 +248,6 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
b.WriteString(content)
b.WriteString("\n\n")
}
if err := rows.Err(); err != nil {
log.Printf("ResolveInstructions rows.Err workspace=%s: %v", workspaceID, err)
}
c.JSON(http.StatusOK, gin.H{
"workspace_id": workspaceID,
@@ -261,7 +258,6 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
func scanInstructions(rows interface {
Next() bool
Scan(dest ...interface{}) error
Err() error
}) []Instruction {
var instructions []Instruction
for rows.Next() {
@@ -273,9 +269,6 @@ func scanInstructions(rows interface {
}
instructions = append(instructions, inst)
}
if err := rows.Err(); err != nil {
log.Printf("scanInstructions rows.Err: %v", err)
}
if instructions == nil {
instructions = []Instruction{}
}
@@ -751,9 +751,9 @@ func TestMCPHandler_SendMessageToUser_DBErrorLogsAndStill200s(t *testing.T) {
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
h, mock := newMCPHandler(t)
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-err").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
// INSERT fails — must NOT abort the tool response.
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
@@ -802,9 +802,9 @@ func TestMCPHandler_SendMessageToUser_ResponseBodyShape(t *testing.T) {
const userMessage = "Hi there from the agent"
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-shape").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
// Capture the response_body argument and assert its exact shape.
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
@@ -861,9 +861,9 @@ func TestMCPHandler_SendMessageToUser_PersistsToActivityLog(t *testing.T) {
// before it does anything else. Returning a name lets the
// broadcast payload populate; the test doesn't assert on the
// broadcast (no observable WS in this fake), only on the DB.
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
mock.ExpectQuery("SELECT name FROM workspaces").
WithArgs("ws-msg").
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
// The persistence INSERT — pin the exact shape so a future
// refactor that switches columns or drops `method='notify'`
@@ -287,7 +287,7 @@ func TestRenderCategoryRoutingYAML_StableOrdering(t *testing.T) {
if ai <= 0 || zi <= 0 || mi <= 0 {
t.Fatalf("could not locate all keys in output: %s", out)
}
if ai >= mi || mi >= zi {
if !(ai < mi && mi < zi) {
t.Errorf("keys not sorted: alpha=%d middle=%d zebra=%d, output:\n%s", ai, mi, zi, out)
}
}
@@ -67,9 +67,6 @@ func (h *TokenHandler) List(c *gin.Context) {
}
tokens = append(tokens, t)
}
if err := rows.Err(); err != nil {
log.Printf("ListTokens rows.Err workspace=%s: %v", workspaceID, err)
}
c.JSON(http.StatusOK, gin.H{
"tokens": tokens,
@@ -15,7 +15,6 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/crypto"
@@ -74,22 +73,6 @@ type WorkspaceHandler struct {
// memory plugin). main.go sets this to plugin.DeleteNamespace
// when MEMORY_PLUGIN_URL is configured.
namespaceCleanupFn func(ctx context.Context, workspaceID string)
// asyncWG tracks goroutines launched by goAsync so tests can wait
// for async DB users (restart, provision) before asserting results.
// Matches the pattern from main commit 1c3b4ff3.
asyncWG sync.WaitGroup
}
func (h *WorkspaceHandler) goAsync(fn func()) {
h.asyncWG.Add(1)
go func() {
defer h.asyncWG.Done()
fn()
}()
}
func (h *WorkspaceHandler) waitAsyncForTest() {
h.asyncWG.Wait()
}
func NewWorkspaceHandler(b events.EventEmitter, p *provisioner.Provisioner, platformURL, configsDir string) *WorkspaceHandler {
@@ -595,7 +578,7 @@ func scanWorkspaceRow(rows interface {
var id, name, role, status, url, sampleError, currentTask, runtime, workspaceDir string
var tier, activeTasks, maxConcurrentTasks, uptimeSeconds int
var errorRate, x, y float64
var collapsed, broadcastEnabled, talkToUserEnabled bool
var collapsed bool
var parentID *string
var agentCard []byte
var budgetLimit sql.NullInt64
@@ -604,7 +587,7 @@ func scanWorkspaceRow(rows interface {
err := rows.Scan(&id, &name, &role, &tier, &status, &agentCard, &url,
&parentID, &activeTasks, &maxConcurrentTasks, &errorRate, &sampleError, &uptimeSeconds,
&currentTask, &runtime, &workspaceDir, &x, &y, &collapsed,
&budgetLimit, &monthlySpend, &broadcastEnabled, &talkToUserEnabled)
&budgetLimit, &monthlySpend)
if err != nil {
return nil, err
}
@@ -628,8 +611,6 @@ func scanWorkspaceRow(rows interface {
"x": x,
"y": y,
"collapsed": collapsed,
"broadcast_enabled": broadcastEnabled,
"talk_to_user_enabled": talkToUserEnabled,
}
// budget_limit: nil when no limit set, int64 otherwise
@@ -665,8 +646,7 @@ const workspaceListQuery = `
COALESCE(w.current_task, ''), COALESCE(w.runtime, 'langgraph'),
COALESCE(w.workspace_dir, ''),
COALESCE(cl.x, 0), COALESCE(cl.y, 0), COALESCE(cl.collapsed, false),
w.budget_limit, COALESCE(w.monthly_spend, 0),
w.broadcast_enabled, w.talk_to_user_enabled
w.budget_limit, COALESCE(w.monthly_spend, 0)
FROM workspaces w
LEFT JOIN canvas_layouts cl ON cl.workspace_id = w.id
WHERE w.status != 'removed'
@@ -726,8 +706,7 @@ func (h *WorkspaceHandler) Get(c *gin.Context) {
COALESCE(w.current_task, ''), COALESCE(w.runtime, 'langgraph'),
COALESCE(w.workspace_dir, ''),
COALESCE(cl.x, 0), COALESCE(cl.y, 0), COALESCE(cl.collapsed, false),
w.budget_limit, COALESCE(w.monthly_spend, 0),
w.broadcast_enabled, w.talk_to_user_enabled
w.budget_limit, COALESCE(w.monthly_spend, 0)
FROM workspaces w
LEFT JOIN canvas_layouts cl ON cl.workspace_id = w.id
WHERE w.id = $1
@@ -1,82 +0,0 @@
package handlers
// workspace_abilities.go — PATCH /workspaces/:id/abilities
//
// Allows users and admin agents to toggle two workspace-level ability flags:
//
// broadcast_enabled — workspace may POST /broadcast to send org-wide messages
// talk_to_user_enabled — workspace may deliver canvas chat messages via
// send_message_to_user / POST /notify
//
// Gated behind AdminAuth so workspace agents cannot self-modify their own
// ability flags (that would let any agent grant itself broadcast rights or
// suppress its own chat-silence constraint).
import (
"log"
"net/http"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/gin-gonic/gin"
)
// AbilitiesPayload carries the subset of ability flags the caller wants to
// update. Fields are pointers so that the handler can distinguish "caller
// supplied false" from "caller omitted the field" (omitempty semantics).
type AbilitiesPayload struct {
BroadcastEnabled *bool `json:"broadcast_enabled"`
TalkToUserEnabled *bool `json:"talk_to_user_enabled"`
}
// PatchAbilities handles PATCH /workspaces/:id/abilities (AdminAuth).
func PatchAbilities(c *gin.Context) {
id := c.Param("id")
if err := validateWorkspaceID(id); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
return
}
var body AbilitiesPayload
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if body.BroadcastEnabled == nil && body.TalkToUserEnabled == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "at least one ability field required"})
return
}
ctx := c.Request.Context()
var exists bool
if err := db.DB.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`, id,
).Scan(&exists); err != nil || !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
return
}
if body.BroadcastEnabled != nil {
if _, err := db.DB.ExecContext(ctx,
`UPDATE workspaces SET broadcast_enabled = $2, updated_at = now() WHERE id = $1`,
id, *body.BroadcastEnabled,
); err != nil {
log.Printf("PatchAbilities broadcast_enabled for %s: %v", id, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "update failed"})
return
}
}
if body.TalkToUserEnabled != nil {
if _, err := db.DB.ExecContext(ctx,
`UPDATE workspaces SET talk_to_user_enabled = $2, updated_at = now() WHERE id = $1`,
id, *body.TalkToUserEnabled,
); err != nil {
log.Printf("PatchAbilities talk_to_user_enabled for %s: %v", id, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "update failed"})
return
}
}
c.JSON(http.StatusOK, gin.H{"status": "updated"})
}
@@ -1,317 +0,0 @@
package handlers
import (
"database/sql"
"net/http"
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/gin-gonic/gin"
)
// setupAbilitiesDB creates a sqlmock DB with QueryMatcherEqual for exact SQL matching.
func setupAbilitiesDB(t *testing.T) sqlmock.Sqlmock {
t.Helper()
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("sqlmock.New: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
return mock
}
// Exact SQL strings used by the production handler.
const (
sqlPatchAbilitiesExists = `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`
sqlPatchBroadcastEnabled = `UPDATE workspaces SET broadcast_enabled = $2, updated_at = now() WHERE id = $1`
sqlPatchTalkToUserEnabled = `UPDATE workspaces SET talk_to_user_enabled = $2, updated_at = now() WHERE id = $1`
)
// ── PatchAbilities HTTP handler tests ──────────────────────────────────────────
func TestPatchAbilities_InvalidWorkspaceID_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupAbilitiesDB(t)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/not-a-uuid/abilities", nil)
PatchAbilities(c)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400", w.Code)
}
}
func TestPatchAbilities_InvalidBody_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupAbilitiesDB(t)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
c.Request = httptest.NewRequest(http.MethodPatch,
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
newFakeCloser([]byte("not json")))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400", w.Code)
}
}
func TestPatchAbilities_NoAbilityFields_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupAbilitiesDB(t)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
c.Request = httptest.NewRequest(http.MethodPatch,
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
newFakeCloser([]byte(`{}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400", w.Code)
}
}
func TestPatchAbilities_WorkspaceNotFound_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnError(sql.ErrNoRows)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_WorkspaceNotFound_ExistsFalse_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_UpdateBroadcastEnabled_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec(sqlPatchBroadcastEnabled).
WithArgs(wsID, true).
WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_UpdateTalkToUserEnabled_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec(sqlPatchTalkToUserEnabled).
WithArgs(wsID, false).
WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_UpdateBothAbilities_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec(sqlPatchBroadcastEnabled).
WithArgs(wsID, true).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(sqlPatchTalkToUserEnabled).
WithArgs(wsID, false).
WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"broadcast_enabled":true,"talk_to_user_enabled":false}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_BroadcastEnabledDBError_Returns500(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec(sqlPatchBroadcastEnabled).
WithArgs(wsID, true).
WillReturnError(sql.ErrConnDone)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestPatchAbilities_TalkToUserEnabledDBError_Returns500(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupAbilitiesDB(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlPatchAbilitiesExists).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec(sqlPatchTalkToUserEnabled).
WithArgs(wsID, true).
WillReturnError(sql.ErrConnDone)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
newFakeCloser([]byte(`{"talk_to_user_enabled":true}`)))
c.Request.Header.Set("Content-Type", "application/json")
PatchAbilities(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// ── Helpers ────────────────────────────────────────────────────────────────────
// newFakeCloser wraps a byte slice as an io.ReadCloser for request body injection.
func newFakeCloser(data []byte) *fakeReadCloser {
return &fakeReadCloser{data: data}
}
type fakeReadCloser struct {
data []byte
pos int
}
func (f *fakeReadCloser) Read(p []byte) (n int, err error) {
if f.pos >= len(f.data) {
return 0, nil
}
n = copy(p, f.data[f.pos:])
f.pos += n
return n, nil
}
func (*fakeReadCloser) Close() error { return nil }
@@ -1,142 +0,0 @@
package handlers
// workspace_broadcast.go — POST /workspaces/:id/broadcast
//
// Allows a workspace with broadcast_enabled=true to send a message to every
// non-removed agent workspace in the org. The message is:
//
// • Persisted in each recipient's activity_logs (type='broadcast_receive')
// so poll-mode agents pick it up via GET /activity.
// • Broadcast via WebSocket BROADCAST_MESSAGE event so canvas panels can
// show a real-time banner for each recipient workspace.
//
// The sender's own workspace logs a 'broadcast_sent' activity row for
// traceability.
//
// Auth: WorkspaceAuth (the agent triggers this with its own bearer token).
// The handler re-validates broadcast_enabled inside the DB lookup to prevent
// TOCTOU — the middleware only proved the token is valid, not the ability.
import (
"log"
"net/http"
"strconv"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
"github.com/gin-gonic/gin"
)
// BroadcastHandler is constructed once and shared across requests.
type BroadcastHandler struct {
broadcaster *events.Broadcaster
}
// NewBroadcastHandler creates a BroadcastHandler.
func NewBroadcastHandler(b *events.Broadcaster) *BroadcastHandler {
return &BroadcastHandler{broadcaster: b}
}
// Broadcast handles POST /workspaces/:id/broadcast.
func (h *BroadcastHandler) Broadcast(c *gin.Context) {
senderID := c.Param("id")
if err := validateWorkspaceID(senderID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
return
}
var body struct {
Message string `json:"message" binding:"required"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "message is required"})
return
}
ctx := c.Request.Context()
// Verify sender exists and has broadcast_enabled=true.
var senderName string
var broadcastEnabled bool
err := db.DB.QueryRowContext(ctx,
`SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
senderID,
).Scan(&senderName, &broadcastEnabled)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
return
}
if !broadcastEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": "broadcast_disabled",
"hint": "This workspace does not have the broadcast ability. Ask a user or admin to enable it via PATCH /workspaces/:id/abilities.",
})
return
}
// Collect all non-removed agent workspaces (excludes the sender itself).
rows, err := db.DB.QueryContext(ctx,
`SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`,
senderID,
)
if err != nil {
log.Printf("Broadcast: recipient query failed for %s: %v", senderID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
defer rows.Close()
var recipientIDs []string
for rows.Next() {
var rid string
if rows.Scan(&rid) == nil {
recipientIDs = append(recipientIDs, rid)
}
}
if err := rows.Err(); err != nil {
log.Printf("Broadcast: recipient rows error for %s: %v", senderID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
broadcastPayload := map[string]interface{}{
"message": body.Message,
"sender_id": senderID,
"sender": senderName,
}
// Persist broadcast_receive in each recipient's activity log + emit WS event.
delivered := 0
for _, rid := range recipientIDs {
if _, err := db.DB.ExecContext(ctx, `
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status)
VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')
`, rid, senderID, "Broadcast from "+senderName+": "+broadcastTruncate(body.Message, 120)); err != nil {
log.Printf("Broadcast: activity_logs insert for recipient %s: %v", rid, err)
continue
}
h.broadcaster.BroadcastOnly(rid, "BROADCAST_MESSAGE", broadcastPayload)
delivered++
}
// Record the send on the sender's own log.
if _, err := db.DB.ExecContext(ctx, `
INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status)
VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')
`, senderID, "Broadcast sent to "+strconv.Itoa(delivered)+" workspace(s)"); err != nil {
log.Printf("Broadcast: sender activity_log for %s: %v", senderID, err)
}
c.JSON(http.StatusOK, gin.H{
"status": "sent",
"delivered": delivered,
})
}
func broadcastTruncate(s string, max int) string {
runes := []rune(s)
if len(runes) <= max {
return s
}
return string(runes[:max]) + "…"
}
@@ -1,403 +0,0 @@
package handlers
import (
"database/sql"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/gin-gonic/gin"
)
// broadcastBody is a convenience that returns an io.ReadCloser wrapping JSON body.
func broadcastBody(body string) io.ReadCloser {
return &broadcastFakeCloser{data: []byte(body)}
}
type broadcastFakeCloser struct {
data []byte
pos int
}
func (f *broadcastFakeCloser) Read(p []byte) (n int, err error) {
if f.pos >= len(f.data) {
return 0, io.EOF
}
n = copy(p, f.data[f.pos:])
f.pos += n
return n, nil
}
func (*broadcastFakeCloser) Close() error { return nil }
// setupBroadcastDB creates a sqlmock DB with QueryMatcherEqual.
func setupBroadcastDB(t *testing.T) sqlmock.Sqlmock {
t.Helper()
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("sqlmock.New: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
return mock
}
// Exact SQL strings from the production handler (whitespace must match verbatim).
const (
sqlBroadcastWorkspaceLookup = `SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`
sqlBroadcastRecipients = `SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`
sqlBroadcastReceiveInsert = `
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status)
VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')`
sqlBroadcastSentInsert = `
INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status)
VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')`
)
// ── Broadcast HTTP handler tests ───────────────────────────────────────────────
func TestBroadcast_InvalidWorkspaceID_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/not-a-uuid/broadcast",
broadcastBody(`{"message":"hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
}
}
func TestBroadcast_MissingMessage_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost,
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
broadcastBody(`{}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
}
}
func TestBroadcast_EmptyMessage_Returns400(t *testing.T) {
gin.SetMode(gin.TestMode)
setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost,
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
broadcastBody(`{"message":""}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
}
}
func TestBroadcast_WorkspaceNotFound_Returns404(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnError(sql.ErrNoRows)
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestBroadcast_BroadcastDisabled_Returns403(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
AddRow("test-workspace", false))
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("got %d, want 403: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestBroadcast_NoRecipients_Success(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
AddRow("test-workspace", true))
// No recipients (sender is the only non-removed workspace)
mock.ExpectQuery(sqlBroadcastRecipients).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
// Sender's own activity log: 2 args (workspaceID, summary)
mock.ExpectExec(sqlBroadcastSentInsert).
WithArgs(wsID, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"hello everyone"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestBroadcast_WithRecipients_Success_DeliversToAll(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
AddRow("broadcaster-ws", true))
mock.ExpectQuery(sqlBroadcastRecipients).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(recipient1).
AddRow(recipient2))
// broadcast_receive: 3 args (recipientID, senderID, summary)
mock.ExpectExec(sqlBroadcastReceiveInsert).
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(sqlBroadcastReceiveInsert).
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
// broadcast_sent: 2 args (workspaceID, summary)
mock.ExpectExec(sqlBroadcastSentInsert).
WithArgs(wsID, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"hello team"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestBroadcast_RecipientInsertError_ContinuesAndSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
AddRow("broadcaster-ws", true))
mock.ExpectQuery(sqlBroadcastRecipients).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).
AddRow(recipient1).
AddRow(recipient2))
// First recipient insert fails — handler logs and continues
mock.ExpectExec(sqlBroadcastReceiveInsert).
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
// Second recipient succeeds
mock.ExpectExec(sqlBroadcastReceiveInsert).
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(sqlBroadcastSentInsert).
WithArgs(wsID, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"partial delivery"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestBroadcast_SenderActivityLogError_StillReturns200(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := setupBroadcastDB(t)
h := NewBroadcastHandler(newTestBroadcaster())
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
AddRow("broadcaster-ws", true))
mock.ExpectQuery(sqlBroadcastRecipients).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectExec(sqlBroadcastSentInsert).
WithArgs(wsID, sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
r := gin.New()
r.POST("/workspaces/:id/broadcast", h.Broadcast)
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
broadcastBody(`{"message":"hello"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Handler logs error but still returns 200
if w.Code != http.StatusOK {
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// ── broadcastTruncate pure function tests ─────────────────────────────────────
func TestBroadcastTruncate_UnderLimit(t *testing.T) {
input := "short message"
got := broadcastTruncate(input, 50)
if got != input {
t.Errorf("broadcastTruncate(%q, 50) = %q, want %q", input, got, input)
}
}
func TestBroadcastTruncate_ExactlyAtLimit(t *testing.T) {
input := "exactly fifty char"
got := broadcastTruncate(input, 18)
if got != input {
t.Errorf("broadcastTruncate(%q, 18) = %q, want %q", input, got, input)
}
}
func TestBroadcastTruncate_OverLimit_TruncatesAndAddsEllipsis(t *testing.T) {
// 150 ASCII chars → over 120 rune limit → truncate to 120 + ellipsis
input := strings.Repeat("x", 150)
got := broadcastTruncate(input, 120)
if len([]rune(got)) != 121 { // 120 + 1 ellipsis rune
t.Errorf("len(broadcastTruncate) = %d, want 121 (120 + ellipsis)", len([]rune(got)))
}
if got[:len(got)-len("…")] != strings.Repeat("x", 120) {
t.Errorf("broadcastTruncate did not truncate correctly")
}
}
func TestBroadcastTruncate_UnicodeChars_TreatsAsRunes(t *testing.T) {
// Each emoji is 1 rune but multiple bytes. 50 emojis > 30 limit.
input := strings.Repeat("🎉", 50)
got := broadcastTruncate(input, 30)
if len([]rune(got)) != 31 { // 30 + ellipsis
t.Errorf("len(broadcastTruncate with emoji) = %d, want 31", len([]rune(got)))
}
}
func TestBroadcastTruncate_ZeroLimit_ReturnsEllipsis(t *testing.T) {
got := broadcastTruncate("hello", 0)
if got != "…" {
t.Errorf("broadcastTruncate with max=0 = %q, want …", got)
}
}
@@ -33,7 +33,6 @@ var wsColumns = []string{
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
// ==================== GET — financial fields stripped from open endpoint ====================
@@ -53,10 +52,8 @@ func TestWorkspaceBudget_Get_NilLimit(t *testing.T) {
[]byte(`{}`), "http://localhost:9001",
nil, 0, 1, 0.0, "", 0, "", "langgraph", "",
0.0, 0.0, false,
nil, // budget_limit NULL
0, // monthly_spend 0
false, // broadcast_enabled
true)) // talk_to_user_enabled
nil, // budget_limit NULL
0)) // monthly_spend 0
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -99,8 +96,7 @@ func TestWorkspaceBudget_Get_WithLimit(t *testing.T) {
nil, 0, 1, 0.0, "", 0, "", "langgraph", "",
0.0, 0.0, false,
int64(500), // budget_limit = $5.00 in DB
int64(123), // monthly_spend = $1.23 in DB
false, true)) // broadcast_enabled, talk_to_user_enabled
int64(123))) // monthly_spend = $1.23 in DB
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -111,11 +111,11 @@ func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath stri
"sync": false,
})
if h.cpProv != nil {
h.goAsync(func() { h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload) })
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
return true
}
if h.provisioner != nil {
h.goAsync(func() { h.provisionWorkspace(workspaceID, templatePath, configFiles, payload) })
go h.provisionWorkspace(workspaceID, templatePath, configFiles, payload)
return true
}
// No backend wired — mark failed so the workspace doesn't linger in
@@ -275,13 +275,13 @@ func (h *WorkspaceHandler) RestartWorkspaceAutoOpts(ctx context.Context, workspa
if h.cpProv != nil {
h.cpStopWithRetry(ctx, workspaceID, "RestartWorkspaceAuto")
// resetClaudeSession is Docker-only — CP has no session state to clear.
h.goAsync(func() { h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload) })
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
return true
}
if h.provisioner != nil {
// Docker.Stop has no retry — see docstring rationale.
h.provisioner.Stop(ctx, workspaceID)
h.goAsync(func() { h.provisionWorkspaceOpts(workspaceID, templatePath, configFiles, payload, resetClaudeSession) })
go h.provisionWorkspaceOpts(workspaceID, templatePath, configFiles, payload, resetClaudeSession)
return true
}
// No backend wired — same shape as provisionWorkspaceAuto's no-backend
@@ -805,9 +805,6 @@ func loadWorkspaceSecrets(ctx context.Context, workspaceID string) (map[string]s
envVars[k] = string(decrypted)
}
}
if err := globalRows.Err(); err != nil {
log.Printf("Provisioner: global_secrets rows.Err workspace=%s: %v", workspaceID, err)
}
}
wsRows, err := db.DB.QueryContext(ctx,
`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1`, workspaceID)
@@ -826,9 +823,6 @@ func loadWorkspaceSecrets(ctx context.Context, workspaceID string) (map[string]s
envVars[k] = string(decrypted)
}
}
if err := wsRows.Err(); err != nil {
log.Printf("Provisioner: workspace_secrets rows.Err workspace=%s: %v", workspaceID, err)
}
}
return envVars, ""
}
@@ -29,7 +29,6 @@ func TestWorkspaceGet_Success(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs("cccccccc-0001-0000-0000-000000000000").
@@ -37,7 +36,7 @@ func TestWorkspaceGet_Success(t *testing.T) {
AddRow("cccccccc-0001-0000-0000-000000000000", "My Agent", "worker", 1, "online", []byte(`{"name":"test"}`),
"http://localhost:8001", nil, 2, 1, 0.05, "", 3600, "working", "langgraph",
"", 10.0, 20.0, false,
nil, 0, false, true))
nil, 0))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -119,7 +118,6 @@ func TestWorkspaceGet_RemovedReturns410(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs(id).
@@ -127,7 +125,7 @@ func TestWorkspaceGet_RemovedReturns410(t *testing.T) {
AddRow(id, "Old Agent", "worker", 1, string(models.StatusRemoved), []byte(`null`),
"", nil, 0, 1, 0.0, "", 0, "", "langgraph",
"", 0.0, 0.0, false,
nil, 0, false, true))
nil, 0))
mock.ExpectQuery(`SELECT updated_at FROM workspaces`).
WithArgs(id).
WillReturnRows(sqlmock.NewRows([]string{"updated_at"}).AddRow(removedAt))
@@ -183,7 +181,6 @@ func TestWorkspaceGet_RemovedReturns410WithNullRemovedAtOnTimestampFetchFailure(
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs(id).
@@ -191,7 +188,7 @@ func TestWorkspaceGet_RemovedReturns410WithNullRemovedAtOnTimestampFetchFailure(
AddRow(id, "Vanished", "worker", 1, string(models.StatusRemoved), []byte(`null`),
"", nil, 0, 1, 0.0, "", 0, "", "langgraph",
"", 0.0, 0.0, false,
nil, 0, false, true))
nil, 0))
// Simulate the row vanishing between the two queries.
mock.ExpectQuery(`SELECT updated_at FROM workspaces`).
WithArgs(id).
@@ -246,7 +243,6 @@ func TestWorkspaceGet_RemovedWithIncludeQueryReturns200(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs(id).
@@ -254,7 +250,7 @@ func TestWorkspaceGet_RemovedWithIncludeQueryReturns200(t *testing.T) {
AddRow(id, "Audit Agent", "worker", 1, string(models.StatusRemoved), []byte(`null`),
"", nil, 0, 1, 0.0, "", 0, "", "langgraph",
"", 0.0, 0.0, false,
nil, 0, false, true))
nil, 0))
// last_outbound_at follow-up query (existing path)
mock.ExpectQuery(`SELECT last_outbound_at FROM workspaces`).
WithArgs(id).
@@ -680,7 +676,6 @@ func TestWorkspaceList_Empty(t *testing.T) {
"parent_id", "active_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}))
w := httptest.NewRecorder()
@@ -1384,7 +1379,6 @@ func TestWorkspaceGet_FinancialFieldsStripped(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
// Populate with non-zero financial values to confirm they are stripped.
mock.ExpectQuery("SELECT w.id, w.name").
@@ -1393,7 +1387,7 @@ func TestWorkspaceGet_FinancialFieldsStripped(t *testing.T) {
AddRow("cccccccc-0010-0000-0000-000000000000", "Finance Test", "worker", 1, "online", []byte(`{}`),
"http://localhost:9001", nil, 0, 1, 0.0, "", 0, "", "langgraph",
"", 0.0, 0.0, false,
int64(50000), int64(12500), false, true)) // budget_limit=500 USD, spend=125 USD
int64(50000), int64(12500))) // budget_limit=500 USD, spend=125 USD
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -1441,7 +1435,6 @@ func TestWorkspaceGet_SensitiveFieldsStripped(t *testing.T) {
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
"budget_limit", "monthly_spend",
"broadcast_enabled", "talk_to_user_enabled",
}
mock.ExpectQuery("SELECT w.id, w.name").
WithArgs("cccccccc-0955-0000-0000-000000000000").
@@ -1454,7 +1447,7 @@ func TestWorkspaceGet_SensitiveFieldsStripped(t *testing.T) {
"langgraph",
"/home/user/secret-projects/client-work",
0.0, 0.0, false,
nil, 0, false, true))
nil, 0))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -36,15 +36,6 @@ type Workspace struct {
// to activity_logs, agent reads via GET /activity?since_id=). See
// migration 045 + RFC #2339.
DeliveryMode string `json:"delivery_mode" db:"delivery_mode"`
// BroadcastEnabled: when true the workspace may call POST /broadcast to
// deliver a message to all non-removed agent workspaces in the org.
// Default false — only privileged orchestrators should hold this ability.
BroadcastEnabled bool `json:"broadcast_enabled" db:"broadcast_enabled"`
// TalkToUserEnabled: when false the workspace's send_message_to_user calls
// and POST /notify requests are rejected with HTTP 403 so the agent is
// forced to route updates through a parent workspace. Default true
// (preserves existing behaviour for all workspaces).
TalkToUserEnabled bool `json:"talk_to_user_enabled" db:"talk_to_user_enabled"`
// Canvas layout fields (from JOIN)
X float64 `json:"x"`
Y float64 `json:"y"`
@@ -158,10 +158,9 @@ type cpProvisionRequest struct {
Tier int `json:"tier"`
PlatformURL string `json:"platform_url"`
Env map[string]string `json:"env"`
// ConfigFiles are template + generated config files to write into the
// EC2 instance's /configs directory. OFFSEC-010: collected by
// collectCPConfigFiles which rejects symlinks and non-regular files
// before including them. Serialised as base64 to avoid JSON escaping.
// ConfigFiles are base64-encoded config files collected from the template
// directory and inline ConfigFiles. Sent to the control plane so CP can
// write them into the EC2 instance at boot time (OFFSEC-010 wiring fix).
ConfigFiles map[string]string `json:"config_files,omitempty"`
}
@@ -186,11 +185,11 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
}
env["ADMIN_TOKEN"] = p.adminToken
}
// Collect template files and generated configs, with OFFSEC-010 guards:
// - Rejects symlinks at the template root (prevents bypass via symlink traversal)
// - Skips symlinks during WalkDir (prevents /etc/passwd etc. inclusion)
// - Validates all paths are relative and non-escaping
// - Caps total size at 12 KiB to prevent payload bloat
// Collect config files from the template directory and inline ConfigFiles.
// Without this wiring, OFFSEC-010 config-file collection is dead code and
// CP receives no files. An error here is fatal — a workspace with no
// config files cannot function.
configFiles, err := collectCPConfigFiles(cfg)
if err != nil {
return "", fmt.Errorf("cp provisioner: collect config files: %w", err)
@@ -257,16 +256,17 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
const cpConfigFilesMaxBytes = 12 << 10
// isCPTemplateConfigFile restricts which files from a template directory are
// eligible for transport to the control plane. Only config.yaml (the runtime
// entrypoint config) and files under prompts/ (system prompts) are needed;
// shipping arbitrary files (e.g. adapter.py, Dockerfile) is both unnecessary
// and a potential data-exfiltration surface.
func isCPTemplateConfigFile(name string) bool {
name = filepath.ToSlash(filepath.Clean(name))
return name == "config.yaml" || strings.HasPrefix(name, "prompts/")
}
// collectCPConfigFiles walks cfg.TemplatePath and collects regular files as
// base64-encoded key→value entries, plus any inline cfg.ConfigFiles. Used to
// ship template config into the EC2 instance via the control plane.
//
// Security (OFFSEC-010):
// - cfg.TemplatePath itself must not be a symlink (os.Lstat check).
// - WalkDir skips symlinks inside the tree (d.Type()&os.ModeSymlink guard).
// - All file paths are cleaned and relativized; absolute/traversal paths
// are rejected by addFile's sanitization checks.
//
// Returns nil, nil when no files are found — caller must handle empty case.
func collectCPConfigFiles(cfg WorkspaceConfig) (map[string]string, error) {
files := make(map[string]string)
total := 0
@@ -301,8 +301,7 @@ func collectCPConfigFiles(cfg WorkspaceConfig) (map[string]string, error) {
// Skip symlinks — WalkDir follows them by default, which means
// a symlink inside the template dir pointing to /etc/passwd
// would be traversed even though the resulting relative-path
// check would correctly reject it. Defense-in-depth: don't
// follow symlinks at all. (OFFSEC-010)
// check would correctly reject it. Defense-in-depth. (OFFSEC-010)
if d.Type()&os.ModeSymlink != 0 {
return nil
}
@@ -320,9 +319,6 @@ func collectCPConfigFiles(cfg WorkspaceConfig) (map[string]string, error) {
if err != nil {
return err
}
if !isCPTemplateConfigFile(rel) {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return err
@@ -343,6 +339,7 @@ func collectCPConfigFiles(cfg WorkspaceConfig) (map[string]string, error) {
}
return files, nil
}
// Stop terminates the workspace's EC2 instance via the control plane.
//
// Looks up the actual EC2 instance_id from the workspaces table before
@@ -1,9 +1,7 @@
package provisioner
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
@@ -283,105 +281,6 @@ func TestStart_TransportFailureSurfaces(t *testing.T) {
}
}
// TestStart_CollectsConfigFiles — verify that collectCPConfigFiles is called and
// its result is included in the cpProvisionRequest sent to the control plane.
// Tests the OFFSEC-010 wiring: the function's symlink guards are only effective
// if the call site actually invokes it.
func TestStart_CollectsConfigFiles(t *testing.T) {
tmpl := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpl, "config.yaml"), []byte("name: test\n"), 0o600); err != nil {
t.Fatal(err)
}
// adapter.py is within the size limit but is NOT config.yaml or prompts/,
// so isCPTemplateConfigFile must exclude it from the transport.
if err := os.WriteFile(filepath.Join(tmpl, "adapter.py"), bytes.Repeat([]byte("x"), cpConfigFilesMaxBytes), 0o600); err != nil {
t.Fatal(err)
}
var gotBody cpProvisionRequest
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewDecoder(r.Body).Decode(&gotBody)
w.WriteHeader(http.StatusCreated)
_, _ = io.WriteString(w, `{"instance_id":"i-abc123","state":"pending"}`)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
_, err := p.Start(context.Background(), WorkspaceConfig{
WorkspaceID: "ws-1",
Runtime: "python",
Tier: 1,
PlatformURL: "http://tenant",
TemplatePath: tmpl,
ConfigFiles: map[string][]byte{"generated.json": []byte(`{"key":"value"}`)},
})
if err != nil {
t.Fatalf("Start: %v", err)
}
// config.yaml from TemplatePath must be base64-encoded in ConfigFiles
if len(gotBody.ConfigFiles) == 0 {
t.Fatal("ConfigFiles is empty: collectCPConfigFiles was not called")
}
// Find config.yaml entry and verify it's valid base64 + correct content
var foundTemplate, foundGenerated bool
for name, encoded := range gotBody.ConfigFiles {
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
t.Errorf("ConfigFiles[%q] is not valid base64: %v", name, err)
continue
}
if name == "config.yaml" && string(decoded) == "name: test\n" {
foundTemplate = true
}
if name == "generated.json" && string(decoded) == `{"key":"value"}` {
foundGenerated = true
}
}
if !foundTemplate {
t.Errorf("ConfigFiles missing config.yaml from TemplatePath")
}
if !foundGenerated {
t.Errorf("ConfigFiles missing generated.json from ConfigFiles")
}
// adapter.py must NOT be in ConfigFiles — isCPTemplateConfigFile filters it out
for name := range gotBody.ConfigFiles {
if name == "adapter.py" {
t.Errorf("adapter.py should not be in ConfigFiles — isCPTemplateConfigFile must filter it out")
}
}
}
// TestStart_SymlinkTemplatePathError — a symlink TemplatePath should cause
// collectCPConfigFiles to return an error, which Start must propagate.
// Without this wiring, OFFSEC-010's root-symlink guard is dead code.
func TestStart_SymlinkTemplatePathError(t *testing.T) {
// Create a temp file and a symlink pointing to it
tmp := t.TempDir()
realFile := filepath.Join(tmp, "real")
if err := os.WriteFile(realFile, []byte("data"), 0o600); err != nil {
t.Fatal(err)
}
symlink := filepath.Join(tmp, "template_link")
if err := os.Symlink(realFile, symlink); err != nil {
t.Fatal(err)
}
p := &CPProvisioner{baseURL: "http://unused", orgID: "org-1", httpClient: &http.Client{Timeout: time.Second}}
_, err := p.Start(context.Background(), WorkspaceConfig{
WorkspaceID: "ws-1",
Runtime: "python",
TemplatePath: symlink, // symlink root → OFFSEC-010 guard should fire
})
if err == nil {
t.Fatal("expected error for symlink TemplatePath, got nil")
}
if !strings.Contains(err.Error(), "symlink") {
t.Errorf("error should mention symlink, got %q", err.Error())
}
}
// TestStop_SendsBothAuthHeaders — verify #118/#130 compliance on the
// teardown path. Any call to /cp/workspaces/:id must carry both the
// platform-wide shared secret AND the per-tenant admin token, or the
@@ -949,19 +848,16 @@ func TestIsRunning_EmptyInstanceIDReturnsFalse(t *testing.T) {
// TestCollectCPConfigFiles_SkipsSymlinks — WalkDir follows symlinks by default,
// but collectCPConfigFiles must skip them so a symlink inside a template dir
// pointing outside (e.g. ln -s /etc snapshot) cannot be traversed.
// Verifies OFFSEC-010 defense-in-depth fix. (OFFSEC-010)
// Verifies OFFSEC-010 defense-in-depth fix.
func TestCollectCPConfigFiles_SkipsSymlinks(t *testing.T) {
tmpl := t.TempDir()
// Write a real file that should be included.
if err := os.WriteFile(filepath.Join(tmpl, "config.yaml"), []byte("name: real\n"), 0o600); err != nil {
t.Fatal(err)
}
// Create a subdir with a file that will be symlinked-outside.
sensitiveDir := t.TempDir()
if err := os.WriteFile(filepath.Join(sensitiveDir, "secret.txt"), []byte("SENSITIVE\n"), 0o600); err != nil {
t.Fatal(err)
}
// Symlink inside template dir pointing to outside path.
symlinkPath := filepath.Join(tmpl, "snapshot")
if err := os.Symlink(sensitiveDir, symlinkPath); err != nil {
t.Fatal(err)
@@ -974,12 +870,9 @@ func TestCollectCPConfigFiles_SkipsSymlinks(t *testing.T) {
if files == nil {
t.Fatal("files should not be nil")
}
// config.yaml must be present.
if _, ok := files["config.yaml"]; !ok {
t.Errorf("config.yaml missing from files")
}
// The symlinked path must NOT be included (even though WalkDir would
// traverse it, the d.Type()&os.ModeSymlink guard skips the entry).
for k := range files {
if strings.Contains(k, "snapshot") || strings.Contains(k, "secret") {
t.Errorf("symlink path %q should not be in files — OFFSEC-010 regression", k)
@@ -989,8 +882,7 @@ func TestCollectCPConfigFiles_SkipsSymlinks(t *testing.T) {
// TestCollectCPConfigFiles_RejectsRootSymlink — if cfg.TemplatePath itself is
// a symlink, WalkDir would follow it to an arbitrary directory, bypassing the
// cfg.TemplatePath boundary. The function must reject this case explicitly.
// (OFFSEC-010)
// cfg.TemplatePath boundary. The function must reject this explicitly.
func TestCollectCPConfigFiles_RejectsRootSymlink(t *testing.T) {
real := t.TempDir()
if err := os.WriteFile(filepath.Join(real, "config.yaml"), []byte("name: real\n"), 0o600); err != nil {
@@ -1009,3 +901,110 @@ func TestCollectCPConfigFiles_RejectsRootSymlink(t *testing.T) {
t.Errorf("expected symlink-related error, got: %v", err)
}
}
// TestStart_PassesConfigFilesToCP — Start must call collectCPConfigFiles and
// include the result (base64-encoded files) in the cpProvisionRequest body sent
// to the control plane. Without this wiring, OFFSEC-010 config-file collection
// is dead code and CP receives no files. (OFFSEC-010 wiring fix)
func TestStart_PassesConfigFilesToCP(t *testing.T) {
tmpl := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpl, "agent.yaml"), []byte("name: test\n"), 0o600); err != nil {
t.Fatal(err)
}
inlineFiles := map[string][]byte{
"runtime.yaml": []byte("runtime: python\n"),
}
var sawBody cpProvisionRequest
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&sawBody); err != nil {
t.Fatalf("decode CP request body: %v", err)
}
w.WriteHeader(http.StatusCreated)
_, _ = io.WriteString(w, `{"instance_id":"i-cfgt","state":"running"}`)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
_, err := p.Start(context.Background(), WorkspaceConfig{
WorkspaceID: "ws-cfg",
Runtime: "python",
Tier: 2,
PlatformURL: "https://tenant.example",
TemplatePath: tmpl,
ConfigFiles: inlineFiles,
})
if err != nil {
t.Fatalf("Start: %v", err)
}
if sawBody.WorkspaceID != "ws-cfg" {
t.Errorf("WorkspaceID = %q, want ws-cfg", sawBody.WorkspaceID)
}
if sawBody.ConfigFiles == nil {
t.Fatal("ConfigFiles must not be nil — collectCPConfigFiles returned nil")
}
foundAgent := false
foundRuntime := false
for name := range sawBody.ConfigFiles {
if name == "agent.yaml" {
foundAgent = true
}
if name == "runtime.yaml" {
foundRuntime = true
}
}
if !foundAgent {
t.Errorf("agent.yaml missing from ConfigFiles; keys = %v", mapKeys(sawBody.ConfigFiles))
}
if !foundRuntime {
t.Errorf("runtime.yaml missing from ConfigFiles; keys = %v", mapKeys(sawBody.ConfigFiles))
}
}
// TestStart_CollectConfigFilesErrorSurfaces — if collectCPConfigFiles returns
// an error (e.g. TemplatePath is a symlink, or files exceed the 12 KiB cap),
// Start must propagate that error rather than silently continuing. (OFFSEC-010)
func TestStart_CollectConfigFilesErrorSurfaces(t *testing.T) {
realDir := t.TempDir()
if err := os.WriteFile(filepath.Join(realDir, "config.yaml"), []byte("data\n"), 0o600); err != nil {
t.Fatal(err)
}
symlinkPath := filepath.Join(t.TempDir(), "template-link")
if err := os.Symlink(realDir, symlinkPath); err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("CP must not be called when collectCPConfigFiles fails")
w.WriteHeader(http.StatusCreated)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
_, err := p.Start(context.Background(), WorkspaceConfig{
WorkspaceID: "ws-symlink",
Runtime: "python",
TemplatePath: symlinkPath,
})
if err == nil {
t.Fatal("expected error when TemplatePath is a symlink, got nil")
}
if !strings.Contains(err.Error(), "symlink") {
t.Errorf("error should mention symlink, got: %v", err)
}
if strings.Contains(err.Error(), "cp provisioner: send") {
t.Errorf("CP must not be called; error should be from collectConfigFiles, not HTTP send: %v", err)
}
}
// mapKeys returns the keys of a map — helper for test assertions.
func mapKeys(m map[string]string) []string {
if m == nil {
return nil
}
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
@@ -481,22 +481,6 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
return "", fmt.Errorf("failed to create container: %w", err)
}
// Seed /configs before the entrypoint starts. molecule-runtime reads
// /configs/config.yaml immediately; post-start copy races fast runtimes
// into a FileNotFoundError crash loop.
if cfg.TemplatePath != "" {
if err := p.CopyTemplateToContainer(ctx, resp.ID, cfg.TemplatePath); err != nil {
_ = p.cli.ContainerRemove(ctx, resp.ID, container.RemoveOptions{Force: true})
return "", fmt.Errorf("failed to copy template to container %s before start: %w", name, err)
}
}
if len(cfg.ConfigFiles) > 0 {
if err := p.WriteFilesToContainer(ctx, resp.ID, cfg.ConfigFiles); err != nil {
_ = p.cli.ContainerRemove(ctx, resp.ID, container.RemoveOptions{Force: true})
return "", fmt.Errorf("failed to write config files to container %s before start: %w", name, err)
}
}
if err := p.cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
// Clean up created container on start failure
_ = p.cli.ContainerRemove(ctx, resp.ID, container.RemoveOptions{Force: true})
@@ -512,6 +496,20 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
// /configs and /workspace, then drops to agent via gosu). No per-start
// chown needed here.
// Copy template files into /configs if TemplatePath is set
if cfg.TemplatePath != "" {
if err := p.CopyTemplateToContainer(ctx, resp.ID, cfg.TemplatePath); err != nil {
log.Printf("Provisioner: warning — failed to copy template to container %s: %v", name, err)
}
}
// Write generated config files into /configs if ConfigFiles is set
if len(cfg.ConfigFiles) > 0 {
if err := p.WriteFilesToContainer(ctx, resp.ID, cfg.ConfigFiles); err != nil {
log.Printf("Provisioner: warning — failed to write config files to container %s: %v", name, err)
}
}
// Resolve the host-mapped port. Retry inspect up to 3 times if Docker hasn't
// bound the ephemeral port yet (rare race under heavy load).
hostURL := InternalURL(cfg.WorkspaceID) // fallback to Docker-internal
@@ -146,9 +146,6 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
wsAdmin.GET("/workspaces", wh.List)
wsAdmin.POST("/workspaces", wh.Create)
wsAdmin.DELETE("/workspaces/:id", wh.Delete)
// Ability toggles — admin-only so workspace agents cannot self-modify
// broadcast_enabled or talk_to_user_enabled.
wsAdmin.PATCH("/workspaces/:id/abilities", handlers.PatchAbilities)
// Out-of-band bootstrap signal: CP's watcher POSTs here when it
// detects "RUNTIME CRASHED" in a workspace EC2 console output,
// so the canvas flips to failed in seconds instead of waiting
@@ -204,12 +201,6 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
// to 'hibernated'. The workspace auto-wakes on the next A2A message.
wsAuth.POST("/hibernate", wh.Hibernate)
// Broadcast — send a message to all non-removed workspaces in the org.
// Requires broadcast_enabled=true on the source workspace (checked
// inside the handler). WorkspaceAuth on wsAuth proves token ownership.
broadcastH := handlers.NewBroadcastHandler(broadcaster)
wsAuth.POST("/broadcast", broadcastH.Broadcast)
// External-workspace credential lifecycle (issue #319 follow-up to
// the Create flow). Both endpoints reject runtime ≠ external with
// 400 — see external_rotate.go for the rationale.
@@ -1,3 +0,0 @@
ALTER TABLE workspaces
DROP COLUMN IF EXISTS broadcast_enabled,
DROP COLUMN IF EXISTS talk_to_user_enabled;
@@ -1,16 +0,0 @@
-- Workspace abilities: opt-in flags that gate platform-level behaviours.
--
-- broadcast_enabled (default FALSE): when TRUE the workspace may call
-- POST /workspaces/:id/broadcast to send a message to every non-removed
-- agent workspace in the org. Off by default — only privileged
-- orchestrator workspaces should hold this ability.
--
-- talk_to_user_enabled (default TRUE): when FALSE the workspace is not
-- allowed to deliver messages to the canvas user via send_message_to_user /
-- POST /notify. The platform returns HTTP 403 so the agent can forward its
-- update to a parent workspace instead. Default TRUE preserves existing
-- behaviour for all current workspaces.
ALTER TABLE workspaces
ADD COLUMN IF NOT EXISTS broadcast_enabled BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN IF NOT EXISTS talk_to_user_enabled BOOLEAN NOT NULL DEFAULT TRUE;
-6
View File
@@ -29,7 +29,6 @@ from typing import Callable
import inbox
from a2a_tools import (
tool_broadcast_message,
tool_chat_history,
tool_check_task_status,
tool_commit_memory,
@@ -161,11 +160,6 @@ async def handle_tool_call(name: str, arguments: dict) -> str:
arguments.get("before_ts", ""),
source_workspace_id=arguments.get("source_workspace_id") or None,
)
elif name == "broadcast_message":
return await tool_broadcast_message(
arguments.get("message", ""),
workspace_id=arguments.get("workspace_id") or None,
)
return f"Unknown tool: {name}"
-1
View File
@@ -137,7 +137,6 @@ from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_cli
# identically.
from a2a_tools_messaging import ( # noqa: E402 (import after the top-of-module imports)
_upload_chat_files,
tool_broadcast_message,
tool_chat_history,
tool_get_workspace_info,
tool_list_peers,
-58
View File
@@ -101,50 +101,6 @@ async def _upload_chat_files(
return uploaded, None
async def tool_broadcast_message(
message: str,
workspace_id: str | None = None,
) -> str:
"""Send a broadcast message to ALL agent workspaces in the org.
Requires the workspace to have broadcast_enabled=true (set by a user or
admin via PATCH /workspaces/:id/abilities). Use for urgent org-wide
signals status changes, critical alerts, coordination instructions.
Every non-removed workspace receives the message in its activity log so
poll-mode agents pick it up, and push-mode canvases get a real-time
BROADCAST_MESSAGE WebSocket event.
Args:
message: The broadcast text. Keep it concise all agents receive
this, so avoid lengthy prose that floods every context.
workspace_id: Optional. Which registered workspace to send the
broadcast from. Single-workspace agents omit this.
"""
if not message:
return "Error: message is required"
target_workspace_id = (workspace_id or "").strip() or WORKSPACE_ID
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{PLATFORM_URL}/workspaces/{target_workspace_id}/broadcast",
json={"message": message},
headers=_auth_headers_for_heartbeat(target_workspace_id),
)
if resp.status_code == 200:
data = resp.json()
delivered = data.get("delivered", "?")
return f"Broadcast sent to {delivered} workspace(s)"
if resp.status_code == 403:
try:
hint = resp.json().get("hint", "")
except Exception:
hint = ""
return f"Error: broadcast ability not enabled.{(' ' + hint) if hint else ''}"
return f"Error: platform returned {resp.status_code}"
except Exception as e:
return f"Error sending broadcast: {e}"
async def tool_send_message_to_user(
message: str,
attachments: list[str] | None = None,
@@ -195,20 +151,6 @@ async def tool_send_message_to_user(
if uploaded:
return f"Message sent to user with {len(uploaded)} attachment(s)"
return "Message sent to user"
if resp.status_code == 403:
try:
body = resp.json()
if body.get("error") == "talk_to_user_disabled":
hint = body.get("hint", "")
return (
"Error: this workspace is not allowed to send messages "
"directly to the user (talk_to_user is disabled). "
+ (hint + " " if hint else "")
+ "Use delegate_task to forward your update to a parent "
"or supervisor workspace that can reach the user."
)
except Exception:
pass
return f"Error: platform returned {resp.status_code}"
except Exception as e:
return f"Error sending message: {e}"
-48
View File
@@ -3,57 +3,9 @@
import logging
import os
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
# ---------------------------------------------------------------------------
# Provider routing — type alias + resolver used by individual adapters.
# Each adapter defines its own ProviderRegistry with the providers it accepts.
# ---------------------------------------------------------------------------
# Maps prefix → (ordered_auth_env_vars, default_base_url).
ProviderRegistry = dict[str, tuple[tuple[str, ...], str]]
def resolve_provider_routing(
model_str: str,
env: Mapping[str, str],
*,
registry: ProviderRegistry,
runtime_config: dict[str, Any] | None = None,
) -> tuple[str, str, str]:
"""Resolve a ``provider:model`` string to ``(api_key, base_url, bare_model_id)``.
URL precedence (highest to lowest):
1. ``<PREFIX>_BASE_URL`` env var
2. ``runtime_config["provider_url"]``
3. registry default for the prefix
Unknown prefixes fall back to OPENAI_API_KEY + api.openai.com.
Raises RuntimeError when no API key env var is set for the prefix.
"""
if ":" in model_str:
prefix, model_id = model_str.split(":", 1)
else:
prefix, model_id = "openai", model_str
env_vars, default_url = registry.get(
prefix, (("OPENAI_API_KEY",), "https://api.openai.com/v1")
)
api_key = next((env[v] for v in env_vars if env.get(v)), "")
if not api_key:
raise RuntimeError(
f"No API key found for provider {prefix!r} "
f"(checked: {', '.join(env_vars)}). Set one in workspace secrets."
)
env_url = env.get(f"{prefix.upper()}_BASE_URL", "")
config_url = (runtime_config or {}).get("provider_url", "")
base_url = env_url or config_url or default_url
return api_key, base_url, model_id
from a2a.server.agent_execution import AgentExecutor
from event_log import DisabledEventLog, EventLogBackend
-4
View File
@@ -340,10 +340,6 @@ _CLI_A2A_COMMAND_KEYWORDS: dict[str, str | None] = {
"delegate_task_async": "delegate --async",
"check_task_status": "status",
"get_workspace_info": "info",
# `broadcast_message` is not exposed via the CLI subprocess interface
# today — it's an MCP-first capability. If a2a_cli grows a `broadcast`
# subcommand, map it here and the alignment test will gate the change.
"broadcast_message": None,
# `send_message_to_user` is not exposed via the CLI subprocess
# interface today — it requires a structured `attachments` field
# that wouldn't survive a positional-arg shell invocation cleanly.
-40
View File
@@ -51,7 +51,6 @@ from dataclasses import dataclass
from typing import Any, Literal
from a2a_tools import (
tool_broadcast_message,
tool_chat_history,
tool_check_task_status,
tool_commit_memory,
@@ -289,44 +288,6 @@ _GET_WORKSPACE_INFO = ToolSpec(
section=A2A_SECTION,
)
_BROADCAST_MESSAGE = ToolSpec(
name="broadcast_message",
short=(
"Send a message to ALL agent workspaces in the org simultaneously. "
"Requires broadcast_enabled=true on this workspace (set by user/admin)."
),
when_to_use=(
"Use for urgent, org-wide signals: critical status changes, emergency "
"stop instructions, coordinated task announcements. Every non-removed "
"workspace receives the message in its activity log (poll-mode agents "
"see it on their next poll; push-mode canvases get a real-time banner). "
"This tool returns an error if broadcast_enabled is false — a user or "
"admin must enable it via the workspace abilities settings first."
),
input_schema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": (
"The broadcast text. Keep it concise — every agent in the "
"org receives this in their activity feed."
),
},
"workspace_id": {
"type": "string",
"description": (
"Optional. Multi-workspace mode: the registered workspace "
"to broadcast from. Single-workspace agents omit this."
),
},
},
"required": ["message"],
},
impl=tool_broadcast_message,
section=A2A_SECTION,
)
_SEND_MESSAGE_TO_USER = ToolSpec(
name="send_message_to_user",
short=(
@@ -642,7 +603,6 @@ TOOLS: list[ToolSpec] = [
_CHECK_TASK_STATUS,
_LIST_PEERS,
_GET_WORKSPACE_INFO,
_BROADCAST_MESSAGE,
_SEND_MESSAGE_TO_USER,
# Inbox (standalone-only; in-container returns informational error)
_WAIT_FOR_MESSAGE,
@@ -5,7 +5,6 @@
- **check_task_status**: Poll the status of a task started with delegate_task_async; returns result when done.
- **list_peers**: List the workspaces this agent can communicate with — name, ID, status, role for each.
- **get_workspace_info**: Get this workspace's own info — ID, name, role, tier, parent, status.
- **broadcast_message**: Send a message to ALL agent workspaces in the org simultaneously. Requires broadcast_enabled=true on this workspace (set by user/admin).
- **send_message_to_user**: Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to: (1) acknowledge a task immediately ('Got it, I'll start working on this'), (2) send interim progress updates while doing long work, (3) deliver follow-up results after delegation completes, (4) attach files (zip, pdf, csv, image) for the user to download via the `attachments` field (NEVER paste file URLs in `message`). The message appears in the user's chat as if you're proactively reaching out.
- **wait_for_message**: Block until the next inbound message (canvas user OR peer agent) arrives, or until ``timeout_secs`` elapses.
- **inbox_peek**: List pending inbound messages without removing them.
@@ -27,9 +26,6 @@ Call this first when you need to delegate but don't know the target's ID. Access
### get_workspace_info
Use to introspect your own identity (e.g. before reporting back to the user, or to determine whether you're a tier-0 root that can write GLOBAL memory).
### broadcast_message
Use for urgent, org-wide signals: critical status changes, emergency stop instructions, coordinated task announcements. Every non-removed workspace receives the message in its activity log (poll-mode agents see it on their next poll; push-mode canvases get a real-time banner). This tool returns an error if broadcast_enabled is false — a user or admin must enable it via the workspace abilities settings first.
### send_message_to_user
Use proactively across the lifecycle of a task — early to acknowledge, mid-flight to update, late to deliver. Never paste file URLs in the message body — always pass absolute paths in `attachments` so the platform serves them as download chips (works on SaaS where external file hosts are unreachable).
+131 -119
View File
@@ -1,141 +1,153 @@
"""Unit tests for resolve_provider_routing in adapter_base.
"""Unit tests for OpenClaw adapter env-var key selection and provider URL routing.
Covers provider routing, URL-override precedence, and the missing-key error path.
Each adapter defines its own registry; this test file defines one inline that
mirrors what the openclaw adapter uses.
The key-selection and URL-routing logic lives inline in OpenClawAdapter.setup()
(adapter.py lines 84-92). Since setup() carries heavy subprocess dependencies,
these tests isolate the selection logic by reproducing the exact Python expressions
from the adapter source if the adapter's logic changes, these tests must be kept
in sync.
Organisation:
TestEnvKeyChain priority order of the 3 currently supported keys
TestProviderUrlMapping model-prefix provider URL dict correctness
TestNegativeAndFallback no keys set / unsupported keys
xfail stubs AISTUDIO + QIANFAN documented as not-yet-implemented
"""
from __future__ import annotations
import os
from unittest.mock import patch
import pytest
from adapter_base import ProviderRegistry, resolve_provider_routing
# Mirror of the registry in openclaw's adapter.py — kept in sync manually.
PROVIDER_REGISTRY: ProviderRegistry = {
"openai": (("OPENAI_API_KEY",), "https://api.openai.com/v1"),
"groq": (("GROQ_API_KEY",), "https://api.groq.com/openai/v1"),
"openrouter": (("OPENROUTER_API_KEY",), "https://openrouter.ai/api/v1"),
"qianfan": (("QIANFAN_API_KEY", "AISTUDIO_API_KEY"), "https://qianfan.baidubce.com/v2"),
"minimax": (("MINIMAX_API_KEY",), "https://api.minimaxi.com/v1"),
"moonshot": (("KIMI_API_KEY",), "https://api.moonshot.ai/v1"),
# ---------------------------------------------------------------------------
# Helpers — mirror the exact expressions from adapter.py lines 84-92.
# Must be kept in sync with the adapter source.
# ---------------------------------------------------------------------------
def _select_key(env: dict) -> str:
"""Mirror line 84: nested os.environ.get priority chain."""
return env.get("OPENAI_API_KEY",
env.get("GROQ_API_KEY",
env.get("OPENROUTER_API_KEY", "")))
_PROVIDER_URLS: dict[str, str] = {
"openai": "https://api.openai.com/v1",
"groq": "https://api.groq.com/openai/v1",
"openrouter": "https://openrouter.ai/api/v1",
}
class TestProviderRouting:
def test_openai_key_and_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"openai:gpt-4o", {"OPENAI_API_KEY": "sk-openai"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-openai"
assert base_url == "https://api.openai.com/v1"
assert model_id == "gpt-4o"
def test_groq_key_and_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"groq:llama-3.3-70b", {"GROQ_API_KEY": "sk-groq"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-groq"
assert base_url == "https://api.groq.com/openai/v1"
assert model_id == "llama-3.3-70b"
def test_openrouter_key_and_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"openrouter:anthropic/claude-sonnet-4-5", {"OPENROUTER_API_KEY": "sk-or"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-or"
assert base_url == "https://openrouter.ai/api/v1"
assert model_id == "anthropic/claude-sonnet-4-5"
def test_qianfan_primary_key(self):
api_key, _, _ = resolve_provider_routing(
"qianfan:ernie-4.5", {"QIANFAN_API_KEY": "sk-qf", "AISTUDIO_API_KEY": "sk-ai"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-qf"
def test_qianfan_fallback_to_aistudio(self):
api_key, base_url, _ = resolve_provider_routing(
"qianfan:ernie-4.5", {"AISTUDIO_API_KEY": "sk-ai"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-ai"
assert base_url == "https://qianfan.baidubce.com/v2"
def test_minimax_key_and_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"minimax:MiniMax-M2.7", {"MINIMAX_API_KEY": "sk-mm"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-mm"
assert base_url == "https://api.minimaxi.com/v1"
assert model_id == "MiniMax-M2.7"
def test_moonshot_key_and_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"moonshot:kimi-k2.5", {"KIMI_API_KEY": "sk-kimi"}, registry=PROVIDER_REGISTRY
)
assert api_key == "sk-kimi"
assert base_url == "https://api.moonshot.ai/v1"
assert model_id == "kimi-k2.5"
def test_bare_model_id_defaults_to_openai(self):
api_key, base_url, model_id = resolve_provider_routing(
"gpt-4o", {"OPENAI_API_KEY": "sk-openai"}, registry=PROVIDER_REGISTRY
)
assert base_url == "https://api.openai.com/v1"
assert model_id == "gpt-4o"
def test_unknown_prefix_falls_back_to_openai_url(self):
api_key, base_url, model_id = resolve_provider_routing(
"custom-shim:my-model", {"OPENAI_API_KEY": "sk-openai"}, registry=PROVIDER_REGISTRY
)
assert base_url == "https://api.openai.com/v1"
assert model_id == "my-model"
def _select_url(model: str, runtime_config: dict | None = None) -> str:
"""Mirror lines 86-92: model-prefix → provider URL with optional override."""
prefix = model.split(":")[0] if ":" in model else "openai"
return (runtime_config or {}).get(
"provider_url",
_PROVIDER_URLS.get(prefix, "https://api.openai.com/v1"),
)
class TestUrlOverridePrecedence:
# ---------------------------------------------------------------------------
# 1. Env-var key priority chain (3 keys currently in adapter.py)
# ---------------------------------------------------------------------------
def test_env_base_url_beats_registry_default(self):
_, base_url, _ = resolve_provider_routing(
"minimax:MiniMax-M2.7",
{"MINIMAX_API_KEY": "sk-mm", "MINIMAX_BASE_URL": "https://api.minimax.chat/v1"},
registry=PROVIDER_REGISTRY,
)
assert base_url == "https://api.minimax.chat/v1"
class TestEnvKeyChain:
def test_runtime_config_provider_url_beats_registry_default(self):
_, base_url, _ = resolve_provider_routing(
"openai:gpt-4o",
{"OPENAI_API_KEY": "sk-openai"},
registry=PROVIDER_REGISTRY,
runtime_config={"provider_url": "https://proxy.example.com/v1"},
)
assert base_url == "https://proxy.example.com/v1"
def test_openai_key_selected(self):
with patch.dict(os.environ, {"OPENAI_API_KEY": "sk-openai-test"}, clear=True):
assert _select_key(os.environ) == "sk-openai-test"
def test_env_base_url_beats_runtime_config(self):
_, base_url, _ = resolve_provider_routing(
"openai:gpt-4o",
{"OPENAI_API_KEY": "sk-openai", "OPENAI_BASE_URL": "https://env-wins.com/v1"},
registry=PROVIDER_REGISTRY,
runtime_config={"provider_url": "https://config-loses.com/v1"},
)
assert base_url == "https://env-wins.com/v1"
def test_groq_key_selected_when_openai_absent(self):
with patch.dict(os.environ, {"GROQ_API_KEY": "sk-groq-test"}, clear=True):
assert _select_key(os.environ) == "sk-groq-test"
def test_openrouter_key_selected_when_openai_and_groq_absent(self):
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "sk-or-test"}, clear=True):
assert _select_key(os.environ) == "sk-or-test"
def test_openai_beats_groq_when_both_set(self):
with patch.dict(os.environ, {"OPENAI_API_KEY": "openai", "GROQ_API_KEY": "groq"}, clear=True):
assert _select_key(os.environ) == "openai"
def test_groq_beats_openrouter_when_openai_absent(self):
with patch.dict(os.environ, {"GROQ_API_KEY": "groq", "OPENROUTER_API_KEY": "or"}, clear=True):
assert _select_key(os.environ) == "groq"
class TestMissingKey:
# ---------------------------------------------------------------------------
# 2. Model-prefix → provider URL routing
# ---------------------------------------------------------------------------
def test_raises_when_no_key_set(self):
with pytest.raises(RuntimeError, match="No API key found for provider 'minimax'"):
resolve_provider_routing("minimax:MiniMax-M2.7", {}, registry=PROVIDER_REGISTRY)
class TestProviderUrlMapping:
def test_raises_lists_checked_vars_in_message(self):
with pytest.raises(RuntimeError, match="MINIMAX_API_KEY"):
resolve_provider_routing("minimax:MiniMax-M2.7", {}, registry=PROVIDER_REGISTRY)
def test_openai_prefix_routes_to_openai(self):
assert _select_url("openai:gpt-4o") == "https://api.openai.com/v1"
def test_groq_prefix_routes_to_groq(self):
assert _select_url("groq:llama3-70b") == "https://api.groq.com/openai/v1"
def test_openrouter_prefix_routes_to_openrouter(self):
assert _select_url("openrouter:meta-llama/llama-3.3-70b") == "https://openrouter.ai/api/v1"
def test_runtime_config_override_wins_over_prefix(self):
url = _select_url("openai:gpt-4o", {"provider_url": "https://custom.example.com/v1"})
assert url == "https://custom.example.com/v1"
def test_unknown_prefix_falls_back_to_openai(self):
assert _select_url("some-unknown-model") == "https://api.openai.com/v1"
class TestRegistryCompleteness:
"""Smoke-check that every provider in the registry has a non-empty entry."""
# ---------------------------------------------------------------------------
# 3. Negative / fallback cases
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("prefix", PROVIDER_REGISTRY)
def test_all_providers_have_key_vars_and_url(self, prefix):
env_vars, base_url = PROVIDER_REGISTRY[prefix]
assert env_vars, f"{prefix}: env_vars is empty"
assert base_url.startswith("https://"), f"{prefix}: base_url looks wrong: {base_url}"
class TestNegativeAndFallback:
def test_no_keys_returns_empty_string(self):
with patch.dict(os.environ, {}, clear=True):
assert _select_key(os.environ) == ""
def test_unsupported_aistudio_key_returns_empty(self):
"""Documents that AISTUDIO_API_KEY is NOT yet in the adapter's key chain."""
with patch.dict(os.environ, {"AISTUDIO_API_KEY": "sk-ai"}, clear=True):
assert _select_key(os.environ) == ""
def test_unsupported_qianfan_key_returns_empty(self):
"""Documents that QIANFAN_API_KEY is NOT yet in the adapter's key chain."""
with patch.dict(os.environ, {"QIANFAN_API_KEY": "sk-qf"}, clear=True):
assert _select_key(os.environ) == ""
# ---------------------------------------------------------------------------
# 4. AISTUDIO + QIANFAN — xfail stubs (not yet implemented in adapter.py)
# These fail now; they should be promoted to passing tests once the adapter
# adds AISTUDIO_API_KEY and QIANFAN_API_KEY to its key chain and provider_urls.
# ---------------------------------------------------------------------------
@pytest.mark.xfail(
strict=True,
reason=(
"AISTUDIO_API_KEY not yet in openclaw adapter env-var chain — "
"add to adapter.py line 84 and provider_urls dict with "
"URL https://generativelanguage.googleapis.com/v1beta/openai"
),
)
def test_aistudio_key_routes_to_aistudio_url():
with patch.dict(os.environ, {"AISTUDIO_API_KEY": "sk-ai-test"}, clear=True):
assert _select_key(os.environ) == "sk-ai-test"
assert _select_url("gemini-2.5-flash") == "https://generativelanguage.googleapis.com/v1beta/openai"
@pytest.mark.xfail(
strict=True,
reason=(
"QIANFAN_API_KEY not yet in openclaw adapter env-var chain — "
"add to adapter.py line 84 and provider_urls dict with "
"URL https://qianfan.baidubce.com/v2"
),
)
def test_qianfan_key_routes_to_qianfan_url():
with patch.dict(os.environ, {"QIANFAN_API_KEY": "sk-qf-test"}, clear=True):
assert _select_key(os.environ) == "sk-qf-test"
assert _select_url("ernie-4.5") == "https://qianfan.baidubce.com/v2"