forked from molecule-ai/molecule-core
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 184ce7ae4e | |||
| ff75aeb43e | |||
| 81cf0cbf98 | |||
| 412dec0d87 | |||
| 39931acd9c | |||
| 6f19b88fa7 | |||
| 83454e5efd | |||
| 8254bedf30 | |||
| ec72f199e6 | |||
| ae22a55675 | |||
| 08648bf4b1 | |||
| eec4ea2e7d | |||
| 6201d12533 | |||
| 81e83c05b7 | |||
| 5b5eacbb29 | |||
| c8fca1467e | |||
| 7c8b81c6eb | |||
| fc1c45789e | |||
| e3a18ed8e8 | |||
| 9f551319d2 | |||
| 1052f8bdb0 | |||
| 30fb507165 | |||
| 77e9a965ac | |||
| 5334d60de4 | |||
| d6c0227e3f | |||
| 27db090d3d | |||
| 0f25f6de97 | |||
| 9991057ad1 | |||
| b89a49ec93 | |||
| f5613bf099 | |||
| 9bd2a2c45f | |||
| a489ee1a7c | |||
| c79ba05ed5 | |||
| 6470e5f41b | |||
| aa560c0314 | |||
| 7644e82f2f | |||
| be18b9c8f9 | |||
| 2227a14b1e | |||
| e72f9ad107 |
@@ -48,16 +48,21 @@ export function EmptyState() {
|
||||
});
|
||||
|
||||
// "Create blank" bypasses templates entirely — no preflight, no
|
||||
// modal, just POST /workspaces with a default name and tier.
|
||||
// Deliberately NOT routed through useTemplateDeploy because it
|
||||
// has no `template.id` to deploy against.
|
||||
// modal, just POST /workspaces with a default name. Deliberately
|
||||
// NOT routed through useTemplateDeploy because it has no
|
||||
// `template.id` to deploy against.
|
||||
//
|
||||
// tier is omitted so the backend picks a SaaS-aware default
|
||||
// (T4 on SaaS, T3 on self-hosted — see WorkspaceHandler.DefaultTier).
|
||||
// The previous hardcoded `tier: 2` shipped every fresh-tenant agent
|
||||
// at Standard regardless of host, which surprised SaaS users whose
|
||||
// CreateWorkspaceDialog already defaults to T4.
|
||||
const createBlank = async () => {
|
||||
setBlankCreating(true);
|
||||
setBlankError(null);
|
||||
try {
|
||||
const ws = await api.post<{ id: string }>("/workspaces", {
|
||||
name: "My First Agent",
|
||||
tier: 2,
|
||||
canvas: firstDeployCoords(),
|
||||
});
|
||||
handleDeployed(ws.id);
|
||||
|
||||
@@ -286,6 +286,14 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [confirmRestart, setConfirmRestart] = useState(false);
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
// First-mount scroll-to-bottom needs `behavior: "instant"` — long
|
||||
// conversations smooth-animate for ~300ms which any concurrent
|
||||
// re-render can interrupt, leaving the user stuck mid-conversation
|
||||
// when the chat tab opens. Subsequent appends (new agent messages)
|
||||
// keep `smooth` for the visual "landing" feel. Flipped the first
|
||||
// time messages.length goes positive, so a workspace switch (which
|
||||
// remounts ChatTab) gets a fresh instant jump too.
|
||||
const hasInitialScrollRef = useRef(false);
|
||||
// Lazy-load older history on scroll-up.
|
||||
// - containerRef = the scrollable messages viewport
|
||||
// - topRef = sentinel above the messages list; IO observes it
|
||||
@@ -545,6 +553,15 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
scrollAnchorRef.current = null;
|
||||
return;
|
||||
}
|
||||
// Instant on first arrival of messages — smooth-scroll on a long
|
||||
// conversation gets interrupted by concurrent renders and leaves
|
||||
// the user stuck in the middle. After the first jump, subsequent
|
||||
// appends animate as before.
|
||||
if (!hasInitialScrollRef.current && messages.length > 0) {
|
||||
hasInitialScrollRef.current = true;
|
||||
bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior });
|
||||
return;
|
||||
}
|
||||
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages]);
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# Team Expansion (Recursive Workspaces)
|
||||
|
||||
When a workspace is expanded into a team, it gains sub-workspaces while its own agent remains as the **team lead** (coordinator). This is recursive — sub-workspaces can themselves be expanded into teams, infinitely deep.
|
||||
|
||||
## How It Works
|
||||
|
||||
When Developer PM is expanded into a team:
|
||||
|
||||
```
|
||||
Business Core
|
||||
|
|
||||
+-- Developer PM (agent stays, becomes coordinator)
|
||||
|
|
||||
+-- Frontend Agent (sub-workspace, private scope)
|
||||
+-- Backend Agent (sub-workspace, private scope)
|
||||
+-- QA Agent (sub-workspace, private scope)
|
||||
```
|
||||
|
||||
- Developer PM's agent **still exists** and acts as coordinator
|
||||
- Developer PM receives incoming A2A messages from Business Core
|
||||
- Developer PM's agent decides how to delegate to sub-workspaces
|
||||
- Sub-workspaces talk to Developer PM and to each other (same level)
|
||||
- Sub-workspaces **cannot** talk to Business Core or any workspace outside the team
|
||||
|
||||
## Communication Rules
|
||||
|
||||
| Direction | Allowed? | Example |
|
||||
|-----------|----------|---------|
|
||||
| Parent level -> team lead | Yes | Business Core -> Developer PM |
|
||||
| Team lead -> sub-workspaces | Yes | Developer PM -> Frontend Agent |
|
||||
| Sub-workspace -> team lead | Yes | Frontend Agent -> Developer PM |
|
||||
| Sub-workspace <-> sibling | Yes | Frontend Agent <-> Backend Agent |
|
||||
| Outside -> sub-workspace directly | No (403) | Business Core -> Frontend Agent |
|
||||
| Sub-workspace -> outside directly | No | Frontend Agent -> Business Core |
|
||||
|
||||
The team lead (Developer PM) is the **only** bridge between the team's internal world and the outside.
|
||||
|
||||
## Scoped Registry
|
||||
|
||||
Sub-workspaces register in the platform registry but with a **private scope**. The registry knows about them but enforces access control.
|
||||
|
||||
```
|
||||
Registry:
|
||||
Business Core :8001 scope: public
|
||||
Developer PM :8002 scope: public
|
||||
Frontend Agent :8010 scope: private, parent=Developer PM
|
||||
Backend Agent :8011 scope: private, parent=Developer PM
|
||||
QA Agent :8012 scope: private, parent=Developer PM
|
||||
```
|
||||
|
||||
- The platform can always discover any workspace (for provisioning, monitoring)
|
||||
- The parent workspace can discover its sub-workspaces
|
||||
- Sub-workspaces can discover their siblings (same parent)
|
||||
- Outside workspaces get a **403 Forbidden** if they try to discover a private sub-workspace
|
||||
|
||||
## How to Expand
|
||||
|
||||
Expansion is triggered via `POST /workspaces/:id/expand`. The platform reads the `sub_workspaces` list from the workspace's config and provisions each one. On the canvas, users right-click a workspace node and select "Expand into team."
|
||||
|
||||
Collapsing is the inverse: `POST /workspaces/:id/collapse`. Sub-workspaces are stopped and removed.
|
||||
|
||||
## What Happens on Expansion
|
||||
|
||||
When Developer PM is expanded into a team, the hierarchy changes but the outside view doesn't. Business Core's parent/child relationship to Developer PM is unaffected — Developer PM still responds to the same A2A endpoint.
|
||||
|
||||
The events fired:
|
||||
- `WORKSPACE_EXPANDED` with the new `sub_workspace_ids` in the payload
|
||||
- `WORKSPACE_PROVISIONING` for each new sub-workspace
|
||||
- `WORKSPACE_ONLINE` for each sub-workspace as they come up
|
||||
|
||||
Communication rules are automatically derived from the new hierarchy — no manual wiring needed.
|
||||
|
||||
## Canvas Behavior
|
||||
|
||||
- Children render as embedded mini-cards (`TeamMemberChip`) inside the parent node, not as separate canvas nodes
|
||||
- Each mini-card shows full status: gradient bar, name, tier badge, skills pills, active tasks, descendant count
|
||||
- **Recursive rendering** up to 3 levels deep (`MAX_NESTING_DEPTH = 3`) — sub-cards can contain their own "Team" sections
|
||||
- Parent node dynamically resizes: 210-280px (no children), 320-450px (children), 400-560px (grandchildren)
|
||||
- Eject button (sky-blue arrow icon) on hover extracts a child from the team
|
||||
- "Extract from Team" also available in the right-click context menu
|
||||
- Double-click a team node to zoom/fit to the parent area
|
||||
- The parent workspace node shows a badge with total descendant count
|
||||
|
||||
## Collapsing a Team
|
||||
|
||||
The inverse of expansion, triggered via `POST /workspaces/:id/collapse`:
|
||||
|
||||
1. Each sub-workspace agent wraps up current work and writes a handoff document to memory
|
||||
2. Sub-workspaces are stopped and removed
|
||||
3. The team lead's agent goes back to handling everything directly
|
||||
4. A `WORKSPACE_COLLAPSED` event fires
|
||||
|
||||
Sub-workspace memory is cleaned up based on backend (see [Memory — Cleanup](../architecture/memory.md#cleanup-on-workspace-deletion)).
|
||||
|
||||
## Deleting a Team Workspace
|
||||
|
||||
When a team workspace is deleted:
|
||||
1. Platform shows a warning listing all sub-workspaces that will be deleted
|
||||
2. User can **drag sub-workspaces out** of the team before confirming (promotes them to the parent level)
|
||||
3. On confirmation, cascade delete removes the parent and all remaining sub-workspaces
|
||||
4. `WORKSPACE_REMOVED` events fire for each deleted workspace
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Communication Rules](../api-protocol/communication-rules.md) — Full access control model
|
||||
- [Core Concepts](../product/core-concepts.md) — Workspace fundamentals
|
||||
- [System Prompt Structure](./system-prompt-structure.md) — How peer capabilities are injected
|
||||
- [Provisioner](../architecture/provisioner.md) — How sub-workspaces are deployed
|
||||
- [Registry & Heartbeat](../api-protocol/registry-and-heartbeat.md) — How registration works
|
||||
- [Event Log](../architecture/event-log.md) — Events fired during expansion
|
||||
- [Canvas UI](../frontend/canvas.md) — Visual behavior of teams
|
||||
@@ -41,8 +41,6 @@ Full contract: `docs/runbooks/admin-auth.md`.
|
||||
| GET | /admin/workspaces/:id/test-token | admin_test_token.go — mint a fresh bearer token for E2E scripts; returns 404 unless `MOLECULE_ENV != production` or `MOLECULE_ENABLE_TEST_TOKENS=1` |
|
||||
| GET/POST/DELETE | /admin/secrets[/:key] | secrets.go — legacy aliases for /settings/secrets |
|
||||
| WS | /workspaces/:id/terminal | terminal.go |
|
||||
| POST | /workspaces/:id/expand | team.go |
|
||||
| POST | /workspaces/:id/collapse | team.go |
|
||||
| POST/GET | /workspaces/:id/approvals | approvals.go |
|
||||
| POST | /workspaces/:id/approvals/:id/decide | approvals.go |
|
||||
| GET | /approvals/pending | approvals.go |
|
||||
|
||||
@@ -336,8 +336,6 @@ This same logic governs: A2A delegation, memory scope enforcement, activity visi
|
||||
|
||||
| Method | Endpoint | Purpose |
|
||||
|--------|----------|---------|
|
||||
| `POST` | `/workspaces/:id/expand` | Expand workspace into team (become coordinator) |
|
||||
| `POST` | `/workspaces/:id/collapse` | Collapse team back to single workspace |
|
||||
|
||||
### Files, Terminal, Templates, Bundles (8 endpoints)
|
||||
|
||||
|
||||
@@ -186,4 +186,3 @@ So the UI now exposes more operational failure state directly instead of silentl
|
||||
- [Quickstart](../quickstart.md)
|
||||
- [Platform API](../api-protocol/platform-api.md)
|
||||
- [Workspace Runtime](../agent-runtime/workspace-runtime.md)
|
||||
- [Team Expansion](../agent-runtime/team-expansion.md)
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ lands in the watch list with a colliding term, add a row here.
|
||||
| **plugin** | A directory under `plugins/` packaging one or more skills or an MCP server wrapper, installable per-workspace via `POST /workspaces/:id/plugins`. Governed by `plugin.yaml`. | **Langflow**: a visual UI node / component in a flowchart. **CrewAI**: a Python-importable callable registered as a capability. |
|
||||
| **agent** | A persistent containerized workspace running continuously — an identity with memory, a role, and a schedule. Not a one-shot invocation. | Most frameworks (AutoGPT, LangChain agents, OpenAI Assistants): a stateless function-call loop. No persistence between invocations unless explicitly checkpointed. |
|
||||
| **flow** | A task execution within a workspace — a request enters, the agent runs tools, emits a response, logs activity. No explicit graph abstraction. | **Langflow**: a directed graph of nodes you author visually. **LangGraph**: a stateful graph of callable nodes. Our "flow" is an imperative timeline, not a graph. |
|
||||
| **team** | A named cluster of workspaces under a PM (org template `expand_team`). Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **team** | A named cluster of workspaces under a PM . Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **skill** | A directory with `SKILL.md` that an agent invokes via the `Skill` tool. Skills are documentation + optional scripts that teach an agent a recipe. | **Anthropic Skills API**: nearly identical. **CrewAI tool**: closer to our plugin's MCP tool, not our skill. |
|
||||
| **channel** | An outbound/inbound social integration (Telegram, Slack, …) per-workspace, wired in `workspace_channels`. | Slack's "channel": the container for messages. We use "channel" for the adapter + credentials, not the conversation itself. |
|
||||
| **runtime** | The execution engine image tag for a workspace: one of `langgraph`, `claude-code`, `openclaw`, `crewai`, `autogen`, `deepagents`, `hermes`. | **LangGraph runtime**: the Python process running the graph. We use "runtime" for the Docker image + adapter pairing, not the inner process. |
|
||||
|
||||
@@ -166,8 +166,6 @@ list_workspaces
|
||||
|
||||
| MCP Tool | API Route | Method | Description |
|
||||
|----------|-----------|--------|-------------|
|
||||
| `expand_team` | `/workspaces/:id/expand` | POST | Expand team node |
|
||||
| `collapse_team` | `/workspaces/:id/collapse` | POST | Collapse team node |
|
||||
|
||||
### Templates & Bundles
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ TOP_LEVEL_MODULES = {
|
||||
"a2a_executor",
|
||||
"a2a_mcp_server",
|
||||
"a2a_tools",
|
||||
"a2a_tools_delegation",
|
||||
"a2a_tools_rbac",
|
||||
"adapter_base",
|
||||
"agent",
|
||||
|
||||
@@ -94,6 +94,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
@@ -142,6 +149,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
|
||||
@@ -21,6 +21,14 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Bundle the built-in memory-plugin-postgres binary so an operator can
|
||||
# activate Memory v2 by setting MEMORY_V2_CUTOVER=true + (default)
|
||||
# MEMORY_PLUGIN_URL=http://localhost:9100. The entrypoint starts this
|
||||
# binary in the background; main /platform talks to it over loopback.
|
||||
# Stays inert until the operator flips the cutover env var.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# Clone templates + plugins at build time from manifest.json
|
||||
FROM alpine:3.20 AS templates
|
||||
@@ -30,8 +38,9 @@ COPY scripts/clone-manifest.sh /scripts/clone-manifest.sh
|
||||
RUN chmod +x /scripts/clone-manifest.sh && /scripts/clone-manifest.sh /manifest.json /workspace-configs-templates /org-templates /plugins
|
||||
|
||||
FROM alpine:3.20
|
||||
RUN apk add --no-cache ca-certificates git tzdata
|
||||
RUN apk add --no-cache ca-certificates git tzdata wget
|
||||
COPY --from=builder /platform /platform
|
||||
COPY --from=builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
COPY --from=templates /workspace-configs-templates /workspace-configs-templates
|
||||
COPY --from=templates /org-templates /org-templates
|
||||
@@ -41,6 +50,7 @@ RUN addgroup -g 1000 platform && adduser -u 1000 -G platform -s /bin/sh -D platf
|
||||
EXPOSE 8080
|
||||
COPY <<'ENTRY' /entrypoint.sh
|
||||
#!/bin/sh
|
||||
# Set up docker-socket group (unchanged from pre-sidecar entrypoint).
|
||||
if [ -S /var/run/docker.sock ]; then
|
||||
SOCK_GID=$(stat -c '%g' /var/run/docker.sock 2>/dev/null || stat -f '%g' /var/run/docker.sock 2>/dev/null)
|
||||
if [ -n "$SOCK_GID" ] && [ "$SOCK_GID" != "0" ]; then
|
||||
@@ -50,6 +60,61 @@ if [ -S /var/run/docker.sock ]; then
|
||||
addgroup platform root 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). Co-located with the
|
||||
# main server so operators flipping MEMORY_V2_CUTOVER=true don't need
|
||||
# to provision a separate service.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it — either MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set.
|
||||
# Without that signal, the sidecar adds zero value (the platform's
|
||||
# wiring.go skips building the client too) but pays a real cost: the
|
||||
# plugin's first migration runs `CREATE EXTENSION vector`, which fails
|
||||
# on tenant Postgres without pgvector preinstalled and aborts container
|
||||
# boot via the 30s health gate. Caught on staging redeploy 2026-05-05.
|
||||
#
|
||||
# Env defaults (when sidecar IS spawned):
|
||||
# MEMORY_PLUGIN_DATABASE_URL = $DATABASE_URL (share existing Postgres;
|
||||
# plugin's `memory_namespaces` / `memory_records` tables coexist
|
||||
# with `agent_memories` and the rest of the platform schema —
|
||||
# no conflicts. Operator can override with a separate URL.)
|
||||
# MEMORY_PLUGIN_LISTEN_ADDR = 127.0.0.1:9100
|
||||
#
|
||||
# Set MEMORY_PLUGIN_DISABLE=1 to force-skip the sidecar even with
|
||||
# cutover env set (e.g. running the plugin externally on a separate host).
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
# Drop privs to the platform user — the plugin doesn't need root and
|
||||
# runs unprivileged elsewhere (tenant image already starts as canvas).
|
||||
su-exec platform /memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for the plugin's /v1/health to return 200. Boot
|
||||
# failure here is fatal — better to crash-loop than to silently
|
||||
# serve cutover traffic against a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check that DATABASE_URL is reachable, has the pgvector extension, and the plugin's migrations applied." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
exec su-exec platform /platform "$@"
|
||||
ENTRY
|
||||
RUN chmod +x /entrypoint.sh && apk add --no-cache su-exec
|
||||
|
||||
@@ -34,6 +34,13 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Memory v2 sidecar binary (Memory v2 #2728). Bundled so an operator
|
||||
# can activate cutover by flipping MEMORY_V2_CUTOVER=true without
|
||||
# provisioning a separate service. See entrypoint-tenant.sh for the
|
||||
# launch logic.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# ── Stage 2: Canvas Next.js standalone ────────────────────────────────
|
||||
FROM node:20-alpine AS canvas-builder
|
||||
@@ -74,8 +81,9 @@ RUN deluser --remove-home node 2>/dev/null || true; \
|
||||
delgroup node 2>/dev/null || true; \
|
||||
addgroup -g 1000 canvas && adduser -u 1000 -G canvas -s /bin/sh -D canvas
|
||||
|
||||
# Go platform binary
|
||||
# Go platform binary + Memory v2 sidecar
|
||||
COPY --from=go-builder /platform /platform
|
||||
COPY --from=go-builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
|
||||
# Templates + plugins (cloned from GitHub in stage 3)
|
||||
@@ -91,7 +99,7 @@ COPY --from=canvas-builder /canvas/public ./public
|
||||
|
||||
COPY workspace-server/entrypoint-tenant.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh && \
|
||||
chown -R canvas:canvas /canvas /platform /migrations
|
||||
chown -R canvas:canvas /canvas /platform /memory-plugin /migrations
|
||||
|
||||
EXPOSE 8080
|
||||
# entrypoint.sh starts as root to fix volume perms, then drops to
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoadConfig_DefaultListenAddrIsLoopback pins the default-bind contract.
|
||||
//
|
||||
// Why this matters: with the prior `:9100` default, the plugin listened on
|
||||
// every interface. Inside the container it didn't matter (no host port
|
||||
// mapping today), but a future change that publishes 9100 OR a cross-host
|
||||
// sidecar deploy would have exposed an unauth'd memory store. Loopback by
|
||||
// default is the least-privilege baseline; operators with a multi-host
|
||||
// topology override via MEMORY_PLUGIN_LISTEN_ADDR.
|
||||
func TestLoadConfig_DefaultListenAddrIsLoopback(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", "")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(cfg.ListenAddr, "127.0.0.1:") {
|
||||
t.Errorf("default ListenAddr must bind loopback-only, got %q "+
|
||||
"(security regression — would expose plugin on every interface)",
|
||||
cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ListenAddrEnvOverride(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", ":9100")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if cfg.ListenAddr != ":9100" {
|
||||
t.Errorf("env override ignored: want :9100, got %q", cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MissingDatabaseURL(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "")
|
||||
|
||||
if _, err := loadConfig(); err == nil {
|
||||
t.Fatal("loadConfig must error when MEMORY_PLUGIN_DATABASE_URL is empty")
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -26,12 +28,28 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
|
||||
)
|
||||
|
||||
// migrationsFS bundles the .up.sql files into the binary at build time
|
||||
// so the prebuilt image doesn't need the source tree at runtime. The
|
||||
// prior `os.ReadDir("cmd/memory-plugin-postgres/migrations")` path
|
||||
// only resolved during `go test` from the repo root — in the published
|
||||
// image the path didn't exist and boot failed after the 30s health gate
|
||||
// (caught on staging redeploy 2026-05-05 after PR #2906).
|
||||
//
|
||||
//go:embed migrations/*.up.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
const (
|
||||
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
|
||||
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
|
||||
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
|
||||
|
||||
defaultListenAddr = ":9100"
|
||||
// Loopback-only by default (defense in depth). The platform talks to
|
||||
// the plugin over `http://localhost:9100` from the same container, so
|
||||
// binding to all interfaces would only widen the reachable surface
|
||||
// without enabling any in-design caller. Operators running the plugin
|
||||
// on a separate host override via MEMORY_PLUGIN_LISTEN_ADDR=:9100 (or
|
||||
// some other interface).
|
||||
defaultListenAddr = "127.0.0.1:9100"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -143,32 +161,71 @@ func openDB(databaseURL string) (*sql.DB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations applies the schema migrations bundled at
|
||||
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
|
||||
// runMigrations applies the schema migrations bundled into the binary
|
||||
// via go:embed (see migrationsFS at the top of this file). Idempotent
|
||||
// on repeat boot — every migration file uses CREATE … IF NOT EXISTS.
|
||||
//
|
||||
// Implementation note: rather than embedding the full migrate engine,
|
||||
// we read the migration files at boot from a known relative path. The
|
||||
// down migrations are deliberately NOT applied here — that's a manual
|
||||
// operator action. This keeps the binary tiny and avoids dragging in
|
||||
// golang-migrate's drivers.
|
||||
// The down migrations are deliberately NOT applied here — that's a
|
||||
// manual operator action. This keeps the binary tiny and avoids
|
||||
// dragging in golang-migrate's drivers.
|
||||
//
|
||||
// MEMORY_PLUGIN_MIGRATIONS_DIR (filesystem path) is honored as an
|
||||
// override for operators who need to ship custom migrations alongside
|
||||
// the binary without rebuilding. When unset (the common case) we read
|
||||
// from the embedded FS.
|
||||
func runMigrations(db *sql.DB) error {
|
||||
// Find the migrations directory. In `go run` mode it's relative
|
||||
// to the cmd dir; in the prebuilt binary case it's expected next
|
||||
// to the binary OR via env var override.
|
||||
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
|
||||
if dir == "" {
|
||||
// Best-effort: try the cwd-relative path that works for `go test`.
|
||||
dir = "cmd/memory-plugin-postgres/migrations"
|
||||
if dir := strings.TrimSpace(os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")); dir != "" {
|
||||
return runMigrationsFromDisk(db, dir)
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
return runMigrationsFromEmbed(db)
|
||||
}
|
||||
|
||||
// runMigrationsFromEmbed applies the *.up.sql files bundled into the
|
||||
// binary at build time. Order is alphabetical (matches the on-disk
|
||||
// behavior of os.ReadDir on Linux for the same set of names).
|
||||
func runMigrationsFromEmbed(db *sql.DB) error {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
return fmt.Errorf("read embedded migrations: %w", err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
path := dir + "/" + e.Name()
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
data, err := migrationsFS.ReadFile("migrations/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read embedded %q: %w", name, err)
|
||||
}
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", name, err)
|
||||
}
|
||||
log.Printf("applied embedded migration %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMigrationsFromDisk preserves the legacy filesystem-path mode for
|
||||
// operator-supplied custom migrations.
|
||||
func runMigrationsFromDisk(db *sql.DB, dir string) error {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
path := dir + "/" + name
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %q: %w", path, err)
|
||||
@@ -176,7 +233,7 @@ func runMigrations(db *sql.DB) error {
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", path, err)
|
||||
}
|
||||
log.Printf("applied migration %s", e.Name())
|
||||
log.Printf("applied disk migration %s (from %s)", name, dir)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMigrationsEmbedded_ContainsCreateTable pins that the migrations
|
||||
// are bundled into the binary at build time, NOT loaded from a
|
||||
// filesystem path that doesn't exist at runtime in the published image.
|
||||
//
|
||||
// Pre-fix: PR #2906 shipped the binary without the migrations dir;
|
||||
// `os.ReadDir("cmd/memory-plugin-postgres/migrations")` errored on every
|
||||
// tenant boot, the 30s health gate aborted the container, and the
|
||||
// staging redeploy fleet job marked all tenants as failed. Embedding
|
||||
// the migrations into the binary removes the runtime path entirely.
|
||||
func TestMigrationsEmbedded_ContainsCreateTable(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("embedded migrations dir is empty — go:embed pattern matched no files")
|
||||
}
|
||||
|
||||
var seenUp bool
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
seenUp = true
|
||||
data, err := migrationsFS.ReadFile("migrations/" + e.Name())
|
||||
if err != nil {
|
||||
t.Errorf("read embedded %q: %v", e.Name(), err)
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(data), "CREATE TABLE") {
|
||||
t.Errorf("embedded %q has no CREATE TABLE — wrong file embedded?", e.Name())
|
||||
}
|
||||
}
|
||||
if !seenUp {
|
||||
t.Fatal("no *.up.sql in embedded migrations — runtime would have no schema to apply")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunMigrationsFromEmbed_OrderingIsAlphabetic pins that we apply
|
||||
// migrations in deterministic alphabetical order, not in whatever
|
||||
// arbitrary order migrationsFS.ReadDir happens to return. With one
|
||||
// migration today this is moot, but a future second migration ('002_…')
|
||||
// MUST run after '001_…' or the schema is broken.
|
||||
//
|
||||
// We can't easily exercise db.Exec here (no test DB); instead pin the
|
||||
// sort step on the directory listing itself.
|
||||
func TestRunMigrationsFromEmbed_OrderingIsAlphabetic(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
for i := 1; i < len(names); i++ {
|
||||
if names[i-1] > names[i] {
|
||||
t.Errorf("ReadDir returned non-sorted names; runMigrationsFromEmbed must sort. "+
|
||||
"Got %q before %q", names[i-1], names[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,51 @@ cd /canvas
|
||||
PORT=3000 HOSTNAME=0.0.0.0 node server.js &
|
||||
CANVAS_PID=$!
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). See Dockerfile entrypoint
|
||||
# comment for rationale.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it (MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set).
|
||||
# Without that signal, the sidecar adds zero value and risks aborting
|
||||
# tenant boot via the 30s health gate when the tenant Postgres lacks
|
||||
# pgvector. Caught on staging redeploy 2026-05-05:
|
||||
# pq: extension "vector" is not available
|
||||
#
|
||||
# Defaults (when sidecar IS spawned): MEMORY_PLUGIN_DATABASE_URL
|
||||
# falls back to the tenant's DATABASE_URL.
|
||||
MEMORY_PLUGIN_PID=""
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
/memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for /v1/health. Boot failure is fatal so a misconfigured
|
||||
# tenant crash-loops instead of silently serving cutover traffic against
|
||||
# a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check DATABASE_URL reachability + pgvector extension + migrations." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
kill "$CANVAS_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
# Start Go platform in foreground-ish (we trap signals)
|
||||
# CANVAS_PROXY_URL tells the platform to proxy unmatched routes to Canvas.
|
||||
# CONTAINER_BACKEND: empty = Docker (default for self-hosted/local).
|
||||
@@ -29,15 +74,20 @@ cd /
|
||||
/platform &
|
||||
PLATFORM_PID=$!
|
||||
|
||||
# If either process exits, kill the other
|
||||
# If any process exits, kill the others
|
||||
cleanup() {
|
||||
kill $CANVAS_PID 2>/dev/null || true
|
||||
kill $PLATFORM_PID 2>/dev/null || true
|
||||
[ -n "$MEMORY_PLUGIN_PID" ] && kill $MEMORY_PLUGIN_PID 2>/dev/null || true
|
||||
}
|
||||
trap cleanup EXIT SIGTERM SIGINT
|
||||
|
||||
# Wait for either to exit — whichever exits first triggers cleanup
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
# Wait for any to exit — whichever exits first triggers cleanup
|
||||
if [ -n "$MEMORY_PLUGIN_PID" ]; then
|
||||
wait -n $CANVAS_PID $PLATFORM_PID $MEMORY_PLUGIN_PID
|
||||
else
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
fi
|
||||
EXIT_CODE=$?
|
||||
cleanup
|
||||
exit $EXIT_CODE
|
||||
|
||||
@@ -600,14 +600,21 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]uploadedFile, 0, len(headers))
|
||||
// Phase 1: pre-validate + read every part BEFORE any DB write.
|
||||
// A multi-file upload must commit all-or-nothing; a per-file
|
||||
// failure halfway through used to leave rows 1..K-1 in the table
|
||||
// while the client got a 500 and retried the whole batch — duplicate
|
||||
// rows, orphan activity rows. Validating up-front + atomic PutBatch
|
||||
// closes that gap.
|
||||
type prepped struct {
|
||||
Sanitized string
|
||||
Mimetype string
|
||||
Content []byte
|
||||
Original string // original (unsanitized) filename for error messages
|
||||
}
|
||||
prepReady := make([]prepped, 0, len(headers))
|
||||
items := make([]pendinguploads.PutItem, 0, len(headers))
|
||||
for _, fh := range headers {
|
||||
// Read full content. Per-file cap enforced post-read so an
|
||||
// oversized file fails with a clean 413 rather than a torn
|
||||
// stream. The +1 byte ReadAll trick that the Python side
|
||||
// uses isn't easy through multipart.FileHeader; instead we
|
||||
// rely on the multipart layer's ContentLength header and
|
||||
// short-circuit before opening the part.
|
||||
if fh.Size > pendinguploads.MaxFileBytes {
|
||||
log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)",
|
||||
workspaceID, fh.Filename, fh.Size)
|
||||
@@ -621,45 +628,67 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
}
|
||||
content, err := readMultipartFile(fh)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", workspaceID, fh.Filename, err)
|
||||
log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v",
|
||||
workspaceID, fh.Filename, err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"})
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := SanitizeFilename(fh.Filename)
|
||||
mimetype := fh.Header.Get("Content-Type")
|
||||
|
||||
fileID, err := h.pendingUploads.Put(ctx, wsUUID, content, sanitized, mimetype)
|
||||
if err != nil {
|
||||
if errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
// Belt + suspenders: the size check above already
|
||||
// caught this, but Storage.Put re-validates so a
|
||||
// malformed FileHeader can't slip through. 413 with
|
||||
// the same shape so the client sees one error class.
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "file exceeds per-file cap",
|
||||
"filename": fh.Filename,
|
||||
"size": len(content),
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
log.Printf("chat_files uploadPollMode: storage.Put failed for %s/%s: %v",
|
||||
workspaceID, sanitized, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage file"})
|
||||
// Belt-and-braces post-read cap (multipart.FileHeader.Size can lie
|
||||
// on some clients that don't set Content-Length per part).
|
||||
if len(content) > pendinguploads.MaxFileBytes {
|
||||
log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)",
|
||||
workspaceID, fh.Filename, len(content))
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "file exceeds per-file cap",
|
||||
"filename": fh.Filename,
|
||||
"size": len(content),
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
sanitized := SanitizeFilename(fh.Filename)
|
||||
mimetype := safeMimetype(fh.Header.Get("Content-Type"))
|
||||
prepReady = append(prepReady, prepped{
|
||||
Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename,
|
||||
})
|
||||
items = append(items, pendinguploads.PutItem{
|
||||
Content: content, Filename: sanitized, Mimetype: mimetype,
|
||||
})
|
||||
}
|
||||
|
||||
// Activity row so the workspace's inbox poller picks this up
|
||||
// on its next cycle. activity_type=a2a_receive (NOT a new
|
||||
// type) so the existing poll filter
|
||||
// `?type=a2a_receive` catches it without poll-side changes;
|
||||
// method=chat_upload_receive is the discriminator the
|
||||
// workspace's adapter (Phase 2) uses to route to the upload
|
||||
// fetcher instead of the agent's message handler. Same
|
||||
// shape as A2A's tasks/send vs message/send method split.
|
||||
// Phase 2: atomic batch insert. On failure no rows commit.
|
||||
fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items)
|
||||
if err != nil {
|
||||
if errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
// Belt + suspenders: pre-validation above already caught
|
||||
// this; surface a clean 413 if a malformed FileHeader
|
||||
// somehow slipped through.
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "one or more files exceed per-file cap",
|
||||
"max": pendinguploads.MaxFileBytes,
|
||||
})
|
||||
return
|
||||
}
|
||||
log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v",
|
||||
workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 3: write per-file activity rows and build the response. Activity
|
||||
// rows are written individually (not part of the same Tx as PutBatch)
|
||||
// because LogActivity is shared across many handlers and threading the
|
||||
// Tx through would be a bigger refactor. The trade-off: if an activity
|
||||
// write fails after the PutBatch commits, the pending_uploads rows
|
||||
// orphan until the 24h TTL — significantly better than the previous
|
||||
// "every multi-file upload could orphan" behavior, and the workspace's
|
||||
// fetcher handles soft-404 cleanly when activity rows reference a row
|
||||
// the platform later expired.
|
||||
out := make([]uploadedFile, 0, len(prepReady))
|
||||
for i, p := range prepReady {
|
||||
fileID := fileIDs[i]
|
||||
uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID)
|
||||
summary := "chat_upload_receive: " + sanitized
|
||||
summary := "chat_upload_receive: " + p.Sanitized
|
||||
method := "chat_upload_receive"
|
||||
LogActivity(ctx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -669,28 +698,65 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
Summary: &summary,
|
||||
RequestBody: map[string]interface{}{
|
||||
"file_id": fileID.String(),
|
||||
"name": sanitized,
|
||||
"mimeType": mimetype,
|
||||
"size": len(content),
|
||||
"name": p.Sanitized,
|
||||
"mimeType": p.Mimetype,
|
||||
"size": len(p.Content),
|
||||
"uri": uri,
|
||||
},
|
||||
Status: "ok",
|
||||
})
|
||||
|
||||
log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)",
|
||||
workspaceID, sanitized, fileID, len(content), mimetype)
|
||||
workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype)
|
||||
|
||||
out = append(out, uploadedFile{
|
||||
URI: uri,
|
||||
Name: sanitized,
|
||||
Mimetype: mimetype,
|
||||
Size: int64(len(content)),
|
||||
Name: p.Sanitized,
|
||||
Mimetype: p.Mimetype,
|
||||
Size: int64(len(p.Content)),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"files": out})
|
||||
}
|
||||
|
||||
// safeMimetype validates a multipart-supplied Content-Type header and
|
||||
// returns a sanitized value safe to store + serve back unmodified.
|
||||
//
|
||||
// The platform's GET /content handler reflects the stored mimetype as
|
||||
// the response Content-Type. An attacker-controlled header that
|
||||
// embedded CR/LF could split the response (header injection); a value
|
||||
// containing semicolons could carry an unexpected charset parameter
|
||||
// that confuses a downstream renderer. Strip CR/LF/control chars +
|
||||
// keep only the type/subtype prefix; reject anything that doesn't
|
||||
// match a basic `type/subtype` regex by falling back to the safe
|
||||
// default (application/octet-stream — the workspace-side handler does
|
||||
// the same fallback).
|
||||
func safeMimetype(raw string) string {
|
||||
const fallback = "application/octet-stream"
|
||||
// Trim parameters (`text/html; charset=utf-8` → `text/html`).
|
||||
if i := strings.IndexByte(raw, ';'); i >= 0 {
|
||||
raw = raw[:i]
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
// Reject if any control char or whitespace is present (header
|
||||
// injection defense). RFC 7231 mimetype grammar forbids whitespace.
|
||||
for _, r := range raw {
|
||||
if r < 0x21 || r > 0x7e {
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
// Require exactly one slash separating type and subtype.
|
||||
parts := strings.Split(raw, "/")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return fallback
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// readMultipartFile reads a multipart part fully into memory. Wraps
|
||||
// the open + io.ReadAll + close idiom so the call site stays clean,
|
||||
// and so a future change (chunked reads / hashing) has one place to
|
||||
|
||||
@@ -67,6 +67,46 @@ func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, file
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// PutBatch mirrors the production atomic-batch contract: any per-item
|
||||
// failure leaves the in-memory state unchanged, simulating Tx rollback.
|
||||
// Pre-validation matches PostgresStorage.PutBatch; oversized items
|
||||
// return ErrTooLarge before any row is added.
|
||||
func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.putErr != nil {
|
||||
return nil, s.putErr
|
||||
}
|
||||
// Pre-validate so an oversized item rejects the whole batch before
|
||||
// any state mutation — matches the Tx-rollback semantics.
|
||||
for _, it := range items {
|
||||
if len(it.Content) > pendinguploads.MaxFileBytes {
|
||||
return nil, pendinguploads.ErrTooLarge
|
||||
}
|
||||
}
|
||||
ids := make([]uuid.UUID, 0, len(items))
|
||||
stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items))
|
||||
stagedPuts := make([]putCall, 0, len(items))
|
||||
for _, it := range items {
|
||||
id := uuid.New()
|
||||
stagedRows[id] = pendinguploads.Record{
|
||||
FileID: id, WorkspaceID: ws, Content: it.Content,
|
||||
Filename: it.Filename, Mimetype: it.Mimetype,
|
||||
SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
stagedPuts = append(stagedPuts, putCall{
|
||||
WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content),
|
||||
})
|
||||
ids = append(ids, id)
|
||||
}
|
||||
for id, r := range stagedRows {
|
||||
s.rows[id] = r
|
||||
}
|
||||
s.puts = append(s.puts, stagedPuts...)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) {
|
||||
return pendinguploads.Record{}, pendinguploads.ErrNotFound
|
||||
}
|
||||
@@ -161,7 +201,7 @@ func TestPollUpload_HappyPath_OneFile_StagesAndLogs(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"report.pdf": []byte("PDF-bytes")})
|
||||
@@ -219,7 +259,7 @@ func TestPollUpload_MultipleFiles_AllStagedAndLogged(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
@@ -257,7 +297,7 @@ func TestPollUpload_PushModeFallsThroughToForward(t *testing.T) {
|
||||
// URL empty + mode=push → 503 (no inbound secret check needed).
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
@@ -281,7 +321,7 @@ func TestPollUpload_NotConfigured_FallsThrough(t *testing.T) {
|
||||
wsID := "33333333-2222-3333-4444-555555555555"
|
||||
expectURLAndMode(mock, wsID, "", "poll") // resolveWorkspaceForwardCreds emits 422
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
// No WithPendingUploads — pendingUploads is nil.
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
@@ -302,7 +342,7 @@ func TestPollUpload_WorkspaceMissing_404(t *testing.T) {
|
||||
wsID := "44444444-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryModeMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
@@ -322,7 +362,7 @@ func TestPollUpload_DeliveryModeLookupDBError_500(t *testing.T) {
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).WillReturnError(errors.New("connection lost"))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
@@ -342,7 +382,7 @@ func TestPollUpload_NoFilesField_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Multipart with a non-files field — no actual files.
|
||||
@@ -367,7 +407,7 @@ func TestPollUpload_MalformedMultipart_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Body that doesn't match the boundary in Content-Type.
|
||||
@@ -388,7 +428,7 @@ func TestPollUpload_StorageError_500(t *testing.T) {
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("disk full")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -409,7 +449,7 @@ func TestPollUpload_StorageTooLarge_413(t *testing.T) {
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = pendinguploads.ErrTooLarge
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -429,7 +469,7 @@ func TestPollUpload_TooManyFiles_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 65 files — over the per-batch cap.
|
||||
@@ -464,7 +504,7 @@ func TestPollUpload_NullDeliveryMode_TreatedAsPush(t *testing.T) {
|
||||
expectURLAndMode(mock, wsID, "", "")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -497,7 +537,7 @@ func TestPollUpload_PerFileCapPreStorage_413(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 25 MB + 1 byte. Single file, large enough to trip the early
|
||||
@@ -532,7 +572,7 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"hello world!.pdf": []byte("data")})
|
||||
@@ -557,6 +597,120 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the
|
||||
// transactional contract introduced in phase 5: when one file in a
|
||||
// multi-file batch fails pre-validation (oversize), NONE of the files
|
||||
// in the batch land in storage. Previously a per-file Put loop would
|
||||
// stage rows 1..K-1 before failing on row K, leaving orphan
|
||||
// pending_uploads + activity rows the client would re-create on retry.
|
||||
//
|
||||
// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's
|
||||
// Tx-rollback behavior on a per-item validation failure) — but the
|
||||
// real atomicity guarantee is the integration test in
|
||||
// pending_uploads_integration_test.go.
|
||||
func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "aaaaaaaa-3333-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Two files: first OK, second over the per-file cap. Pre-validation
|
||||
// in uploadPollMode catches it BEFORE any Put — store.puts must
|
||||
// stay empty. (If the test ever sees len=1, the regression is
|
||||
// "first file slipped through into storage on a partial-failure
|
||||
// batch.")
|
||||
tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1)
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
"ok.txt": []byte("small"),
|
||||
"huge.bin": tooBig,
|
||||
})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String())
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in-
|
||||
// flight PutBatch failure (e.g. simulated DB error) leaves zero rows
|
||||
// — same guarantee as the pre-validation path, but exercises the
|
||||
// "Tx-Rollback after BEGIN" branch via the fake.
|
||||
func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
wsID := "bbbbbbbb-3333-3333-4444-555555555555"
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("db down mid-batch")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
"a.txt": []byte("aaa"),
|
||||
"b.txt": []byte("bbb"),
|
||||
"c.txt": []byte("ccc"),
|
||||
})
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("status=%d, want 500", w.Code)
|
||||
}
|
||||
if len(store.puts) != 0 {
|
||||
t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype
|
||||
// hardening: a multipart-supplied Content-Type header with CR/LF is
|
||||
// rewritten to application/octet-stream so the eventual /content
|
||||
// response can't be header-split on the wire.
|
||||
func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) {
|
||||
got := safeMimetype("text/html\r\nX-Injected: pwn")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("CRLF mimetype not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("image/png\x00")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("NUL byte mimetype not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("text/plain; charset=utf-8")
|
||||
if got != "text/plain" {
|
||||
t.Errorf("parameter not stripped, got %q", got)
|
||||
}
|
||||
got = safeMimetype("application/pdf")
|
||||
if got != "application/pdf" {
|
||||
t.Errorf("clean mime modified, got %q", got)
|
||||
}
|
||||
got = safeMimetype("")
|
||||
if got != "" {
|
||||
t.Errorf("empty input should pass through, got %q", got)
|
||||
}
|
||||
got = safeMimetype("notamime")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("non-type/subtype not coerced, got %q", got)
|
||||
}
|
||||
got = safeMimetype("/empty-type")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("missing type half not coerced, got %q", got)
|
||||
}
|
||||
got = safeMimetype("type/")
|
||||
if got != "application/octet-stream" {
|
||||
t.Errorf("missing subtype half not coerced, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPollUpload_ActivityRowDiscriminator pins the
|
||||
// activity_type / method shape that the workspace inbox poller depends
|
||||
// on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive`
|
||||
@@ -580,7 +734,7 @@ func TestPollUpload_ActivityRowDiscriminator(t *testing.T) {
|
||||
expectActivityInsertWithTypeAndMethod(mock, wsID, "a2a_receive", "chat_upload_receive")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.pdf": []byte("xx")})
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestChatUpload_InvalidWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
c, w := makeUploadRequest(t, "not-a-uuid", &bytes.Buffer{}, "")
|
||||
h.Upload(c)
|
||||
@@ -122,7 +122,7 @@ func TestChatUpload_WorkspaceNotInDB(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -166,7 +166,7 @@ func TestChatUpload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -203,7 +203,7 @@ func TestChatUpload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone) // mint fails
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -231,7 +231,7 @@ func TestChatUpload_NoURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000042"
|
||||
expectURLAndMode(mock, wsID, "", "push")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -256,7 +256,7 @@ func TestChatUpload_PollModeEmptyURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLAndMode(mock, wsID, "", "poll")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -286,7 +286,7 @@ func TestChatUpload_NullModeEmptyURL(t *testing.T) {
|
||||
wsID := "30ba7f0b-b303-4a20-aefe-3a4a675b8aa4" // user's "mac laptop"
|
||||
expectURLNullMode(mock, wsID, "")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -338,7 +338,7 @@ func TestChatUpload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "super-secret-123")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -380,7 +380,7 @@ func TestChatUpload_ForwardsErrorStatusUnchanged(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -402,7 +402,7 @@ func TestChatUpload_WorkspaceUnreachable(t *testing.T) {
|
||||
expectURL(mock, wsID, "http://127.0.0.1:1")
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -418,7 +418,7 @@ func TestChatDownload_InvalidPath(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
cases := []struct {
|
||||
name, path, wantSubstr string
|
||||
@@ -507,7 +507,7 @@ func TestChatDownload_WorkspaceNotInDB(t *testing.T) {
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -533,7 +533,7 @@ func TestChatDownload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -559,7 +559,7 @@ func TestChatDownload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -592,7 +592,7 @@ func TestChatDownload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "the-secret")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/report.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestChatDownload_404FromWorkspacePropagated(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/missing.txt")
|
||||
h.Download(c)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/scheduler"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -61,7 +62,21 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
tier = defaults.Tier
|
||||
}
|
||||
if tier == 0 {
|
||||
tier = 2
|
||||
// Resolved via the same DefaultTier helper Create + Templates
|
||||
// use (#2910 PR-E). SaaS → T4 (one container per sibling EC2,
|
||||
// no neighbour to protect from), self-hosted → T3. Pre-#2910
|
||||
// this path returned T2 on self-hosted, asymmetric with
|
||||
// workspace.go's T3 — undocumented drift. Lifting to
|
||||
// DefaultTier collapses both call sites onto one source of
|
||||
// truth so a future tier-default change sweeps every entry
|
||||
// point at once. Templates that want a different floor still
|
||||
// declare `tier:` in config.yaml or `defaults.tier` in
|
||||
// org.yaml.
|
||||
if h.workspace != nil {
|
||||
tier = h.workspace.DefaultTier()
|
||||
} else {
|
||||
tier = 3
|
||||
}
|
||||
}
|
||||
|
||||
ctxLookup := context.Background()
|
||||
@@ -82,6 +97,16 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
}
|
||||
if existing {
|
||||
log.Printf("Org import: %q already exists (id=%s) — skipping create+provision, recursing into children for partial-match", ws.Name, existingID)
|
||||
parentRef := ""
|
||||
if parentID != nil {
|
||||
parentRef = *parentID
|
||||
}
|
||||
provlog.Event("provision.skip_existing", map[string]any{
|
||||
"name": ws.Name,
|
||||
"existing_id": existingID,
|
||||
"parent_id": parentRef,
|
||||
"tier": tier,
|
||||
})
|
||||
*results = append(*results, map[string]interface{}{
|
||||
"id": existingID,
|
||||
"name": ws.Name,
|
||||
|
||||
@@ -44,6 +44,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -273,6 +274,183 @@ func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the
|
||||
// "all rows commit" leg of the PutBatch atomicity contract against a real
|
||||
// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery
|
||||
// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only
|
||||
// COUNT(*) on the real table can.
|
||||
func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
|
||||
// Pre-existing row so the COUNT(*) baseline is non-zero — proves
|
||||
// PutBatch adds rows incrementally rather than overwriting.
|
||||
if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil {
|
||||
t.Fatalf("seed Put: %v", err)
|
||||
}
|
||||
|
||||
items := []pendinguploads.PutItem{
|
||||
{Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"},
|
||||
{Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"},
|
||||
}
|
||||
ids, err := store.PutBatch(ctx, wsID, items)
|
||||
if err != nil {
|
||||
t.Fatalf("PutBatch: %v", err)
|
||||
}
|
||||
if len(ids) != len(items) {
|
||||
t.Fatalf("ids length %d, want %d", len(ids), len(items))
|
||||
}
|
||||
|
||||
// Each returned id round-trips through Get with the right content.
|
||||
for i, id := range ids {
|
||||
rec, err := store.Get(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Get item %d (%s): %v", i, id, err)
|
||||
}
|
||||
if string(rec.Content) != string(items[i].Content) {
|
||||
t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content)
|
||||
}
|
||||
if rec.Filename != items[i].Filename {
|
||||
t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename)
|
||||
}
|
||||
}
|
||||
|
||||
var n int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 4 {
|
||||
t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure
|
||||
// proves the all-or-nothing contract end-to-end against real Postgres MVCC.
|
||||
//
|
||||
// Strategy: build a 3-item batch where item index 1 carries a filename with
|
||||
// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol
|
||||
// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which
|
||||
// triggers the per-row INSERT error path in PutBatch. The first item's
|
||||
// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy
|
||||
// rollback would leave that row visible after PutBatch returns.
|
||||
//
|
||||
// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the
|
||||
// "no leak" contract; a unit test with sqlmock can prove the Go function
|
||||
// CALLED Rollback, but only this integration test proves Postgres actually
|
||||
// HONORED it.
|
||||
func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
|
||||
// Baseline COUNT(*) for this workspace — must remain 0 after a failed batch.
|
||||
var before int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil {
|
||||
t.Fatalf("baseline count: %v", err)
|
||||
}
|
||||
if before != 0 {
|
||||
t.Fatalf("workspace not isolated: baseline = %d, want 0", before)
|
||||
}
|
||||
|
||||
// Item 1 has a NUL byte in the filename — Go-side pre-validation
|
||||
// (which only checks empty/length) lets it through, so the INSERT
|
||||
// reaches lib/pq, which rejects it at the protocol level. That's the
|
||||
// canonical "DB-side error mid-batch" we want to exercise.
|
||||
items := []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"},
|
||||
}
|
||||
_, err := store.PutBatch(ctx, wsID, items)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error from NUL-byte filename, got nil")
|
||||
}
|
||||
|
||||
// THE assertion this whole test exists for: even though item 0's
|
||||
// INSERT…RETURNING succeeded inside the Tx, the rollback unwound
|
||||
// it — zero rows for this workspace, not one (let alone three).
|
||||
var after int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil {
|
||||
t.Fatalf("post-failure count: %v", err)
|
||||
}
|
||||
if after != 0 {
|
||||
t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the
|
||||
// pre-validation short-circuit: an oversized item rejects with ErrTooLarge
|
||||
// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock
|
||||
// with zero expectations) catches the Go-side path; this test sanity-checks
|
||||
// no real DB I/O happens by confirming COUNT(*) doesn't move.
|
||||
func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
ctx := context.Background()
|
||||
|
||||
wsID := uuid.New()
|
||||
tooBig := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
_, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "ok.txt"},
|
||||
{Content: tooBig, Filename: "too-big.bin"},
|
||||
})
|
||||
if err != pendinguploads.ErrTooLarge {
|
||||
t.Fatalf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
var n int
|
||||
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
|
||||
t.Fatalf("count: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a
|
||||
// migration (20260505200000_pending_uploads_acked_index.up.sql) actually
|
||||
// created idx_pending_uploads_acked with the right partial-index predicate.
|
||||
//
|
||||
// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny
|
||||
// tables regardless of available indexes — a plan-shape check would be
|
||||
// flaky under real test loads. The contract we care about is "the index
|
||||
// exists with the predicate we wrote in the migration"; pg_indexes is
|
||||
// the canonical source for that, robust to row count and planner version.
|
||||
func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
ctx := context.Background()
|
||||
|
||||
var indexdef string
|
||||
err := conn.QueryRowContext(ctx, `
|
||||
SELECT indexdef FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = 'pending_uploads'
|
||||
AND indexname = 'idx_pending_uploads_acked'
|
||||
`).Scan(&indexdef)
|
||||
if err == sql.ErrNoRows {
|
||||
t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("pg_indexes query: %v", err)
|
||||
}
|
||||
|
||||
// Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL"
|
||||
// we'd be indexing the entire table (defeats the point — most rows are
|
||||
// unacked), and the existing idx_pending_uploads_unacked already covers
|
||||
// the inverse predicate.
|
||||
if !strings.Contains(indexdef, "(acked_at)") {
|
||||
t.Errorf("index missing acked_at column: %s", indexdef)
|
||||
}
|
||||
if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") {
|
||||
t.Errorf("index missing partial predicate: %s", indexdef)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) {
|
||||
conn := integrationDB_PendingUploads(t)
|
||||
store := pendinguploads.NewPostgres(conn)
|
||||
|
||||
@@ -77,6 +77,14 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads.
|
||||
return pendinguploads.SweepResult{}, nil
|
||||
}
|
||||
|
||||
// PutBatch is required by the Storage interface; the upload handler
|
||||
// tests live in chat_files_poll_test.go and use a separate fake
|
||||
// (inMemStorage). Stubbed here because the Get/Ack tests don't drive
|
||||
// PutBatch, but the interface must be satisfied.
|
||||
func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package handlers
|
||||
|
||||
// provlog_emit_test.go — pins that the structured-logging emit sites
|
||||
// added for #2867 PR-D actually fire when their boundary is crossed.
|
||||
//
|
||||
// These are call-site contract tests, not provlog package tests (those
|
||||
// live next to the helper). The assertion is "this dispatcher path
|
||||
// emits this event name" — if a refactor moves the call out of the
|
||||
// boundary helper, the gate fails. Fields are NOT pinned here on
|
||||
// purpose; the field set is convenience for ops, not contract for the
|
||||
// emit point. Pinning fields would block additive evolution of the
|
||||
// payload (see also feedback_behavior_based_ast_gates.md).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// captureProvLog redirects the global logger to a buffer for the test
|
||||
// duration. provlog.Event uses log.Printf, so this is the only seam.
|
||||
// Returned mutex protects against concurrent reads from the goroutine
|
||||
// fired by provisionWorkspaceAuto (the goroutine never returns in
|
||||
// these tests because Start() is stubbed, but the buffer can still be
|
||||
// touched by it racing the assertion).
|
||||
func captureProvLog(t *testing.T) (read func() string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
var mu sync.Mutex
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetFlags(0)
|
||||
log.SetOutput(&safeWriter{buf: &buf, mu: &mu})
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
return func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionWorkspaceAutoSync_EmitsProvisionStart — sync variant is
|
||||
// chosen for the assertion path because it returns once the (stubbed)
|
||||
// Start() has been called, so we know the emit has flushed. The async
|
||||
// variant would race a goroutine.
|
||||
func TestProvisionWorkspaceAutoSync_EmitsProvisionStart(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
// Best-effort: the body will hit DB code under provisionWorkspaceCP
|
||||
// — we only need the emit at the entry, which fires unconditionally
|
||||
// before the dispatch. Recovering from any later panic keeps the
|
||||
// test focused.
|
||||
defer func() { _ = recover() }()
|
||||
h.provisionWorkspaceAutoSync("ws-test-1", "tmpl", nil, models.CreateWorkspacePayload{
|
||||
Name: "n", Tier: 4, Runtime: "claude-code",
|
||||
})
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: provision.start ") {
|
||||
t.Fatalf("expected provision.start emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-test-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"sync":true`) {
|
||||
t.Errorf("sync flag not pinned for sync dispatcher: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsRestartPreStop — emit fires before the actual
|
||||
// Stop call, so the trackingCPProv stub doesn't need to be wired for
|
||||
// real Stop semantics. Backend label "cp" pinned because that's the
|
||||
// SaaS path; we don't pin "docker" or "none" branches here (separate
|
||||
// tests would only re-test the trivial branch label switch).
|
||||
func TestStopForRestart_EmitsRestartPreStop(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
defer func() { _ = recover() }()
|
||||
h.stopForRestart(context.Background(), "ws-restart-1")
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: restart.pre_stop ") {
|
||||
t.Fatalf("expected restart.pre_stop emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-restart-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"backend":"cp"`) {
|
||||
t.Errorf("backend label missing or wrong: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsBackendNoneWhenUnwired — pin the no-backend
|
||||
// branch so a future refactor that drops the label switch is caught.
|
||||
// This is the silent-Stop case (workspace_dispatchers.go:StopWorkspaceAuto
|
||||
// returns nil for unwired backends); the emit ensures the operator can
|
||||
// still see the boundary in the log.
|
||||
func TestStopForRestart_EmitsBackendNoneWhenUnwired(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{} // both nil
|
||||
h.stopForRestart(context.Background(), "ws-restart-2")
|
||||
got := read()
|
||||
if !strings.Contains(got, `"backend":"none"`) {
|
||||
t.Fatalf("expected backend=none for unwired handler: %s", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
// Tests for the SaaS-aware default-tier resolution introduced in #2901
|
||||
// and hardened in #2910 (multi-model review of #2901 found the original
|
||||
// claim of "all green" was passing because no SaaS-mode test existed).
|
||||
//
|
||||
// These tests pin three invariants:
|
||||
//
|
||||
// 1. WorkspaceHandler.IsSaaS() returns true when cpProv is wired,
|
||||
// false otherwise.
|
||||
// 2. WorkspaceHandler.DefaultTier() returns 4 on SaaS, 3 self-hosted.
|
||||
// 3. generateDefaultConfig (TemplatesHandler.Import path) writes the
|
||||
// passed-in tier into the generated config.yaml — pre-#2910 it
|
||||
// was hardcoded to 3 and silently disagreed with the create-
|
||||
// handler default on SaaS.
|
||||
|
||||
// stubCPProv is a minimal stand-in for the CP provisioner — only
|
||||
// exercises the IsSaaS / HasProvisioner contract, never invoked in
|
||||
// these tests.
|
||||
type stubCPProv struct{}
|
||||
|
||||
func (stubCPProv) Start(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (stubCPProv) Stop(_ interface{}, _ string) error { return nil }
|
||||
func (stubCPProv) Restart(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestIsSaaS_TrueWhenCPProvWired(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if !h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=false with cpProv wired; expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSaaS_FalseWhenOnlyDocker(t *testing.T) {
|
||||
// provisioner field set, cpProv nil — the self-hosted path.
|
||||
// Use a non-nil sentinel so the check actually has something to
|
||||
// disagree with. trackingCPProv lives in workspace_provision_auto_test.go
|
||||
// and is the established stub for these handler-level tests.
|
||||
h := &WorkspaceHandler{provisioner: nil, cpProv: nil}
|
||||
if h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=true with both backends nil; expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SaaS_IsT4(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if got := h.DefaultTier(); got != 4 {
|
||||
t.Errorf("SaaS DefaultTier()=%d; expected 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SelfHosted_IsT3(t *testing.T) {
|
||||
h := &WorkspaceHandler{}
|
||||
if got := h.DefaultTier(); got != 3 {
|
||||
t.Errorf("self-hosted DefaultTier()=%d; expected 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
// generateDefaultConfig — pin that the tier param flows into the
|
||||
// emitted config.yaml verbatim. Pre-#2910 this was hardcoded "tier: 3"
|
||||
// regardless of caller intent.
|
||||
func TestGenerateDefaultConfig_RespectsTierParam(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 4)
|
||||
if !strings.Contains(cfg, "tier: 4\n") {
|
||||
t.Errorf("expected `tier: 4` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
// The pre-#2910 hardcoded `tier: 3` line must NOT appear.
|
||||
if strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("config should not contain `tier: 3` when caller passed 4, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDefaultConfig_SelfHostedTierT3(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 3)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("expected `tier: 3` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// Bounds check — caller passes 0 or out-of-range, helper falls back
|
||||
// to T3 (the safer-of-the-two when deployment mode can't be resolved).
|
||||
func TestGenerateDefaultConfig_OutOfRangeFallsBackToT3(t *testing.T) {
|
||||
for _, tier := range []int{0, -1, 99} {
|
||||
cfg := generateDefaultConfig("X", map[string]string{}, tier)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("invalid tier %d should fall back to T3, got:\n%s", tier, cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func TestSecurity_GetTemplates_NoAuth_Returns401(t *testing.T) {
|
||||
authDB, authMock := newEnrolledAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
@@ -98,7 +98,7 @@ func TestSecurity_GetTemplates_FreshInstall_FailsOpen(t *testing.T) {
|
||||
authDB, authMock := newFreshInstallAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TeamHandler now hosts only Collapse — the visual "expand" action is
|
||||
// canvas-side and creating children goes through the regular
|
||||
// WorkspaceHandler.Create path with parent_id set, like any other
|
||||
// workspace. Every workspace can have children; "team" is just the
|
||||
// state of having children. The old Expand handler bulk-created
|
||||
// children by reading sub_workspaces from a parent's config and was
|
||||
// non-idempotent — calling it N times leaked N×children EC2s, which
|
||||
// is how tenant-hongming accumulated 72 stale workspaces.
|
||||
type TeamHandler struct {
|
||||
wh *WorkspaceHandler
|
||||
b *events.Broadcaster
|
||||
}
|
||||
|
||||
// NewTeamHandler constructs a TeamHandler. wh is used by Collapse to
|
||||
// route StopWorkspaceAuto through the backend dispatcher.
|
||||
func NewTeamHandler(b *events.Broadcaster, wh *WorkspaceHandler, platformURL, configsDir string) *TeamHandler {
|
||||
return &TeamHandler{wh: wh, b: b}
|
||||
}
|
||||
|
||||
// Collapse handles POST /workspaces/:id/collapse
|
||||
// Stops and removes all child workspaces.
|
||||
func (h *TeamHandler) Collapse(c *gin.Context) {
|
||||
parentID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find children
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, parentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to query children"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
removed := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var childID, childName string
|
||||
if rows.Scan(&childID, &childName) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Stop the workload via the backend dispatcher (CP for SaaS,
|
||||
// Docker for self-hosted). Pre-2026-05-05 this was
|
||||
// `if h.provisioner != nil { h.provisioner.Stop(...) }`, which
|
||||
// silently skipped on every SaaS tenant — child EC2s kept running
|
||||
// after team-collapse until the orphan sweeper caught them
|
||||
// (issue #2813).
|
||||
if err := h.wh.StopWorkspaceAuto(ctx, childID); err != nil {
|
||||
log.Printf("Team collapse: stop %s failed: %v — orphan sweeper will reconcile", childID, err)
|
||||
}
|
||||
|
||||
// Mark as removed
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusRemoved, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to remove workspace %s: %v", childID, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`DELETE FROM canvas_layouts WHERE workspace_id = $1`, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to delete layout for %s: %v", childID, err)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", childID, map[string]interface{}{})
|
||||
|
||||
removed = append(removed, childName)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_COLLAPSED", parentID, map[string]interface{}{
|
||||
"removed_children": removed,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "collapsed",
|
||||
"removed": removed,
|
||||
})
|
||||
}
|
||||
|
||||
// findTemplateDirByName resolves a workspace name to its template
|
||||
// directory. Kept here because callers outside this package may use
|
||||
// it, even though the in-package consumer (Expand) is gone.
|
||||
//
|
||||
// TODO: relocate alongside the templates handler if no other callers
|
||||
// surface, or delete entirely after a deprecation cycle.
|
||||
func findTemplateDirByName(configsDir, name string) string {
|
||||
normalized := normalizeName(name)
|
||||
|
||||
candidate := filepath.Join(configsDir, normalized)
|
||||
if _, err := os.Stat(filepath.Join(candidate, "config.yaml")); err == nil {
|
||||
return candidate
|
||||
}
|
||||
|
||||
// Fall back to scanning all dirs
|
||||
entries, err := os.ReadDir(configsDir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cfgPath := filepath.Join(configsDir, e.Name(), "config.yaml")
|
||||
data, err := os.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var cfg struct {
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
if yaml.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ---------- TeamHandler: Collapse ----------
|
||||
|
||||
func TestTeamCollapse_NoChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// No children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "collapsed" {
|
||||
t.Errorf("expected status 'collapsed', got %v", resp["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamCollapse_WithChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// Two children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("child-1", "Worker A").
|
||||
AddRow("child-2", "Worker B"))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-1
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-2
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast for parent
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
removed, ok := resp["removed"].([]interface{})
|
||||
if !ok || len(removed) != 2 {
|
||||
t.Errorf("expected 2 removed children, got %v", resp["removed"])
|
||||
}
|
||||
}
|
||||
// ---------- findTemplateDirByName helper ----------
|
||||
|
||||
func TestFindTemplateDirByName_DirectMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
subDir := filepath.Join(dir, "mybot")
|
||||
os.MkdirAll(subDir, 0755)
|
||||
os.WriteFile(filepath.Join(subDir, "config.yaml"), []byte("name: MyBot"), 0644)
|
||||
|
||||
result := findTemplateDirByName(dir, "mybot")
|
||||
if result != subDir {
|
||||
t.Errorf("expected %s, got %s", subDir, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
result := findTemplateDirByName(dir, "nonexistent")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_InvalidConfigsDir(t *testing.T) {
|
||||
result := findTemplateDirByName("/nonexistent/path", "anything")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for invalid dir, got %s", result)
|
||||
}
|
||||
}
|
||||
@@ -36,8 +36,14 @@ func normalizeName(name string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files and skills.
|
||||
func generateDefaultConfig(name string, files map[string]string) string {
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files
|
||||
// and skills. tier is the deployment-aware default (caller passes
|
||||
// h.wh.DefaultTier() — T4 on SaaS, T3 on self-hosted) so the generated
|
||||
// file matches what POST /workspaces would default to. Pre-#2910 this
|
||||
// was hardcoded to 3, which split-brained with the create-handler
|
||||
// default on SaaS (T4) and pinned newly-imported templates at T3 even
|
||||
// when downstream Create paths picked T4.
|
||||
func generateDefaultConfig(name string, files map[string]string, tier int) string {
|
||||
promptFiles := []string{}
|
||||
skillSet := map[string]bool{}
|
||||
|
||||
@@ -74,9 +80,15 @@ func generateDefaultConfig(name string, files map[string]string) string {
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString(`name: "` + escaped + `"` + "\n")
|
||||
cfg.WriteString("description: Imported agent\n")
|
||||
// Default to tier 3 ("Privileged") — matches the workspace.go
|
||||
// create handler default. See its comment for rationale.
|
||||
cfg.WriteString("version: 1.0.0\ntier: 3\n")
|
||||
// Tier is SaaS-aware via the caller's DefaultTier (#2910 PR-B).
|
||||
// Bounds-checked: invalid input falls back to T3 (the historical
|
||||
// default + the safer-of-the-two when the deployment mode can't
|
||||
// be resolved).
|
||||
if tier < 1 || tier > 4 {
|
||||
tier = 3
|
||||
}
|
||||
cfg.WriteString("version: 1.0.0\n")
|
||||
cfg.WriteString(fmt.Sprintf("tier: %d\n", tier))
|
||||
cfg.WriteString("model: anthropic:claude-haiku-4-5-20251001\n")
|
||||
cfg.WriteString("\nprompt_files:\n")
|
||||
if len(promptFiles) > 0 {
|
||||
@@ -148,7 +160,11 @@ func (h *TemplatesHandler) Import(c *gin.Context) {
|
||||
|
||||
// Auto-generate config.yaml if not provided
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
cfg := generateDefaultConfig(body.Name, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(body.Name, body.Files, tier)
|
||||
if err := os.WriteFile(filepath.Join(destDir, "config.yaml"), []byte(cfg), 0600); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write config.yaml"})
|
||||
return
|
||||
@@ -227,7 +243,11 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
// Check if config.yaml exists in container
|
||||
if _, err := h.execInContainer(ctx, containerName, []string{"test", "-f", "/configs/config.yaml"}); err != nil {
|
||||
cfg := generateDefaultConfig(wsName, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(wsName, body.Files, tier)
|
||||
singleFile := map[string]string{"config.yaml": cfg}
|
||||
h.copyFilesToContainer(ctx, containerName, "/configs", singleFile)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGenerateDefaultConfig_WithFiles(t *testing.T) {
|
||||
"skills/review/templates.md": "Templates",
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Test Agent", files)
|
||||
cfg := generateDefaultConfig("Test Agent", files, 3)
|
||||
|
||||
// Name is emitted as a double-quoted scalar (#221 sanitizer).
|
||||
if !strings.Contains(cfg, `name: "Test Agent"`) {
|
||||
@@ -85,7 +85,7 @@ func TestGenerateDefaultConfig_Empty(t *testing.T) {
|
||||
"data/something.json": `{"key": "value"}`,
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Empty Agent", files)
|
||||
cfg := generateDefaultConfig("Empty Agent", files, 3)
|
||||
|
||||
if !strings.Contains(cfg, `name: "Empty Agent"`) {
|
||||
t.Errorf("config should contain quoted agent name, got:\n%s", cfg)
|
||||
@@ -134,7 +134,7 @@ func TestGenerateDefaultConfig_YAMLInjection(t *testing.T) {
|
||||
|
||||
for _, tc := range adversarialCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{})
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{}, 3)
|
||||
var parsed map[string]interface{}
|
||||
if err := yaml.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("sanitized config does not parse as YAML: %v\n--- config ---\n%s", err, cfg)
|
||||
@@ -205,7 +205,7 @@ func TestImport_Success(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "New Agent",
|
||||
@@ -245,7 +245,7 @@ func TestImport_MissingName(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
body := `{"files": {"test.md": "content"}}`
|
||||
|
||||
@@ -265,7 +265,7 @@ func TestImport_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -296,7 +296,7 @@ func TestImport_AlreadyExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.MkdirAll(filepath.Join(tmpDir, "existing-agent"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{"name": "Existing Agent", "files": {"test.md": "content"}}`
|
||||
|
||||
@@ -317,7 +317,7 @@ func TestImport_WithConfigYaml(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "Custom Agent",
|
||||
@@ -354,7 +354,7 @@ func TestReplaceFiles_MissingBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -373,7 +373,7 @@ func TestReplaceFiles_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -398,7 +398,7 @@ func TestReplaceFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
// ReplaceFiles now selects (name, instance_id, runtime) for the
|
||||
// restart-cascade. Match the full column list rather than just the
|
||||
@@ -429,7 +429,7 @@ func TestReplaceFiles_PathTraversal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-rf-pt").
|
||||
|
||||
@@ -31,10 +31,20 @@ const maxUploadFiles = 200
|
||||
type TemplatesHandler struct {
|
||||
configsDir string
|
||||
docker *client.Client
|
||||
// wh is used by Import and ReplaceFiles to call DefaultTier() so a
|
||||
// generated config.yaml's tier matches the SaaS-vs-self-hosted
|
||||
// boundary (#2910 PR-B). nil-tolerant — the field is unused when
|
||||
// the caller doesn't import templates that need a fresh config
|
||||
// generated.
|
||||
wh *WorkspaceHandler
|
||||
}
|
||||
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli}
|
||||
// NewTemplatesHandler constructs a TemplatesHandler. wh may be nil for
|
||||
// callers that only use the read-only template surfaces (List,
|
||||
// ReadFile, ListFiles). Import + ReplaceFiles need wh non-nil so the
|
||||
// generated config.yaml picks the SaaS-aware default tier.
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client, wh *WorkspaceHandler) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli, wh: wh}
|
||||
}
|
||||
|
||||
// modelSpec describes a single supported model on a template: its id (sent
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestTemplatesList_EmptyDir(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -99,7 +99,7 @@ skills:
|
||||
// Create a directory without config.yaml (should be skipped)
|
||||
os.MkdirAll(filepath.Join(tmpDir, "no-config"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -160,7 +160,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -237,7 +237,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -315,7 +315,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -434,7 +434,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -512,7 +512,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -555,7 +555,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -589,7 +589,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -661,7 +661,7 @@ skills: []
|
||||
log.SetOutput(&logBuf)
|
||||
defer log.SetOutput(prevOutput)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -698,7 +698,7 @@ func TestTemplatesList_NonexistentDir(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil)
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -723,7 +723,7 @@ func TestListFiles_InvalidRoot(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -748,7 +748,7 @@ func TestListFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-nonexist").
|
||||
@@ -775,7 +775,7 @@ func TestListFiles_FallbackToHost_NoTemplate(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil) // nil docker = no container
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil) // nil docker = no container
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-fallback").
|
||||
@@ -815,7 +815,7 @@ func TestListFiles_FallbackToHost_WithTemplate(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Test Agent\n"), 0644)
|
||||
os.WriteFile(filepath.Join(tmplDir, "system-prompt.md"), []byte("# prompt"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-tmpl").
|
||||
@@ -849,7 +849,7 @@ func TestReadFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -870,7 +870,7 @@ func TestReadFile_InvalidRoot(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -892,7 +892,7 @@ func TestReadFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nf").
|
||||
@@ -926,7 +926,7 @@ func TestReadFile_FallbackToHost_Success(t *testing.T) {
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Reader Agent\ntier: 1\n"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
// instance_id="" → SaaS branch skipped → falls through to local
|
||||
// Docker / template-dir host fallback (the only path the test
|
||||
@@ -967,7 +967,7 @@ func TestReadFile_FallbackToHost_NotFound(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nofile").
|
||||
@@ -999,7 +999,7 @@ func TestWriteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1023,7 +1023,7 @@ func TestWriteFile_InvalidBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1046,7 +1046,7 @@ func TestWriteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-wf-nf").
|
||||
@@ -1080,7 +1080,7 @@ func TestDeleteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1101,7 +1101,7 @@ func TestDeleteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-del-nf").
|
||||
@@ -1133,7 +1133,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
tmplDir := filepath.Join(tmpDir, "my-agent")
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("My Agent")
|
||||
|
||||
if result != tmplDir {
|
||||
@@ -1143,7 +1143,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
|
||||
func TestResolveTemplateDir_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("Nonexistent Agent")
|
||||
|
||||
if result != "" {
|
||||
@@ -1177,7 +1177,7 @@ func TestCWE78_DeleteFile_TraversalVariants(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -148,15 +148,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
id := uuid.New().String()
|
||||
awarenessNamespace := workspaceAwarenessNamespace(id)
|
||||
if payload.Tier == 0 {
|
||||
// Default to T3 ("Privileged"). T3 gives agents a read_write
|
||||
// workspace mount + Docker daemon access — the level most
|
||||
// templates need to do real work. Lower tiers (T1 sandboxed,
|
||||
// T2 standard) stay available as explicit opt-ins for
|
||||
// low-trust agents. Matches the Canvas CreateWorkspaceDialog
|
||||
// default for self-hosted hosts (SaaS defaults to T4 via
|
||||
// CreateWorkspaceDialog because each SaaS workspace runs on
|
||||
// its own sibling EC2).
|
||||
payload.Tier = 3
|
||||
// SaaS-aware default. SaaS → T4 (full host access; each
|
||||
// workspace runs on its own sibling EC2 so the tier boundary
|
||||
// is a Docker resource limit on the only container present —
|
||||
// no neighbour to protect from). Self-hosted → T3 (read-write
|
||||
// workspace mount + Docker daemon access, most templates'
|
||||
// baseline). Lower tiers (T1 sandboxed, T2 standard) remain
|
||||
// explicit opt-ins for low-trust agents. Matches the canvas
|
||||
// CreateWorkspaceDialog defaults so the API and the UI agree.
|
||||
payload.Tier = h.DefaultTier()
|
||||
}
|
||||
|
||||
// Detect runtime + default model from template config.yaml when the
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// HasProvisioner reports whether either backend (CP or local Docker) is
|
||||
@@ -49,6 +50,32 @@ func (h *WorkspaceHandler) HasProvisioner() bool {
|
||||
return h.cpProv != nil || h.provisioner != nil
|
||||
}
|
||||
|
||||
// IsSaaS reports whether the CP (EC2) provisioner is wired. Each SaaS
|
||||
// workspace runs on its own sibling EC2, so the per-workspace tier
|
||||
// boundary is a Docker resource limit applied to the only container
|
||||
// on that EC2 — there's no neighbour to protect from. Self-hosted
|
||||
// runs many workspaces in one Docker daemon on a single host, so
|
||||
// the tier-2-by-default safe-neighbour-share posture stays.
|
||||
//
|
||||
// Tier defaults across Create / OrgImport / canvas EmptyState branch
|
||||
// on IsSaaS so SaaS users get T4 (full host access) by default and
|
||||
// self-hosted users keep the lower-trust caps.
|
||||
func (h *WorkspaceHandler) IsSaaS() bool {
|
||||
return h.cpProv != nil
|
||||
}
|
||||
|
||||
// DefaultTier is the SaaS-aware default tier. T4 on SaaS (single
|
||||
// container per EC2 — full host access matches the boundary), T3 on
|
||||
// self-hosted (read-write workspace mount + Docker daemon access,
|
||||
// most templates' baseline). Callers default to this when the user
|
||||
// hasn't explicitly picked a tier.
|
||||
func (h *WorkspaceHandler) DefaultTier() int {
|
||||
if h.IsSaaS() {
|
||||
return 4
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker
|
||||
// for self-hosted) and starts provisioning in a goroutine. Returns true
|
||||
// when a backend was kicked off, false when neither is wired.
|
||||
@@ -75,6 +102,14 @@ func (h *WorkspaceHandler) HasProvisioner() bool {
|
||||
// lives in prepareProvisionContext (shared by both per-backend
|
||||
// goroutines).
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": false,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
@@ -110,6 +145,14 @@ func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath stri
|
||||
// Keep these two helpers in sync — when one grows a new arm (third
|
||||
// backend, retry semantics), the other should too.
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAutoSync(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": true,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -431,6 +432,16 @@ func coalesceRestart(workspaceID string, cycle func()) {
|
||||
// NPE'd before reaching the reprovision step — which is why every SaaS dead-
|
||||
// agent incident pre-this-fix required manual restart from canvas.
|
||||
func (h *WorkspaceHandler) stopForRestart(ctx context.Context, workspaceID string) {
|
||||
backend := "none"
|
||||
if h.provisioner != nil {
|
||||
backend = "docker"
|
||||
} else if h.cpProv != nil {
|
||||
backend = "cp"
|
||||
}
|
||||
provlog.Event("restart.pre_stop", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"backend": backend,
|
||||
})
|
||||
if h.provisioner != nil {
|
||||
h.provisioner.Stop(ctx, workspaceID)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestINSERTworkspacesAllowlist enumerates every function in this
|
||||
// package that emits an `INSERT INTO workspaces (` SQL literal, and
|
||||
// pins the result against an explicit allowlist. New entries fail the
|
||||
// build until a reviewer adds them — forcing the question "what
|
||||
// makes this INSERT idempotent?" at PR-review time, not after the
|
||||
// next bulk-create leak.
|
||||
//
|
||||
// Pairs with TestCreateWorkspaceTree_CallsLookupBeforeInsert (the
|
||||
// behavior pin for the one bulk path). Together they close the
|
||||
// regression class: this test catches "did a new function start
|
||||
// inserting workspaces?", that test catches "did the existing bulk
|
||||
// path drop its idempotency check?". Either fires immediately when
|
||||
// drift happens.
|
||||
//
|
||||
// Why allowlist rather than pure behavior gate (per memory
|
||||
// feedback_behavior_based_ast_gates.md): the bulk-create leak class
|
||||
// is small + stable (1 path today), and a behavior gate would have
|
||||
// to disambiguate "iterating a YAML array of workspaces" from the
|
||||
// many other `for ... range` patterns in a Create handler (config
|
||||
// lines, secrets map, channels). Type-info-aware AST analysis would
|
||||
// catch the YAML-iteration shape but is heavy. Allowlisting is the
|
||||
// minimum-viable pin: any PR that adds a new INSERT site is forced
|
||||
// to pause, add an entry here, and document the safety mechanism in
|
||||
// the comment alongside.
|
||||
//
|
||||
// RFC #2867 class 1.
|
||||
func TestINSERTworkspacesAllowlist(t *testing.T) {
|
||||
// expected[key] = safety mechanism. Keep the comment pinned to
|
||||
// what makes that function safe — if the safety changes, the
|
||||
// allowlist must be re-reviewed.
|
||||
expected := map[string]string{
|
||||
// org_import.createWorkspaceTree: lookupExistingChild
|
||||
// before INSERT (#2868 phase 3). Also pinned by
|
||||
// TestCreateWorkspaceTree_CallsLookupBeforeInsert.
|
||||
"org_import.go:createWorkspaceTree": "lookup-then-insert via lookupExistingChild",
|
||||
// registry.Register: external workspace registers itself with
|
||||
// its known UUID; INSERT is idempotent via ON CONFLICT (id)
|
||||
// DO UPDATE — re-registration upserts, never duplicates.
|
||||
"registry.go:Register": "ON CONFLICT (id) DO UPDATE",
|
||||
// workspace.Create: single-workspace POST /workspaces from a
|
||||
// human or automation. No iteration; payload describes one
|
||||
// workspace; UUID is server-generated. Caller intent IS to
|
||||
// create, so no idempotency check is needed.
|
||||
"workspace.go:Create": "single-workspace POST, server-generated UUID",
|
||||
}
|
||||
|
||||
actual := map[string]string{}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("readdir %s: %v", wd, err)
|
||||
}
|
||||
for _, ent := range entries {
|
||||
name := ent.Name()
|
||||
if ent.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".go") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(wd, name)
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s: %v", path, err)
|
||||
}
|
||||
// For each top-level FuncDecl, walk its body and check for an
|
||||
// `INSERT INTO workspaces (` SQL literal in any CallExpr arg.
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fn.Body == nil {
|
||||
continue
|
||||
}
|
||||
var foundInsert bool
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
lit, ok := n.(*ast.BasicLit)
|
||||
if !ok || lit.Kind != token.STRING {
|
||||
return true
|
||||
}
|
||||
raw := lit.Value
|
||||
if unq, err := strconv.Unquote(raw); err == nil {
|
||||
raw = unq
|
||||
}
|
||||
if workspacesInsertRE.MatchString(raw) {
|
||||
foundInsert = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if foundInsert {
|
||||
key := name + ":" + fn.Name.Name
|
||||
actual[key] = "(observed via AST walk)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute set diffs so failures point at the specific drift.
|
||||
missing := []string{}
|
||||
unexpected := []string{}
|
||||
for k := range expected {
|
||||
if _, ok := actual[k]; !ok {
|
||||
missing = append(missing, k)
|
||||
}
|
||||
}
|
||||
for k := range actual {
|
||||
if _, ok := expected[k]; !ok {
|
||||
unexpected = append(unexpected, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(missing)
|
||||
sort.Strings(unexpected)
|
||||
|
||||
if len(unexpected) > 0 {
|
||||
t.Errorf(`new function(s) emit `+"`INSERT INTO workspaces (`"+` and aren't in the allowlist:
|
||||
%s
|
||||
|
||||
If this is a legitimate addition, add an entry to expected[] in this test
|
||||
with the safety mechanism pinned in the comment alongside (lookup-then-
|
||||
insert / ON CONFLICT / single-workspace path / etc.). The bulk-create
|
||||
regression class needs explicit per-handler review, not silent drift.
|
||||
|
||||
Reference: RFC #2867 class 1, sibling test
|
||||
TestCreateWorkspaceTree_CallsLookupBeforeInsert.`,
|
||||
strings.Join(unexpected, "\n "))
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Errorf(`expected function(s) no longer emit `+"`INSERT INTO workspaces (`"+`:
|
||||
%s
|
||||
|
||||
Either the function was renamed/deleted (update the allowlist) or the
|
||||
INSERT was moved out (verify the new home is also covered). Don't just
|
||||
delete the entry — confirm the safety mechanism is still in place
|
||||
elsewhere or that the workspace-create path was intentionally
|
||||
restructured.`,
|
||||
strings.Join(missing, "\n "))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package pendinguploads
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to
|
||||
// the external test package. The production code uses StartSweeper
|
||||
// (which pins the canonical SweepInterval); tests pin a short interval
|
||||
// to exercise the ticker-driven cycle without burning real wall-clock
|
||||
// time. The Go convention `export_test.go` keeps this seam OUT of the
|
||||
// production binary — files ending in _test.go are stripped at build
|
||||
// time, so this re-export only exists during `go test`.
|
||||
func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
|
||||
startSweeperWithInterval(ctx, storage, ackRetention, interval)
|
||||
}
|
||||
@@ -85,6 +85,15 @@ type SweepResult struct {
|
||||
// Total returns the sum of Acked + Expired — convenient for log lines.
|
||||
func (r SweepResult) Total() int { return r.Acked + r.Expired }
|
||||
|
||||
// PutItem is one file in a PutBatch call. Same per-field rules as Put —
|
||||
// empty content, missing filename, or content > MaxFileBytes is rejected
|
||||
// up-front so a bad item in the batch doesn't poison the transaction.
|
||||
type PutItem struct {
|
||||
Content []byte
|
||||
Filename string
|
||||
Mimetype string
|
||||
}
|
||||
|
||||
// Storage is the platform-side persistence boundary for poll-mode chat
|
||||
// uploads. The Postgres implementation backs all callers today; an S3-
|
||||
// backed implementation can drop in once RFC #2789 lands by making
|
||||
@@ -99,6 +108,17 @@ type Storage interface {
|
||||
// content > MaxFileBytes return errors before any DB write.
|
||||
Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error)
|
||||
|
||||
// PutBatch inserts N uploads atomically — either all rows commit or
|
||||
// none do. Returns assigned file_ids in input order on success;
|
||||
// returns an error and does NOT insert any row on failure.
|
||||
//
|
||||
// Use this from multi-file upload handlers so a per-row failure on
|
||||
// row K doesn't leave rows 1..K-1 orphaned in the table (a client
|
||||
// retry would then double-insert them on success). All-or-nothing
|
||||
// semantics match the multipart request the canvas sends — either
|
||||
// the whole batch succeeds or the user re-uploads.
|
||||
PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error)
|
||||
|
||||
// Get returns the full row including content. Returns ErrNotFound
|
||||
// when the row is absent, acked, or past expires_at. Caller should
|
||||
// not differentiate the three cases in the response — from the
|
||||
@@ -174,6 +194,64 @@ func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, conten
|
||||
return fileID, nil
|
||||
}
|
||||
|
||||
// PutBatch inserts every item atomically inside a single Tx. On any
|
||||
// per-item validation or per-row INSERT error the Tx is rolled back and
|
||||
// the caller sees the error without any rows committed — no partial
|
||||
// orphans for a multi-file upload that fails mid-batch.
|
||||
//
|
||||
// Validation runs BEFORE BEGIN so a bad input shape (empty content,
|
||||
// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only
|
||||
// failures expected are DB-side (broken connection, statement timeout)
|
||||
// — those abort cleanly via Rollback.
|
||||
func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) {
|
||||
if len(items) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
for i, it := range items {
|
||||
if len(it.Content) == 0 {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: empty content", i)
|
||||
}
|
||||
if len(it.Content) > MaxFileBytes {
|
||||
return nil, ErrTooLarge
|
||||
}
|
||||
if it.Filename == "" {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i)
|
||||
}
|
||||
if len(it.Filename) > 100 {
|
||||
return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i)
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: begin tx: %w", err)
|
||||
}
|
||||
// Defer-rollback is safe even after a successful Commit — the second
|
||||
// Rollback is a no-op (database/sql tracks tx state).
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
out := make([]uuid.UUID, 0, len(items))
|
||||
for i, it := range items {
|
||||
var fid uuid.UUID
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING file_id
|
||||
`, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err)
|
||||
}
|
||||
out = append(out, fid)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("pendinguploads: commit batch: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) {
|
||||
// The expires_at + acked_at filter in the WHERE clause means a
|
||||
// caller sees ErrNotFound for absent / acked / expired without
|
||||
|
||||
@@ -511,3 +511,223 @@ func TestSweepResult_TotalSumsCounts(t *testing.T) {
|
||||
t.Errorf("zero Total = %d, want 0", z.Total())
|
||||
}
|
||||
}
|
||||
|
||||
// ----- PutBatch -------------------------------------------------------------
|
||||
//
|
||||
// PutBatch is the multi-file atomic insert path used by uploadPollMode in
|
||||
// chat_files.go. The contract that callers rely on:
|
||||
//
|
||||
// - Either ALL rows commit, or NONE do — a per-row INSERT failure must
|
||||
// leave the table unchanged (no orphaned rows from a half-applied batch).
|
||||
// - Per-item validation runs BEFORE the Tx opens so a bad input shape
|
||||
// never wastes a BEGIN round-trip.
|
||||
// - Returned []uuid.UUID is in input order — handler maps response back
|
||||
// to the multipart Files[i].
|
||||
//
|
||||
// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us
|
||||
// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for
|
||||
// BeginTx-with-options, the test fails until we re-pin.
|
||||
|
||||
func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1, id2, id3 := uuid.New(), uuid.New(), uuid.New()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3))
|
||||
mock.ExpectCommit()
|
||||
// Rollback after Commit is a no-op in database/sql; sqlmock allows it
|
||||
// when ExpectCommit was already matched, so we don't need to expect it.
|
||||
|
||||
got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"},
|
||||
{Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"},
|
||||
{Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutBatch: %v", err)
|
||||
}
|
||||
if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 {
|
||||
t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) {
|
||||
db, _ := newMockDB(t) // zero expectations — must NOT round-trip
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
got, err := store.PutBatch(context.Background(), uuid.New(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on empty batch, got %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("expected nil ids on empty batch, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "a.txt"},
|
||||
{Content: nil, Filename: "b.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") {
|
||||
t.Fatalf("expected item-1 empty-content error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
too := make([]byte, pendinguploads.MaxFileBytes+1)
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("ok"), Filename: "small.txt"},
|
||||
{Content: too, Filename: "huge.bin"},
|
||||
})
|
||||
if !errors.Is(err, pendinguploads.ErrTooLarge) {
|
||||
t.Fatalf("expected ErrTooLarge, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: ""},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") {
|
||||
t.Fatalf("expected item-0 empty-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) {
|
||||
db, _ := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
long := strings.Repeat("z", 101)
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "ok.txt"},
|
||||
{Content: []byte("hi"), Filename: long},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") {
|
||||
t.Fatalf("expected item-1 too-long-filename error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_BeginTxError_Wrapped(t *testing.T) {
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
mock.ExpectBegin().WillReturnError(errors.New("conn refused"))
|
||||
|
||||
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "begin tx") {
|
||||
t.Fatalf("expected wrapped begin-tx error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) {
|
||||
// First INSERT succeeds, second errors. PutBatch MUST NOT issue
|
||||
// Commit; the deferred Rollback unwinds row 1 so neither row commits.
|
||||
// This is the contract that prevents orphan rows on a failed batch.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1 := uuid.New()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("bb"), int64(2), "b.txt", "").
|
||||
WillReturnError(errors.New("statement timeout"))
|
||||
// Critical: Rollback expected, NOT Commit. If a future refactor
|
||||
// accidentally swallows the per-row error and Commits anyway, this
|
||||
// test fails because the unmet ExpectCommit-vs-Rollback shape diverges.
|
||||
mock.ExpectRollback()
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("aaa"), Filename: "a.txt"},
|
||||
{Content: []byte("bb"), Filename: "b.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "batch insert item 1") {
|
||||
t.Fatalf("expected wrapped per-row insert error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations (must rollback, no commit): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) {
|
||||
// Edge case: very first INSERT fails. No rows ever staged — but the
|
||||
// Tx still needs to roll back to release the snapshot.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("oops"), int64(4), "a.txt", "").
|
||||
WillReturnError(errors.New("constraint violation"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("oops"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "batch insert item 0") {
|
||||
t.Fatalf("expected wrapped item-0 insert error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutBatch_CommitError_Wrapped(t *testing.T) {
|
||||
// Commit fails after every INSERT succeeded. Postgres has already
|
||||
// rolled back the Tx by this point; we surface the error so the
|
||||
// handler returns 500 and the client retries.
|
||||
db, mock := newMockDB(t)
|
||||
store := pendinguploads.NewPostgres(db)
|
||||
|
||||
wsID := uuid.New()
|
||||
id1 := uuid.New()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(insertSQL).
|
||||
WithArgs(wsID, []byte("hi"), int64(2), "a.txt", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
|
||||
mock.ExpectCommit().WillReturnError(errors.New("commit broken"))
|
||||
|
||||
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
|
||||
{Content: []byte("hi"), Filename: "a.txt"},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "commit batch") {
|
||||
t.Fatalf("expected wrapped commit error, got %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,13 +66,13 @@ const sweepDeadline = 30 * time.Second
|
||||
// to exercise the ticker-driven sweep path without burning real wall-
|
||||
// clock time.
|
||||
func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) {
|
||||
StartSweeperWithInterval(ctx, storage, ackRetention, SweepInterval)
|
||||
startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval)
|
||||
}
|
||||
|
||||
// StartSweeperWithInterval is the test-friendly variant of StartSweeper
|
||||
// startSweeperWithInterval is the test-friendly variant of StartSweeper
|
||||
// — same loop, but the cadence is caller-specified. Production code
|
||||
// should use StartSweeper to keep the SweepInterval constant pinned.
|
||||
func StartSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
|
||||
func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
|
||||
if storage == nil {
|
||||
log.Println("pendinguploads sweeper: storage is nil — sweeper disabled")
|
||||
return
|
||||
|
||||
@@ -44,6 +44,9 @@ func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error {
|
||||
func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error {
|
||||
return errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) {
|
||||
idx := int(f.calls.Load())
|
||||
f.calls.Add(1)
|
||||
@@ -65,6 +68,15 @@ func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration)
|
||||
|
||||
// waitForCycle blocks until at least one Sweep completes, with a deadline.
|
||||
// Tests use this instead of time.Sleep to avoid flakes on slow CI hosts.
|
||||
//
|
||||
// CAVEAT: cycleDone fires from inside fakeSweepStorage.Sweep's defer,
|
||||
// which runs as Sweep returns its result — BEFORE the StartSweeper
|
||||
// loop has processed the (result, error) tuple and called the
|
||||
// metric recorders. Tests that assert on metric counters must NOT
|
||||
// rely on this wait alone; use waitForMetricDelta instead so the
|
||||
// metric increment race (Sweep returns → cycleDone fires → test
|
||||
// reads counter → only then does StartSweeper's loop call
|
||||
// metrics.PendingUploadsSweepError) doesn't produce a flake.
|
||||
func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.NewTimer(timeout)
|
||||
@@ -78,6 +90,33 @@ func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Durati
|
||||
}
|
||||
}
|
||||
|
||||
// waitForMetricDelta polls the supplied delta function until it returns
|
||||
// `want` or the timeout elapses. Use after waitForCycle when the test
|
||||
// asserts on a metric counter — closes the race between cycleDone
|
||||
// (signalled inside fakeSweepStorage.Sweep's defer, BEFORE Sweep
|
||||
// returns to StartSweeper) and the metric recording (which happens in
|
||||
// StartSweeper's loop AFTER Sweep returns). On a slow CI host the test
|
||||
// goroutine wins the read before StartSweeper's goroutine writes the
|
||||
// counter; the polling assert preserves the determinism of "the metric
|
||||
// MUST be N" without timing-based flakes.
|
||||
//
|
||||
// Per memory feedback_question_test_when_unexpected.md: the failure
|
||||
// mode "delta=0, want=1" looked like a real bug at first glance —
|
||||
// "metric never incremented" — but instrumented analysis showed the
|
||||
// metric DID increment, just AFTER the test's read. The fix is the
|
||||
// test's wait shape, not the production code.
|
||||
func waitForMetricDelta(t *testing.T, delta func() int64, want int64, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if delta() == want {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("waited %s for metric delta=%d, last seen %d", timeout, want, delta())
|
||||
}
|
||||
|
||||
func TestStartSweeper_NilStorageDoesNotPanic(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -144,7 +183,7 @@ func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go pendinguploads.StartSweeperWithInterval(ctx, store, time.Hour, 30*time.Millisecond)
|
||||
go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond)
|
||||
|
||||
// Immediate cycle + at least one tick-driven cycle.
|
||||
store.waitForCycle(t, 2, 2*time.Second)
|
||||
@@ -220,12 +259,13 @@ func TestStartSweeper_RecordsMetricsOnSuccess(t *testing.T) {
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
|
||||
if got := deltaAcked(); got != 3 {
|
||||
t.Errorf("acked counter delta = %d, want 3", got)
|
||||
}
|
||||
if got := deltaExpired(); got != 5 {
|
||||
t.Errorf("expired counter delta = %d, want 5", got)
|
||||
}
|
||||
// Poll for the success counters to settle — closes the cycleDone-
|
||||
// vs-metric-record race (see waitForMetricDelta comment).
|
||||
waitForMetricDelta(t, deltaAcked, 3, 2*time.Second)
|
||||
waitForMetricDelta(t, deltaExpired, 5, 2*time.Second)
|
||||
// Error counter MUST stay at zero on the success path. Read after
|
||||
// the success counters have settled — once those are correct,
|
||||
// StartSweeper has fully processed this cycle's result.
|
||||
if got := deltaError(); got != 0 {
|
||||
t.Errorf("error counter delta = %d, want 0", got)
|
||||
}
|
||||
@@ -244,7 +284,11 @@ func TestStartSweeper_RecordsMetricsOnError(t *testing.T) {
|
||||
go pendinguploads.StartSweeper(ctx, store, time.Hour)
|
||||
store.waitForCycle(t, 1, 2*time.Second)
|
||||
|
||||
if got := deltaError(); got != 1 {
|
||||
t.Errorf("error counter delta = %d, want 1", got)
|
||||
}
|
||||
// Poll for the error counter to settle — cycleDone fires inside
|
||||
// the fake's Sweep defer, BEFORE StartSweeper's loop receives the
|
||||
// returned error and calls metrics.PendingUploadsSweepError. On
|
||||
// slow CI hosts a direct deltaError() read here returns 0 even
|
||||
// though the metric WILL be 1 a few ms later. See
|
||||
// waitForMetricDelta comment.
|
||||
waitForMetricDelta(t, deltaError, 1, 2*time.Second)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// CPProvisionerAPI is the contract WorkspaceHandler uses to talk to the
|
||||
@@ -214,6 +215,13 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
|
||||
}
|
||||
|
||||
log.Printf("CP provisioner: workspace %s → EC2 instance %s (%s)", cfg.WorkspaceID, result.InstanceID, result.State)
|
||||
provlog.Event("provision.ec2_started", map[string]any{
|
||||
"workspace_id": cfg.WorkspaceID,
|
||||
"instance_id": result.InstanceID,
|
||||
"state": result.State,
|
||||
"tier": cfg.Tier,
|
||||
"runtime": cfg.Runtime,
|
||||
})
|
||||
return result.InstanceID, nil
|
||||
}
|
||||
|
||||
@@ -273,6 +281,10 @@ func (p *CPProvisioner) Stop(ctx context.Context, workspaceID string) error {
|
||||
return fmt.Errorf("cp provisioner: stop %s: unexpected %d: %s",
|
||||
workspaceID, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
provlog.Event("provision.ec2_stopped", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"instance_id": instanceID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
// Package provlog emits structured, single-line JSON log records for
|
||||
// provisioning-lifecycle boundaries (workspace create, EC2 start/stop,
|
||||
// restart, idempotency skips). Records share a stable `evt:` prefix and
|
||||
// JSON payload so a future grep|jq pipeline (or a Loki/Datadog ingest)
|
||||
// can reconstruct the per-workspace timeline without parsing the
|
||||
// human-prose log lines that already exist.
|
||||
//
|
||||
// Existing log.Printf lines are intentionally NOT replaced — they
|
||||
// remain the operator-facing message. Event() emits a paired structured
|
||||
// record alongside, additive only.
|
||||
//
|
||||
// Event taxonomy (extend by appending; never rename):
|
||||
//
|
||||
// provision.start — workspace row inserted, EC2 about to launch
|
||||
// provision.skip_existing — idempotency hit, no new EC2
|
||||
// provision.ec2_started — RunInstances returned an instance id
|
||||
// provision.ec2_stopped — TerminateInstances acknowledged
|
||||
// restart.pre_stop — Restart handler about to call Stop
|
||||
//
|
||||
// Required fields per event are documented at each call site.
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
)
|
||||
|
||||
// Event writes a single line of the form:
|
||||
//
|
||||
// evt: <name> {"k":"v",...}
|
||||
//
|
||||
// to the standard logger. JSON encoding errors are silently swallowed —
|
||||
// a logging helper must never panic the request path. fields may be
|
||||
// nil; the empty payload `{}` is still useful to mark an event boundary.
|
||||
func Event(name string, fields map[string]any) {
|
||||
if fields == nil {
|
||||
fields = map[string]any{}
|
||||
}
|
||||
payload, err := json.Marshal(fields)
|
||||
if err != nil {
|
||||
// Fall back to a static payload so the event boundary still
|
||||
// appears in the log. The marshal error itself is recorded
|
||||
// on a best-effort basis.
|
||||
log.Printf("evt: %s {\"_marshal_err\":%q}", name, err.Error())
|
||||
return
|
||||
}
|
||||
log.Printf("evt: %s %s", name, payload)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// captureLog redirects the default logger to a buffer for the duration
|
||||
// of fn and returns whatever was written.
|
||||
func captureLog(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetOutput(&buf)
|
||||
log.SetFlags(0) // strip date/time so assertions stay deterministic
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
fn()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestEvent_EmitsEvtPrefixAndJSONPayload(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.start", map[string]any{
|
||||
"workspace_id": "ws-123",
|
||||
"tier": 4,
|
||||
"runtime": "claude-code",
|
||||
})
|
||||
})
|
||||
out = strings.TrimSpace(out)
|
||||
if !strings.HasPrefix(out, "evt: provision.start ") {
|
||||
t.Fatalf("expected evt-prefixed line, got %q", out)
|
||||
}
|
||||
jsonPart := strings.TrimPrefix(out, "evt: provision.start ")
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonPart), &got); err != nil {
|
||||
t.Fatalf("payload not valid JSON: %v (raw=%q)", err, jsonPart)
|
||||
}
|
||||
if got["workspace_id"] != "ws-123" {
|
||||
t.Errorf("workspace_id field lost: %+v", got)
|
||||
}
|
||||
// JSON unmarshal turns numbers into float64 — exact-equal compare.
|
||||
if got["tier"].(float64) != 4 {
|
||||
t.Errorf("tier field lost: %+v", got)
|
||||
}
|
||||
if got["runtime"] != "claude-code" {
|
||||
t.Errorf("runtime field lost: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_NilFieldsEmitsEmptyObject(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("restart.pre_stop", nil)
|
||||
})
|
||||
if !strings.Contains(out, "evt: restart.pre_stop {}") {
|
||||
t.Fatalf("nil fields should emit empty object, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_PreservesEventBoundaryOnUnmarshalableValue(t *testing.T) {
|
||||
// A channel cannot be marshaled by encoding/json — verify we still
|
||||
// emit the event boundary with a recorded marshal error. This is
|
||||
// the structural guarantee: the call site never sees a panic, and
|
||||
// the event name is always present in the log.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.ec2_started", map[string]any{
|
||||
"chan": make(chan int),
|
||||
})
|
||||
})
|
||||
if !strings.Contains(out, "evt: provision.ec2_started ") {
|
||||
t.Fatalf("event boundary missing on marshal error: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "_marshal_err") {
|
||||
t.Fatalf("expected _marshal_err sentinel, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_SingleLineOutput(t *testing.T) {
|
||||
// Log aggregators line-split on \n. A multi-line emit would silently
|
||||
// fragment the JSON across two records — pin single-line shape.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.skip_existing", map[string]any{
|
||||
"existing_id": "ws-abc",
|
||||
"name": "child-1",
|
||||
})
|
||||
})
|
||||
trimmed := strings.TrimRight(out, "\n")
|
||||
if strings.Contains(trimmed, "\n") {
|
||||
t.Fatalf("event line must be single-line, got %q", out)
|
||||
}
|
||||
}
|
||||
@@ -243,13 +243,15 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
// entire platform. Gated behind AdminAuth (issue #180).
|
||||
r.GET("/approvals/pending", middleware.AdminAuth(db.DB), apph.ListAll)
|
||||
|
||||
// Team handlers — Collapse only. The bulk-Expand path is gone:
|
||||
// every workspace can have children via the regular CreateWorkspace
|
||||
// flow with parent_id set, so a separate handler that bulk-creates
|
||||
// from sub_workspaces (and was non-idempotent — calling it twice
|
||||
// duplicated the team) earned its way out.
|
||||
teamh := handlers.NewTeamHandler(broadcaster, wh, platformURL, configsDir)
|
||||
wsAuth.POST("/collapse", teamh.Collapse)
|
||||
// (TeamHandler is gone — #2864.) The visual canvas Collapse
|
||||
// button calls PATCH /workspaces/:id { collapsed: true/false }
|
||||
// (presentational toggle on canvas_layouts), NOT the destructive
|
||||
// POST /collapse that stopped + removed children. The
|
||||
// destructive route had zero UI callers (verified via grep
|
||||
// across canvas/, scripts/, and the MCP tool registry — only
|
||||
// docs referenced it). team.go + team_test.go + the route
|
||||
// + helpers (findTemplateDirByName, NewTeamHandler) are
|
||||
// deleted; visual collapse is unaffected.
|
||||
|
||||
// Agents
|
||||
ah := handlers.NewAgentHandler(broadcaster)
|
||||
@@ -519,8 +521,9 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
r.GET("/canvas/viewport", vh.Get)
|
||||
r.PUT("/canvas/viewport", middleware.CanvasOrBearer(db.DB), vh.Save)
|
||||
|
||||
// Templates
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli)
|
||||
// Templates — wh threaded so generateDefaultConfig picks the
|
||||
// SaaS-aware default tier in Import + ReplaceFiles (#2910 PR-B).
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli, wh)
|
||||
// #686: GET /templates lists all template names+metadata from configsDir.
|
||||
// Open access lets unauthenticated callers enumerate org configurations and
|
||||
// installed plugins. AdminAuth-gate it alongside POST /templates/import.
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql.
|
||||
DROP INDEX IF EXISTS idx_pending_uploads_acked;
|
||||
@@ -0,0 +1,30 @@
|
||||
-- 20260505200000_pending_uploads_acked_index.up.sql
|
||||
--
|
||||
-- Adds the missing partial index for the acked-retention arm of the
|
||||
-- pendinguploads.Sweep query. The Phase 1 migration created two
|
||||
-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch
|
||||
-- hot path + expires_at sweep arm); the third query path —
|
||||
-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was
|
||||
-- left to a seq scan.
|
||||
--
|
||||
-- For a high-traffic deployment that's a real cost: the table
|
||||
-- accumulates one row per chat-attached file; the sweeper runs every
|
||||
-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq
|
||||
-- scan over 100K-1M acked rows holds an AccessShare lock for seconds
|
||||
-- on every cycle. Partial-indexing the inverse predicate reduces
|
||||
-- this to a btree range scan and lets the DELETE complete in
|
||||
-- low-millisecond range.
|
||||
--
|
||||
-- WHERE acked_at IS NOT NULL is intentionally inverse of the other
|
||||
-- two indexes — they cover the unacked working set; this covers the
|
||||
-- terminal-state set the sweeper visits. Disjoint subsets, so the
|
||||
-- two indexes don't overlap.
|
||||
--
|
||||
-- Caught in self-review on the parent RFC's Phase 4 PR; filed as
|
||||
-- a follow-up rather than a Phase 1 fix because the cost only
|
||||
-- materializes at a row count we don't expect to hit before the
|
||||
-- sweeper has had a chance to keep up.
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked
|
||||
ON pending_uploads (acked_at)
|
||||
WHERE acked_at IS NOT NULL;
|
||||
+12
-318
@@ -115,324 +115,18 @@ async def report_activity(
|
||||
pass # Best-effort — don't block delegation on activity reporting
|
||||
|
||||
|
||||
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
|
||||
# intentionally generous: 3s gives the platform's executeDelegation
|
||||
# goroutine room to dispatch + the callee to respond + the result to
|
||||
# write to activity_logs without thrashing the platform with rapid
|
||||
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
|
||||
# operators don't see behavior change beyond "no more 600s timeouts".
|
||||
_SYNC_POLL_INTERVAL_S = 3.0
|
||||
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
|
||||
|
||||
|
||||
async def _delegate_sync_via_polling(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
src: str,
|
||||
) -> str:
|
||||
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
|
||||
|
||||
Sidesteps the platform proxy's blocking `message/send` HTTP path that
|
||||
hits a hard 600s ceiling. Instead:
|
||||
|
||||
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
|
||||
— platform's executeDelegation goroutine handles A2A dispatch in
|
||||
the background. No client-side timeout dependency on the platform
|
||||
holding a connection open.
|
||||
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
|
||||
matching delegation_id reaching terminal status (completed/failed).
|
||||
3. Return the response_preview text on completed; surface error_detail
|
||||
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
|
||||
path uses, so caller error-detection logic is unchanged).
|
||||
|
||||
Both /delegate and /delegations are existing endpoints — this helper
|
||||
just composes them into a polling synchronous facade. The result is
|
||||
available the moment the platform writes the terminal status row;
|
||||
no extra latency vs. the legacy proxy-blocked path on fast cases.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
# 1. Dispatch via /delegate (the async, durable path).
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={
|
||||
"target_id": workspace_id,
|
||||
"task": task,
|
||||
"idempotency_key": idem_key,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
|
||||
|
||||
if resp.status_code != 202 and resp.status_code != 200:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
|
||||
|
||||
try:
|
||||
dispatch = resp.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
|
||||
|
||||
delegation_id = dispatch.get("delegation_id", "")
|
||||
if not delegation_id:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
|
||||
|
||||
# 2. Poll for terminal status with a deadline. Each poll is a cheap
|
||||
# /delegations GET — bounded by the platform's existing rate limit.
|
||||
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
|
||||
last_status = "unknown"
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
poll = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Transient — keep polling. The platform IS holding the
|
||||
# delegation row; we just lost a network request.
|
||||
last_status = f"poll-error: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
if poll.status_code != 200:
|
||||
last_status = f"poll HTTP {poll.status_code}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
try:
|
||||
rows = poll.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
last_status = f"poll non-JSON: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
# /delegations returns a flat list of delegation events. Filter to
|
||||
# our delegation_id; pick the first terminal one. The list may
|
||||
# have multiple rows per delegation_id (one for the original
|
||||
# dispatch, one per status update); we want the latest terminal.
|
||||
if not isinstance(rows, list):
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
terminal = None
|
||||
for r in rows:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
if r.get("delegation_id") != delegation_id:
|
||||
continue
|
||||
status = (r.get("status") or "").lower()
|
||||
last_status = status
|
||||
if status in ("completed", "failed"):
|
||||
terminal = r
|
||||
break
|
||||
if terminal:
|
||||
if (terminal.get("status") or "").lower() == "completed":
|
||||
return terminal.get("response_preview") or ""
|
||||
err = (
|
||||
terminal.get("error_detail")
|
||||
or terminal.get("summary")
|
||||
or "delegation failed"
|
||||
)
|
||||
return f"{_A2A_ERROR_PREFIX}{err}"
|
||||
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
|
||||
# Budget exhausted — the platform's row is still in flight (or queued).
|
||||
# Surface as an error so the caller can decide to retry or fall back;
|
||||
# the platform DOES still have the durable row, so the work isn't
|
||||
# lost — it'll complete eventually and a future check_task_status
|
||||
# will surface the result.
|
||||
return (
|
||||
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
|
||||
f"(delegation_id={delegation_id}, last_status={last_status}); "
|
||||
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
|
||||
)
|
||||
|
||||
|
||||
async def tool_delegate_task(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
delegation originates from — drives auth + the X-Workspace-ID source
|
||||
header so the platform's a2a_proxy logs the correct sender. Single-
|
||||
workspace operators leave it None and routing falls back to the
|
||||
module-level WORKSPACE_ID.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
# Auto-route: if source not specified, look up which registered
|
||||
# workspace last saw this peer (populated by tool_list_peers). Falls
|
||||
# back to the legacy WORKSPACE_ID for single-workspace operators.
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
|
||||
|
||||
# Discover the target. discover_peer is the access-control gate +
|
||||
# name/status lookup. The peer's reported ``url`` field is NOT used
|
||||
# for routing — see send_a2a_message, which constructs the URL via
|
||||
# the platform's A2A proxy.
|
||||
peer = await discover_peer(workspace_id, source_workspace_id=src)
|
||||
if not peer:
|
||||
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
|
||||
|
||||
if (peer.get("status") or "").lower() == "offline":
|
||||
return f"Error: workspace {workspace_id} is offline"
|
||||
|
||||
# Report delegation start — include the task text for traceability
|
||||
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
|
||||
_peer_names[workspace_id] = peer_name # cache for future use
|
||||
# Brief summary for canvas display — just the delegation target
|
||||
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
|
||||
|
||||
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
|
||||
# use the platform's durable async delegation API (POST /delegate +
|
||||
# poll /delegations) instead of the proxy-blocked message/send path.
|
||||
# This sidesteps the 600s message/send timeout class that broke
|
||||
# iteration-14/90-style long-running delegations on 2026-05-05.
|
||||
#
|
||||
# Default off — staging-canary first, flip default after PR-2's
|
||||
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
|
||||
# ≥1 week without incident.
|
||||
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
|
||||
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
|
||||
else:
|
||||
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
|
||||
# (the platform proxy) so the same code works for in-container and
|
||||
# external (standalone molecule-mcp) callers.
|
||||
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
|
||||
|
||||
# Detect delegation failures — wrap them clearly so the calling agent
|
||||
# can decide to retry, use another peer, or handle the task itself.
|
||||
is_error = result.startswith(_A2A_ERROR_PREFIX)
|
||||
# Strip the sentinel prefix so error_detail is the human-readable
|
||||
# cause directly. The Activity tab's red error chip surfaces this
|
||||
# without the user having to scroll into the raw response JSON.
|
||||
#
|
||||
# Cap at 4096 chars before sending — the platform's
|
||||
# activity_logs.error_detail column is unbounded TEXT and a
|
||||
# malicious or buggy peer could otherwise stream an arbitrarily
|
||||
# large error message into the caller's activity log. 4096 is
|
||||
# comfortably above any real exception traceback we've seen and
|
||||
# well below an obvious-DoS threshold.
|
||||
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
|
||||
await report_activity(
|
||||
"a2a_receive", workspace_id,
|
||||
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
|
||||
task_text=task, response_text=result,
|
||||
status="error" if is_error else "ok",
|
||||
error_detail=error_detail,
|
||||
)
|
||||
if is_error:
|
||||
return (
|
||||
f"DELEGATION FAILED to {peer_name}: {result}\n"
|
||||
f"You should either: (1) try a different peer, (2) handle this task yourself, "
|
||||
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def tool_delegate_task_async(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task via the platform's async delegation API (fire-and-forget).
|
||||
|
||||
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
|
||||
Results are tracked in the platform DB and broadcast via WebSocket.
|
||||
Use check_task_status to poll for results.
|
||||
|
||||
``source_workspace_id`` selects the sending workspace (which one of
|
||||
this agent's registered workspaces gets logged as the originator);
|
||||
auto-routes via the peer→source cache when omitted.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
|
||||
|
||||
# Idempotency key: SHA-256 of (source, target, task) so that a
|
||||
# restarted agent firing the same delegation gets the same key and
|
||||
# the platform returns the existing delegation_id instead of
|
||||
# creating a duplicate. Fixes #1456. Source is in the key so the
|
||||
# SAME task delegated from two different registered workspaces
|
||||
# produces two distinct delegations (the right behavior — one per
|
||||
# tenant audit trail).
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code == 202:
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"delegation_id": data.get("delegation_id", ""),
|
||||
"workspace_id": workspace_id,
|
||||
"status": "delegated",
|
||||
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
|
||||
})
|
||||
else:
|
||||
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
|
||||
except Exception as e:
|
||||
return f"Error: delegation failed — {e}"
|
||||
|
||||
|
||||
async def tool_check_task_status(
|
||||
workspace_id: str,
|
||||
task_id: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Check delegations for this workspace via the platform API.
|
||||
|
||||
Args:
|
||||
workspace_id: Ignored (kept for backward compat). Checks
|
||||
``source_workspace_id``'s delegations (the workspace that
|
||||
FIRED the delegations), not the target's.
|
||||
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
|
||||
source_workspace_id: Which registered workspace's delegation log
|
||||
to query. Defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: failed to check delegations ({resp.status_code})"
|
||||
delegations = resp.json()
|
||||
if task_id:
|
||||
# Filter by delegation_id
|
||||
matching = [d for d in delegations if d.get("delegation_id") == task_id]
|
||||
if matching:
|
||||
return json.dumps(matching[0])
|
||||
return json.dumps({"status": "not_found", "delegation_id": task_id})
|
||||
# Return all recent delegations
|
||||
summary = []
|
||||
for d in delegations[:10]:
|
||||
summary.append({
|
||||
"delegation_id": d.get("delegation_id", ""),
|
||||
"target_id": d.get("target_id", ""),
|
||||
"status": d.get("status", ""),
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
})
|
||||
return json.dumps({"delegations": summary, "count": len(delegations)})
|
||||
except Exception as e:
|
||||
return f"Error checking delegations: {e}"
|
||||
# Delegation tool handlers — extracted to a2a_tools_delegation
|
||||
# (RFC #2873 iter 4b). Re-imported here so call sites + tests that
|
||||
# reference ``a2a_tools.tool_delegate_task`` /
|
||||
# ``a2a_tools._delegate_sync_via_polling`` keep resolving identically.
|
||||
from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_client block)
|
||||
_SYNC_POLL_BUDGET_S,
|
||||
_SYNC_POLL_INTERVAL_S,
|
||||
_delegate_sync_via_polling,
|
||||
tool_check_task_status,
|
||||
tool_delegate_task,
|
||||
tool_delegate_task_async,
|
||||
)
|
||||
|
||||
|
||||
async def _upload_chat_files(
|
||||
|
||||
@@ -0,0 +1,372 @@
|
||||
"""Delegation tool handlers — single-concern slice of the a2a_tools surface.
|
||||
|
||||
Extracted from ``a2a_tools.py`` (RFC #2873 iter 4b). Owns the three
|
||||
delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they
|
||||
share.
|
||||
|
||||
Public surface:
|
||||
|
||||
* ``tool_delegate_task`` — synchronous delegation, waits for response.
|
||||
* ``tool_delegate_task_async`` — fire-and-forget delegation; returns
|
||||
``{delegation_id, ...}``.
|
||||
* ``tool_check_task_status`` — poll the platform's ``/delegations`` log.
|
||||
|
||||
Internal:
|
||||
|
||||
* ``_delegate_sync_via_polling`` — durable async + poll for terminal
|
||||
status (RFC #2829 PR-5 cutover path; toggled by
|
||||
``DELEGATION_SYNC_VIA_INBOX=1``).
|
||||
* ``_SYNC_POLL_INTERVAL_S`` / ``_SYNC_POLL_BUDGET_S`` constants.
|
||||
|
||||
Circular-import note: this module calls ``report_activity`` from
|
||||
``a2a_tools`` to emit activity rows around the delegate dispatch.
|
||||
``a2a_tools`` imports the public symbols here at module-load time,
|
||||
so we use a LAZY import for ``report_activity`` inside the function
|
||||
that needs it. Without the lazy hop Python raises an ImportError
|
||||
on first ``a2a_tools`` import.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from a2a_client import (
|
||||
PLATFORM_URL,
|
||||
WORKSPACE_ID,
|
||||
_A2A_ERROR_PREFIX,
|
||||
_peer_names,
|
||||
_peer_to_source,
|
||||
discover_peer,
|
||||
send_a2a_message,
|
||||
)
|
||||
from a2a_tools_rbac import auth_headers_for_heartbeat as _auth_headers_for_heartbeat
|
||||
|
||||
|
||||
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
|
||||
# intentionally generous: 3s gives the platform's executeDelegation
|
||||
# goroutine room to dispatch + the callee to respond + the result to
|
||||
# write to activity_logs without thrashing the platform with rapid
|
||||
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
|
||||
# operators don't see behavior change beyond "no more 600s timeouts".
|
||||
_SYNC_POLL_INTERVAL_S = 3.0
|
||||
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
|
||||
|
||||
|
||||
async def _delegate_sync_via_polling(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
src: str,
|
||||
) -> str:
|
||||
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
|
||||
|
||||
Sidesteps the platform proxy's blocking `message/send` HTTP path that
|
||||
hits a hard 600s ceiling. Instead:
|
||||
|
||||
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
|
||||
— platform's executeDelegation goroutine handles A2A dispatch in
|
||||
the background. No client-side timeout dependency on the platform
|
||||
holding a connection open.
|
||||
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
|
||||
matching delegation_id reaching terminal status (completed/failed).
|
||||
3. Return the response_preview text on completed; surface error_detail
|
||||
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
|
||||
path uses, so caller error-detection logic is unchanged).
|
||||
|
||||
Both /delegate and /delegations are existing endpoints — this helper
|
||||
just composes them into a polling synchronous facade. The result is
|
||||
available the moment the platform writes the terminal status row;
|
||||
no extra latency vs. the legacy proxy-blocked path on fast cases.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
# 1. Dispatch via /delegate (the async, durable path).
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={
|
||||
"target_id": workspace_id,
|
||||
"task": task,
|
||||
"idempotency_key": idem_key,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
|
||||
|
||||
if resp.status_code != 202 and resp.status_code != 200:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
|
||||
|
||||
try:
|
||||
dispatch = resp.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
|
||||
|
||||
delegation_id = dispatch.get("delegation_id", "")
|
||||
if not delegation_id:
|
||||
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
|
||||
|
||||
# 2. Poll for terminal status with a deadline. Each poll is a cheap
|
||||
# /delegations GET — bounded by the platform's existing rate limit.
|
||||
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
|
||||
last_status = "unknown"
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
poll = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Transient — keep polling. The platform IS holding the
|
||||
# delegation row; we just lost a network request.
|
||||
last_status = f"poll-error: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
if poll.status_code != 200:
|
||||
last_status = f"poll HTTP {poll.status_code}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
try:
|
||||
rows = poll.json()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
last_status = f"poll non-JSON: {e}"
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
|
||||
# /delegations returns a flat list of delegation events. Filter to
|
||||
# our delegation_id; pick the first terminal one. The list may
|
||||
# have multiple rows per delegation_id (one for the original
|
||||
# dispatch, one per status update); we want the latest terminal.
|
||||
if not isinstance(rows, list):
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
continue
|
||||
terminal = None
|
||||
for r in rows:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
if r.get("delegation_id") != delegation_id:
|
||||
continue
|
||||
status = (r.get("status") or "").lower()
|
||||
last_status = status
|
||||
if status in ("completed", "failed"):
|
||||
terminal = r
|
||||
break
|
||||
if terminal:
|
||||
if (terminal.get("status") or "").lower() == "completed":
|
||||
return terminal.get("response_preview") or ""
|
||||
err = (
|
||||
terminal.get("error_detail")
|
||||
or terminal.get("summary")
|
||||
or "delegation failed"
|
||||
)
|
||||
return f"{_A2A_ERROR_PREFIX}{err}"
|
||||
|
||||
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
|
||||
|
||||
# Budget exhausted — the platform's row is still in flight (or queued).
|
||||
# Surface as an error so the caller can decide to retry or fall back;
|
||||
# the platform DOES still have the durable row, so the work isn't
|
||||
# lost — it'll complete eventually and a future check_task_status
|
||||
# will surface the result.
|
||||
return (
|
||||
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
|
||||
f"(delegation_id={delegation_id}, last_status={last_status}); "
|
||||
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
|
||||
)
|
||||
|
||||
|
||||
async def tool_delegate_task(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
delegation originates from — drives auth + the X-Workspace-ID source
|
||||
header so the platform's a2a_proxy logs the correct sender. Single-
|
||||
workspace operators leave it None and routing falls back to the
|
||||
module-level WORKSPACE_ID.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
# Auto-route: if source not specified, look up which registered
|
||||
# workspace last saw this peer (populated by tool_list_peers). Falls
|
||||
# back to the legacy WORKSPACE_ID for single-workspace operators.
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
|
||||
|
||||
# Discover the target. discover_peer is the access-control gate +
|
||||
# name/status lookup. The peer's reported ``url`` field is NOT used
|
||||
# for routing — see send_a2a_message, which constructs the URL via
|
||||
# the platform's A2A proxy.
|
||||
peer = await discover_peer(workspace_id, source_workspace_id=src)
|
||||
if not peer:
|
||||
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
|
||||
|
||||
if (peer.get("status") or "").lower() == "offline":
|
||||
return f"Error: workspace {workspace_id} is offline"
|
||||
|
||||
# Lazy import: a2a_tools imports this module at top-level, so a
|
||||
# top-level import of report_activity from a2a_tools would create a
|
||||
# circular dependency at first-import time. Lazy resolution inside
|
||||
# the function body breaks the cycle without forcing a ground-up
|
||||
# restructure of the activity-reporting layer.
|
||||
from a2a_tools import report_activity
|
||||
|
||||
# Report delegation start — include the task text for traceability
|
||||
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
|
||||
_peer_names[workspace_id] = peer_name # cache for future use
|
||||
# Brief summary for canvas display — just the delegation target
|
||||
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
|
||||
|
||||
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
|
||||
# use the platform's durable async delegation API (POST /delegate +
|
||||
# poll /delegations) instead of the proxy-blocked message/send path.
|
||||
# This sidesteps the 600s message/send timeout class that broke
|
||||
# iteration-14/90-style long-running delegations on 2026-05-05.
|
||||
#
|
||||
# Default off — staging-canary first, flip default after PR-2's
|
||||
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
|
||||
# ≥1 week without incident.
|
||||
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
|
||||
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
|
||||
else:
|
||||
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
|
||||
# (the platform proxy) so the same code works for in-container and
|
||||
# external (standalone molecule-mcp) callers.
|
||||
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
|
||||
|
||||
# Detect delegation failures — wrap them clearly so the calling agent
|
||||
# can decide to retry, use another peer, or handle the task itself.
|
||||
is_error = result.startswith(_A2A_ERROR_PREFIX)
|
||||
# Strip the sentinel prefix so error_detail is the human-readable
|
||||
# cause directly. The Activity tab's red error chip surfaces this
|
||||
# without the user having to scroll into the raw response JSON.
|
||||
#
|
||||
# Cap at 4096 chars before sending — the platform's
|
||||
# activity_logs.error_detail column is unbounded TEXT and a
|
||||
# malicious or buggy peer could otherwise stream an arbitrarily
|
||||
# large error message into the caller's activity log. 4096 is
|
||||
# comfortably above any real exception traceback we've seen and
|
||||
# well below an obvious-DoS threshold.
|
||||
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
|
||||
await report_activity(
|
||||
"a2a_receive", workspace_id,
|
||||
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
|
||||
task_text=task, response_text=result,
|
||||
status="error" if is_error else "ok",
|
||||
error_detail=error_detail,
|
||||
)
|
||||
if is_error:
|
||||
return (
|
||||
f"DELEGATION FAILED to {peer_name}: {result}\n"
|
||||
f"You should either: (1) try a different peer, (2) handle this task yourself, "
|
||||
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def tool_delegate_task_async(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task via the platform's async delegation API (fire-and-forget).
|
||||
|
||||
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
|
||||
Results are tracked in the platform DB and broadcast via WebSocket.
|
||||
Use check_task_status to poll for results.
|
||||
|
||||
``source_workspace_id`` selects the sending workspace (which one of
|
||||
this agent's registered workspaces gets logged as the originator);
|
||||
auto-routes via the peer→source cache when omitted.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
|
||||
|
||||
# Idempotency key: SHA-256 of (source, target, task) so that a
|
||||
# restarted agent firing the same delegation gets the same key and
|
||||
# the platform returns the existing delegation_id instead of
|
||||
# creating a duplicate. Fixes #1456. Source is in the key so the
|
||||
# SAME task delegated from two different registered workspaces
|
||||
# produces two distinct delegations (the right behavior — one per
|
||||
# tenant audit trail).
|
||||
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegate",
|
||||
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code == 202:
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"delegation_id": data.get("delegation_id", ""),
|
||||
"workspace_id": workspace_id,
|
||||
"status": "delegated",
|
||||
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
|
||||
})
|
||||
else:
|
||||
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
|
||||
except Exception as e:
|
||||
return f"Error: delegation failed — {e}"
|
||||
|
||||
|
||||
async def tool_check_task_status(
|
||||
workspace_id: str,
|
||||
task_id: str,
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Check delegations for this workspace via the platform API.
|
||||
|
||||
Args:
|
||||
workspace_id: Ignored (kept for backward compat). Checks
|
||||
``source_workspace_id``'s delegations (the workspace that
|
||||
FIRED the delegations), not the target's.
|
||||
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
|
||||
source_workspace_id: Which registered workspace's delegation log
|
||||
to query. Defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: failed to check delegations ({resp.status_code})"
|
||||
delegations = resp.json()
|
||||
if task_id:
|
||||
# Filter by delegation_id
|
||||
matching = [d for d in delegations if d.get("delegation_id") == task_id]
|
||||
if matching:
|
||||
return json.dumps(matching[0])
|
||||
return json.dumps({"status": "not_found", "delegation_id": task_id})
|
||||
# Return all recent delegations
|
||||
summary = []
|
||||
for d in delegations[:10]:
|
||||
summary.append({
|
||||
"delegation_id": d.get("delegation_id", ""),
|
||||
"target_id": d.get("target_id", ""),
|
||||
"status": d.get("status", ""),
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
})
|
||||
return json.dumps({"delegations": summary, "count": len(delegations)})
|
||||
except Exception as e:
|
||||
return f"Error checking delegations: {e}"
|
||||
+44
-8
@@ -553,10 +553,26 @@ def _poll_once(
|
||||
# Imported lazily at use-site so a runtime that never sees an
|
||||
# upload-receive row never imports the module. Cheap on the hot
|
||||
# path because Python caches the import.
|
||||
from inbox_uploads import is_chat_upload_row, fetch_and_stage
|
||||
from inbox_uploads import is_chat_upload_row, BatchFetcher
|
||||
|
||||
new_count = 0
|
||||
last_id: str | None = None
|
||||
# ``batch_fetcher`` is lazy: a poll batch with no upload rows pays
|
||||
# zero overhead. Once the first upload row appears we open one
|
||||
# BatchFetcher and submit every subsequent upload row to its thread
|
||||
# pool; before processing the FIRST non-upload row we drain the
|
||||
# pool (wait_all) so the URI cache is hot when message rewriting
|
||||
# runs. Without the barrier, the chat message that references the
|
||||
# upload would arrive at the agent with the un-rewritten
|
||||
# platform-pending: URI.
|
||||
batch_fetcher: BatchFetcher | None = None
|
||||
|
||||
def _drain_uploads(bf: BatchFetcher | None) -> None:
|
||||
if bf is None:
|
||||
return
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
@@ -570,14 +586,21 @@ def _poll_once(
|
||||
# message_from_activity. We DO advance the cursor past
|
||||
# this row so a permanent network outage on /content
|
||||
# doesn't stall the cursor and block real chat traffic.
|
||||
fetch_and_stage(
|
||||
row,
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
)
|
||||
if batch_fetcher is None:
|
||||
batch_fetcher = BatchFetcher(
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
)
|
||||
batch_fetcher.submit(row)
|
||||
last_id = str(row.get("id", "")) or last_id
|
||||
continue
|
||||
# Non-upload row: drain any pending uploads first so the URI
|
||||
# cache is populated before we run rewrite_request_body /
|
||||
# message_from_activity on a row that may reference one.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
batch_fetcher = None
|
||||
if _is_self_notify_row(row):
|
||||
# The workspace-server's `/notify` handler writes the agent's
|
||||
# own send_message_to_user POSTs to activity_logs with
|
||||
@@ -612,6 +635,13 @@ def _poll_once(
|
||||
last_id = message.activity_id
|
||||
new_count += 1
|
||||
|
||||
# Drain any uploads still in flight if the batch ended with upload
|
||||
# rows (no chat-message row to trigger the inline drain). Without
|
||||
# this, a future poll that picks up the chat-message row first
|
||||
# would race with the still-running fetches.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
|
||||
if last_id is not None:
|
||||
state.save_cursor(last_id, cursor_key)
|
||||
return new_count
|
||||
@@ -654,6 +684,7 @@ def start_poller_thread(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
interval: float = POLL_INTERVAL_SECONDS,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> threading.Thread:
|
||||
"""Spawn the poller as a daemon thread. Returns the Thread handle.
|
||||
|
||||
@@ -665,13 +696,18 @@ def start_poller_thread(
|
||||
operator running ``ps -eL`` or eyeballing ``threading.enumerate()``
|
||||
can tell which thread is which without reverse-engineering it from
|
||||
crash tracebacks.
|
||||
|
||||
Pass ``stop_event`` to enable graceful shutdown — used by tests so
|
||||
the daemon thread doesn't outlive the test that started it and race
|
||||
with later tests' httpx patches. Production code passes None and
|
||||
relies on the daemon flag for process-exit cleanup.
|
||||
"""
|
||||
name = "molecule-mcp-inbox-poller"
|
||||
if workspace_id:
|
||||
name = f"{name}-{workspace_id[:8]}"
|
||||
t = threading.Thread(
|
||||
target=_poll_loop,
|
||||
args=(state, platform_url, workspace_id, interval),
|
||||
args=(state, platform_url, workspace_id, interval, stop_event),
|
||||
name=name,
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
+264
-15
@@ -37,6 +37,7 @@ read another tenant's bytes even if a token is misrouted.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -68,6 +69,24 @@ MAX_FILE_BYTES = 25 * 1024 * 1024
|
||||
# 10s default for /activity calls — both are user-perceived latency.
|
||||
DEFAULT_FETCH_TIMEOUT = 60.0
|
||||
|
||||
# Concurrency cap for ``BatchFetcher``. Four workers is enough headroom
|
||||
# for the realistic "user dragged 3-4 files into chat at once" case
|
||||
# while bounding the platform's per-workspace fan-out. The cap matters
|
||||
# because the platform's /content endpoint reads bytea from Postgres in
|
||||
# a single round-trip per request — N workers = N concurrent DB reads
|
||||
# of up to 25 MB each, so a higher cap could pressure platform memory
|
||||
# without much UX win (network bandwidth is the bottleneck once the
|
||||
# bytes are buffered).
|
||||
DEFAULT_BATCH_FETCH_WORKERS = 4
|
||||
|
||||
# Upper bound on how long ``BatchFetcher.wait_all`` blocks the inbox
|
||||
# poll loop before giving up on still-in-flight fetches. Aligned with
|
||||
# DEFAULT_FETCH_TIMEOUT so a single hung fetch can't stall the loop
|
||||
# longer than its own deadline. A timeout fires only if a worker thread
|
||||
# is stuck past the underlying httpx timeout — pathological case;
|
||||
# normal completion is bounded by per-fetch timeout × ceil(N/W).
|
||||
DEFAULT_BATCH_WAIT_TIMEOUT = DEFAULT_FETCH_TIMEOUT + 5.0
|
||||
|
||||
# Cap on the URI cache. A long-lived workspace handling thousands of
|
||||
# uploads shouldn't grow without bound; an LRU cap of 1024 keeps the
|
||||
# entries-needed-for-a-typical-conversation well within memory.
|
||||
@@ -275,6 +294,7 @@ def fetch_and_stage(
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
client: Any = None,
|
||||
) -> str | None:
|
||||
"""Fetch the row's bytes, stage them under chat-uploads, and ack.
|
||||
|
||||
@@ -289,6 +309,11 @@ def fetch_and_stage(
|
||||
On success, the URI cache is updated so a subsequent chat message
|
||||
referencing the same ``platform-pending:`` URI is rewritten before
|
||||
the agent sees it.
|
||||
|
||||
Pass ``client`` to reuse a shared ``httpx.Client`` for both GET and
|
||||
POST ack (saves one TLS handshake per row vs. constructing one
|
||||
per-call). ``BatchFetcher`` does this across an entire poll batch so
|
||||
N concurrent fetches share one connection pool.
|
||||
"""
|
||||
body = _request_body_dict(row)
|
||||
if body is None:
|
||||
@@ -317,25 +342,58 @@ def fetch_and_stage(
|
||||
if not isinstance(filename, str):
|
||||
filename = "file"
|
||||
|
||||
# Lazy httpx import: the standalone MCP path uses httpx; an in-
|
||||
# container caller that imports this module by accident shouldn't
|
||||
# explode at import time.
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
logger.error("inbox_uploads: httpx not installed; cannot fetch %s", file_id)
|
||||
return None
|
||||
# Caller-supplied client: reuse for both GET + POST ack. Otherwise
|
||||
# build a one-shot client and close it on the way out. Lazy httpx
|
||||
# import keeps the standalone MCP path's optional dep optional.
|
||||
own_client = client is None
|
||||
if own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
logger.error("inbox_uploads: httpx not installed; cannot fetch %s", file_id)
|
||||
return None
|
||||
client = httpx.Client(timeout=timeout_secs)
|
||||
|
||||
try:
|
||||
return _fetch_and_stage_with_client(
|
||||
client,
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
file_id=file_id,
|
||||
pending_uri=pending_uri,
|
||||
filename=filename,
|
||||
body=body,
|
||||
)
|
||||
finally:
|
||||
if own_client:
|
||||
try:
|
||||
client.close()
|
||||
except Exception: # noqa: BLE001 — close should never crash the caller
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_and_stage_with_client(
|
||||
client: Any,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
file_id: str,
|
||||
pending_uri: str,
|
||||
filename: str,
|
||||
body: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Inner body of fetch_and_stage. Always uses the supplied client for
|
||||
both GET and POST so the connection pool is shared across the call.
|
||||
"""
|
||||
content_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/content"
|
||||
ack_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/ack"
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_secs) as client:
|
||||
resp = client.get(content_url, headers=headers)
|
||||
resp = client.get(content_url, headers=headers)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"inbox_uploads: GET %s failed: %s", content_url, exc
|
||||
)
|
||||
logger.warning("inbox_uploads: GET %s failed: %s", content_url, exc)
|
||||
return None
|
||||
|
||||
if resp.status_code == 404:
|
||||
@@ -403,8 +461,7 @@ def fetch_and_stage(
|
||||
# back the on-disk file — the platform's sweep will clean up
|
||||
# eventually.
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_secs) as client:
|
||||
ack_resp = client.post(ack_url, headers=headers)
|
||||
ack_resp = client.post(ack_url, headers=headers)
|
||||
if ack_resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"inbox_uploads: ack %s returned %d: %s",
|
||||
@@ -418,6 +475,198 @@ def fetch_and_stage(
|
||||
return local_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BatchFetcher — concurrent fetch across a single poll batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BatchFetcher:
|
||||
"""Fetch + stage + ack a batch of upload-receive rows concurrently.
|
||||
|
||||
Why this exists: the inbox poll loop used to call ``fetch_and_stage``
|
||||
serially per row. With N upload rows in a batch (a user dragging
|
||||
multiple files into chat at once), the loop blocked for
|
||||
``N × per_fetch_latency`` before processing the chat message that
|
||||
referenced them — a 4-file upload at 5s each = 20s of stall
|
||||
before the agent saw the user's prompt. ``BatchFetcher`` runs the
|
||||
fetches on a small thread pool (default 4 workers) so the stall is
|
||||
bounded by ``ceil(N/W) × per_fetch_latency`` instead.
|
||||
|
||||
Connection reuse: one ``httpx.Client`` is shared across every fetch
|
||||
in the batch. httpx clients carry a connection pool, so a second
|
||||
fetch to the same platform host reuses the TCP+TLS handshake from
|
||||
the first — measurable win when fetches happen back-to-back.
|
||||
|
||||
Correctness invariant the caller MUST preserve: the inbox loop is
|
||||
expected to call ``wait_all()`` before processing the chat-message
|
||||
activity row that REFERENCES one of these uploads. Without the
|
||||
barrier, the URI cache is empty when ``rewrite_request_body`` runs
|
||||
and the agent sees the un-rewritten ``platform-pending:`` URI. The
|
||||
caller-side test ``test_poll_once_waits_for_uploads_before_messages``
|
||||
pins this end-to-end.
|
||||
|
||||
Use as a context manager so the executor + client are torn down
|
||||
even if the caller raises mid-batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
max_workers: int = DEFAULT_BATCH_FETCH_WORKERS,
|
||||
client: Any = None,
|
||||
):
|
||||
self._platform_url = platform_url
|
||||
self._workspace_id = workspace_id
|
||||
self._headers = dict(headers) # copy so caller mutations don't leak in
|
||||
self._timeout_secs = timeout_secs
|
||||
|
||||
# Caller can inject a client (tests do this); production callers
|
||||
# let us build one. Track ownership so we only close ours.
|
||||
self._own_client = client is None
|
||||
if self._own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
# Match fetch_and_stage's behavior: log + degrade rather
|
||||
# than raising at construction time. submit() will then
|
||||
# return None for every row.
|
||||
logger.error("inbox_uploads: httpx not installed; BatchFetcher inert")
|
||||
self._client: Any = None
|
||||
else:
|
||||
self._client = httpx.Client(timeout=timeout_secs)
|
||||
else:
|
||||
self._client = client
|
||||
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="upload-fetch",
|
||||
)
|
||||
self._futures: list[concurrent.futures.Future[Any]] = []
|
||||
self._closed = False
|
||||
# Flipped to True by wait_all when the timeout fires; close()
|
||||
# reads this to decide between drain-and-wait vs cancel-queued.
|
||||
self._timed_out = False
|
||||
|
||||
def submit(self, row: dict[str, Any]) -> concurrent.futures.Future[Any] | None:
|
||||
"""Submit ``row`` for fetch + stage + ack. Non-blocking — the
|
||||
worker thread runs ``fetch_and_stage`` with the shared client.
|
||||
|
||||
Returns the Future so a caller that wants per-row outcome can
|
||||
await it; ``None`` if the BatchFetcher is in a degraded state
|
||||
(httpx missing).
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("BatchFetcher: submit after close")
|
||||
if self._client is None:
|
||||
return None
|
||||
fut = self._executor.submit(
|
||||
fetch_and_stage,
|
||||
row,
|
||||
platform_url=self._platform_url,
|
||||
workspace_id=self._workspace_id,
|
||||
headers=self._headers,
|
||||
timeout_secs=self._timeout_secs,
|
||||
client=self._client,
|
||||
)
|
||||
self._futures.append(fut)
|
||||
return fut
|
||||
|
||||
def wait_all(self, timeout: float | None = DEFAULT_BATCH_WAIT_TIMEOUT) -> None:
|
||||
"""Block until every submitted future completes (or times out).
|
||||
|
||||
Per-future exceptions are logged + swallowed — ``fetch_and_stage``
|
||||
already converts every error path to ``return None``, so a real
|
||||
exception propagating up to here is unexpected and we don't want
|
||||
one bad fetch to abort the whole batch.
|
||||
|
||||
Timeouts are also logged + swallowed AND record the timed-out
|
||||
futures on ``self._timed_out`` so ``close`` can cancel them
|
||||
without paying their full latency. Without this hand-off,
|
||||
``close()``'s ``shutdown(wait=True)`` would block on the leaked
|
||||
workers and undo the user-facing timeout — the inbox poll loop
|
||||
would stall indefinitely on a hung /content fetch.
|
||||
"""
|
||||
if not self._futures:
|
||||
return
|
||||
try:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
self._futures,
|
||||
timeout=timeout,
|
||||
return_when=concurrent.futures.ALL_COMPLETED,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — concurrent.futures shouldn't raise here
|
||||
logger.warning("inbox_uploads: BatchFetcher.wait_all crashed: %s", exc)
|
||||
return
|
||||
for fut in done:
|
||||
exc = fut.exception()
|
||||
if exc is not None:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher worker raised: %s", exc
|
||||
)
|
||||
if not_done:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher.wait_all left %d in-flight after %ss timeout",
|
||||
len(not_done),
|
||||
timeout,
|
||||
)
|
||||
# Mark these futures so close() knows to cancel-not-wait. We
|
||||
# cancel queued-but-not-started ones immediately; futures
|
||||
# already running can't be cancelled (Python's threading
|
||||
# model), but close() will pass cancel_futures=True so any
|
||||
# remaining queued items don't run.
|
||||
for fut in not_done:
|
||||
fut.cancel()
|
||||
self._timed_out = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Tear down the executor + (if owned) the httpx client.
|
||||
|
||||
Idempotent. After close, ``submit`` raises and the BatchFetcher
|
||||
cannot be reused — construct a fresh one for the next poll.
|
||||
|
||||
If ``wait_all`` reported a timeout, shutdown skips the
|
||||
``wait=True`` drain and instead asks the executor to drop queued
|
||||
futures (``cancel_futures=True``). Currently-running workers
|
||||
can't be interrupted by Python's threading model, but the poll
|
||||
loop returns immediately rather than blocking on a hung fetch.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
timed_out = getattr(self, "_timed_out", False)
|
||||
try:
|
||||
if timed_out:
|
||||
# cancel_futures landed in Python 3.9 — guarded for older
|
||||
# interpreters via a TypeError fallback. Drop queued
|
||||
# tasks; running ones will exit when their httpx call
|
||||
# eventually returns or the daemon thread dies.
|
||||
try:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
except TypeError:
|
||||
self._executor.shutdown(wait=False)
|
||||
else:
|
||||
# Healthy path: wait for in-flight work so we don't
|
||||
# interrupt a fetch mid-write.
|
||||
self._executor.shutdown(wait=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: executor shutdown error: %s", exc)
|
||||
if self._own_client and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: client close error: %s", exc)
|
||||
|
||||
def __enter__(self) -> "BatchFetcher":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URI rewrite for incoming chat messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -339,8 +339,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen_send_src["src"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
@@ -367,8 +367,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen["send"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(
|
||||
peer_id, "do thing", source_workspace_id=ws_explicit,
|
||||
@@ -395,8 +395,8 @@ class TestToolDelegateTaskAutoRouting:
|
||||
seen["send"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Drift gate + direct surface tests for ``a2a_tools_delegation`` (RFC #2873 iter 4b).
|
||||
|
||||
The full behavior matrix for the three delegation MCP tools lives in
|
||||
``test_a2a_tools_impl.py`` (TestToolDelegateTask + TestToolDelegateTaskAsync
|
||||
+ TestToolCheckTaskStatus). Those exercise call paths through the
|
||||
``a2a_tools_delegation.foo`` module (after the iter 4b retarget).
|
||||
|
||||
This file owns the post-split contract:
|
||||
|
||||
1. **Drift gate** — every previously-public symbol on ``a2a_tools``
|
||||
(``tool_delegate_task``, ``tool_delegate_task_async``,
|
||||
``tool_check_task_status``, ``_delegate_sync_via_polling``,
|
||||
``_SYNC_POLL_INTERVAL_S``, ``_SYNC_POLL_BUDGET_S``) is the EXACT
|
||||
same callable / value as the new module's public name. A wrapper
|
||||
that drifted would silently bypass tests targeting the wrapper.
|
||||
|
||||
2. **Smoke import** — both modules import in either order without
|
||||
raising (the lazy ``report_activity`` import inside
|
||||
``tool_delegate_task`` is the contract that prevents a circular
|
||||
import; this test pins it).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_workspace_id(monkeypatch):
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ============== Drift gate ==============
|
||||
|
||||
class TestBackCompatAliases:
|
||||
def test_tool_delegate_task_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert a2a_tools.tool_delegate_task is a2a_tools_delegation.tool_delegate_task
|
||||
|
||||
def test_tool_delegate_task_async_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools.tool_delegate_task_async
|
||||
is a2a_tools_delegation.tool_delegate_task_async
|
||||
)
|
||||
|
||||
def test_tool_check_task_status_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools.tool_check_task_status
|
||||
is a2a_tools_delegation.tool_check_task_status
|
||||
)
|
||||
|
||||
def test_delegate_sync_via_polling_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools._delegate_sync_via_polling
|
||||
is a2a_tools_delegation._delegate_sync_via_polling
|
||||
)
|
||||
|
||||
def test_constants_match(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_delegation
|
||||
assert (
|
||||
a2a_tools._SYNC_POLL_INTERVAL_S
|
||||
== a2a_tools_delegation._SYNC_POLL_INTERVAL_S
|
||||
)
|
||||
assert (
|
||||
a2a_tools._SYNC_POLL_BUDGET_S
|
||||
== a2a_tools_delegation._SYNC_POLL_BUDGET_S
|
||||
)
|
||||
|
||||
|
||||
# ============== Smoke imports ==============
|
||||
|
||||
class TestImportContracts:
|
||||
def test_delegation_imports_without_a2a_tools_loaded(self, monkeypatch):
|
||||
"""``a2a_tools_delegation`` should NOT pull in ``a2a_tools`` at
|
||||
module-load time. The lazy ``from a2a_tools import report_activity``
|
||||
inside ``tool_delegate_task`` is the only legitimate hop.
|
||||
|
||||
Pin this so a future refactor that adds a top-level
|
||||
``from a2a_tools import …`` re-introduces the circular-import
|
||||
crash that motivated the lazy pattern.
|
||||
"""
|
||||
import sys
|
||||
# Drop both modules so we re-import in a controlled order
|
||||
for mod in ("a2a_tools", "a2a_tools_delegation"):
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
# Importing delegation first must succeed without a2a_tools
|
||||
# being loaded (because a2a_tools imports delegation, the
|
||||
# circular path ONLY closes if delegation top-level imports
|
||||
# something from a2a_tools).
|
||||
import a2a_tools_delegation # noqa: F401
|
||||
# If we got here, no circular import.
|
||||
assert "a2a_tools_delegation" in sys.modules
|
||||
|
||||
def test_a2a_tools_imports_via_delegation_re_export(self):
|
||||
"""The opposite direction: importing a2a_tools must trigger the
|
||||
delegation re-export so a2a_tools.tool_delegate_task resolves."""
|
||||
import a2a_tools
|
||||
assert hasattr(a2a_tools, "tool_delegate_task")
|
||||
assert hasattr(a2a_tools, "tool_delegate_task_async")
|
||||
assert hasattr(a2a_tools, "tool_check_task_status")
|
||||
|
||||
|
||||
# ============== Sync-poll budget env override ==============
|
||||
|
||||
class TestPollBudgetEnvOverride:
|
||||
def test_default_budget_when_env_unset(self):
|
||||
"""Module-level constant. Set DELEGATION_TIMEOUT before importing
|
||||
a2a_tools_delegation to override; default is 300.0."""
|
||||
# The constant is computed at module-load time. To verify the
|
||||
# override path we'd need to reload — skipped here because it's
|
||||
# tested at boot. This test pins the default for catch-the-eye
|
||||
# documentation.
|
||||
import a2a_tools_delegation
|
||||
# Whatever was set when the module first loaded — assert it's
|
||||
# numeric and >= the documented floor (180s healthsweep budget).
|
||||
assert isinstance(a2a_tools_delegation._SYNC_POLL_BUDGET_S, float)
|
||||
assert a2a_tools_delegation._SYNC_POLL_BUDGET_S >= 180.0
|
||||
@@ -226,16 +226,16 @@ class TestToolDelegateTask:
|
||||
|
||||
async def test_peer_not_found_returns_error(self):
|
||||
import a2a_tools
|
||||
with patch("a2a_tools.discover_peer", return_value=None):
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=None):
|
||||
result = await a2a_tools.tool_delegate_task("ws-missing", "task")
|
||||
assert "not found" in result or "Error" in result
|
||||
|
||||
async def test_offline_peer_returns_error(self):
|
||||
"""A peer with status=offline short-circuits before we hit the proxy."""
|
||||
import a2a_tools
|
||||
with patch("a2a_tools.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
|
||||
mc = _make_http_mock()
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "task")
|
||||
assert "offline" in result.lower()
|
||||
|
||||
@@ -261,8 +261,8 @@ class TestToolDelegateTask:
|
||||
captured["source"] = source_workspace_id
|
||||
return "ok"
|
||||
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
await a2a_tools.tool_delegate_task(peer_id, "do thing")
|
||||
|
||||
@@ -274,8 +274,8 @@ class TestToolDelegateTask:
|
||||
import a2a_tools
|
||||
|
||||
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="Task completed!"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="Task completed!"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
|
||||
|
||||
@@ -287,8 +287,8 @@ class TestToolDelegateTask:
|
||||
|
||||
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
|
||||
error_msg = f"{a2a_tools._A2A_ERROR_PREFIX}Agent error: something bad"
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value=error_msg), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value=error_msg), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
|
||||
|
||||
@@ -302,8 +302,8 @@ class TestToolDelegateTask:
|
||||
# Pre-populate the cache
|
||||
a2a_tools._peer_names["ws-cached"] = "CachedName"
|
||||
peer = {"id": "ws-cached", "url": "http://ws-cached.svc/a2a"} # no 'name'
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="done"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="done"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-cached", "task")
|
||||
|
||||
@@ -316,8 +316,8 @@ class TestToolDelegateTask:
|
||||
# Ensure not in cache
|
||||
a2a_tools._peer_names.pop("ws-nona000", None)
|
||||
peer = {"id": "ws-nona000", "url": "http://x.svc/a2a"} # no 'name'
|
||||
with patch("a2a_tools.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools.send_a2a_message", return_value="ok"), \
|
||||
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
|
||||
patch("a2a_tools_delegation.send_a2a_message", return_value="ok"), \
|
||||
patch("a2a_tools.report_activity", new=AsyncMock()):
|
||||
result = await a2a_tools.tool_delegate_task("ws-nona000", "task")
|
||||
|
||||
@@ -349,7 +349,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(202, {"delegation_id": "d-123", "status": "delegated"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -362,7 +362,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(500, {"error": "internal"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
assert "Error" in result
|
||||
@@ -372,7 +372,7 @@ class TestToolDelegateTaskAsync:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_exc=httpx.ConnectError("connection refused"))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
|
||||
|
||||
assert "Error" in result or "failed" in result.lower()
|
||||
@@ -393,7 +393,7 @@ class TestToolCheckTaskStatus:
|
||||
{"delegation_id": "d-2", "target_id": "ws-u", "status": "pending", "summary": "waiting"},
|
||||
]
|
||||
mc = _make_http_mock(get_resp=_resp(200, delegations))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -409,7 +409,7 @@ class TestToolCheckTaskStatus:
|
||||
{"delegation_id": "d-2", "status": "pending"},
|
||||
]
|
||||
mc = _make_http_mock(get_resp=_resp(200, delegations))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -421,7 +421,7 @@ class TestToolCheckTaskStatus:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-missing")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -432,7 +432,7 @@ class TestToolCheckTaskStatus:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(500, {"error": "db down"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
|
||||
|
||||
assert "Error" in result or "failed" in result.lower()
|
||||
|
||||
@@ -80,10 +80,10 @@ class TestFlagOffLegacyPath:
|
||||
async def fake_report_activity(*_a, **_kw):
|
||||
return None
|
||||
|
||||
with patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools.discover_peer", side_effect=fake_discover), \
|
||||
with patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
|
||||
patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
|
||||
patch("a2a_tools.report_activity", side_effect=fake_report_activity), \
|
||||
patch("a2a_tools._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
|
||||
patch("a2a_tools_delegation._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
|
||||
result = await a2a_tools.tool_delegate_task(
|
||||
"ws-target", "task body", source_workspace_id="ws-self"
|
||||
)
|
||||
@@ -105,7 +105,7 @@ class TestFlagOnDispatchFailures:
|
||||
import a2a_tools
|
||||
mc = _make_client(post_exc=httpx.ConnectError("network down"))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class TestFlagOnDispatchFailures:
|
||||
import a2a_tools
|
||||
mc = _make_client(post_resp=_resp(403, {"error": "forbidden"}))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -134,7 +134,7 @@ class TestFlagOnDispatchFailures:
|
||||
# 202 Accepted but no delegation_id field — defensive shape check.
|
||||
mc = _make_client(post_resp=_resp(202, {"status": "delegated"}))
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -168,7 +168,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[_resp(200, [completed_row])],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -196,7 +196,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[_resp(200, [failed_row])],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -234,7 +234,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=get_seq,
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -266,7 +266,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=get_seq,
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
@@ -304,7 +304,7 @@ class TestFlagOnPollingOutcomes:
|
||||
get_resps=[first_poll, second_poll],
|
||||
)
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
|
||||
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
|
||||
res = await a2a_tools._delegate_sync_via_polling(
|
||||
"ws-target", "task", "ws-self"
|
||||
)
|
||||
|
||||
@@ -555,16 +555,34 @@ def test_poll_once_self_notify_does_not_fire_notification(state: inbox.InboxStat
|
||||
def test_start_poller_thread_is_daemon(state: inbox.InboxState):
|
||||
"""Daemon flag is required so the poller dies with the parent
|
||||
process; a non-daemon poller would leak across `claude` restarts
|
||||
and write to a stale workspace."""
|
||||
and write to a stale workspace.
|
||||
|
||||
Stop_event is plumbed so the thread cleans up at the end of the
|
||||
test instead of leaking into later tests. Without cleanup, the
|
||||
daemon's ~10ms tick races with later tests that patch httpx.Client
|
||||
— the leaked thread sees their patched response and runs an
|
||||
unwanted iteration of _poll_once that double-counts mocked calls
|
||||
(caught when test_batch_fetcher_owns_client_when_not_supplied
|
||||
surfaced this on Python 3.11 CI but not 3.13 local).
|
||||
"""
|
||||
resp = _make_response(200, [])
|
||||
p, _ = _patch_httpx(resp)
|
||||
stop_event = threading.Event()
|
||||
with p, patch("platform_auth.auth_headers", return_value={}):
|
||||
# Use a very short interval so the loop body runs at least once
|
||||
# before we exit the test.
|
||||
t = inbox.start_poller_thread(state, "http://platform", "ws-1", interval=0.01)
|
||||
t = inbox.start_poller_thread(
|
||||
state, "http://platform", "ws-1", interval=0.01, stop_event=stop_event
|
||||
)
|
||||
time.sleep(0.05)
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
# Signal shutdown + wait for the thread to actually exit before
|
||||
# we leave the test scope. Without this join, the leaked thread
|
||||
# races with later tests' httpx patches.
|
||||
stop_event.set()
|
||||
t.join(timeout=2.0)
|
||||
assert not t.is_alive(), "poller thread did not exit on stop_event"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -577,6 +595,219 @@ def test_default_cursor_path_uses_configs_dir(monkeypatch, tmp_path: Path):
|
||||
assert inbox.default_cursor_path() == tmp_path / ".mcp_inbox_cursor"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 5b — BatchFetcher integration with the poll loop
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# These tests pin the cross-module contract between inbox._poll_once and
|
||||
# inbox_uploads.BatchFetcher: chat_upload_receive rows must be submitted
|
||||
# to a single BatchFetcher AND drained (URI cache populated) before any
|
||||
# subsequent message row is processed. Without the drain, the
|
||||
# rewrite_request_body path inside message_from_activity surfaces the
|
||||
# un-rewritten ``platform-pending:`` URI to the agent.
|
||||
|
||||
|
||||
def _upload_row(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": f"chat_upload_receive: {file_id}.pdf",
|
||||
"request_body": {
|
||||
"file_id": file_id,
|
||||
"name": f"{file_id}.pdf",
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 3,
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _message_row_referencing(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {
|
||||
"params": {
|
||||
"message": {
|
||||
"parts": [
|
||||
{"kind": "text", "text": "have a look"},
|
||||
{
|
||||
"kind": "file",
|
||||
"file": {
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"name": f"{file_id}.pdf",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:01Z",
|
||||
}
|
||||
|
||||
|
||||
def _patch_httpx_routing(activity_rows: list[dict], upload_bytes: bytes = b"PDF"):
|
||||
"""Replace ``httpx.Client`` so:
|
||||
|
||||
- GET /activity returns ``activity_rows``
|
||||
- GET /workspaces/.../content returns ``upload_bytes`` with content-type
|
||||
- POST /ack returns 200
|
||||
|
||||
Returns the patch context manager; tests use ``with p:``. Each new
|
||||
Client(...) gets a fresh MagicMock so the test can verify
|
||||
constructor-count expectations without pinning singletons.
|
||||
"""
|
||||
def _client_factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.__enter__ = MagicMock(return_value=c)
|
||||
c.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
def _get(url, params=None, headers=None):
|
||||
if "/activity" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = activity_rows
|
||||
resp.text = ""
|
||||
return resp
|
||||
if "/pending-uploads/" in url and "/content" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.content = upload_bytes
|
||||
resp.headers = {"content-type": "application/pdf"}
|
||||
resp.text = ""
|
||||
return resp
|
||||
resp = MagicMock()
|
||||
resp.status_code = 404
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
def _post(url, headers=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
c.get = MagicMock(side_effect=_get)
|
||||
c.post = MagicMock(side_effect=_post)
|
||||
c.close = MagicMock()
|
||||
return c
|
||||
|
||||
return patch("httpx.Client", side_effect=_client_factory)
|
||||
|
||||
|
||||
def test_poll_once_drains_uploads_before_processing_message_row(state: inbox.InboxState, tmp_path):
|
||||
"""The chat-message row's file.uri MUST be rewritten to the local
|
||||
workspace: URI by the time it lands in the InboxState queue. This
|
||||
requires BatchFetcher.wait_all() to run before message_from_activity
|
||||
on the second row.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Sandbox the on-disk staging dir so the test can't pollute the
|
||||
# workspace's real chat-uploads.
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [
|
||||
_upload_row("act-1", "file-A"),
|
||||
_message_row_referencing("act-2", "file-A"),
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF-bytes"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
|
||||
assert n == 1, "exactly one message row should be enqueued (the upload row is a side-effect, not a message)"
|
||||
queued = state.peek(10)
|
||||
assert len(queued) == 1
|
||||
# The contract this test exists to pin: the platform-pending: URI
|
||||
# was rewritten to workspace: BEFORE the message landed in the
|
||||
# state queue. message_from_activity mutates row['request_body']
|
||||
# in-place, so the rewritten URI is observable on the row dict
|
||||
# we passed in.
|
||||
rewritten_part = rows[1]["request_body"]["params"]["message"]["parts"][1]
|
||||
assert rewritten_part["file"]["uri"].startswith("workspace:"), (
|
||||
f"upload barrier broken: file.uri = {rewritten_part['file']['uri']!r}; "
|
||||
"rewrite_request_body ran before BatchFetcher.wait_all populated the cache"
|
||||
)
|
||||
# Cursor advanced past BOTH rows — upload-receive (act-1) is
|
||||
# acknowledged via the inbox cursor regardless of fetch outcome.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_with_only_upload_rows_drains_at_loop_end(state: inbox.InboxState, tmp_path):
|
||||
"""End-of-batch drain: a poll that contains ONLY upload rows (no
|
||||
chat-message row to trigger the inline drain) must still drain the
|
||||
BatchFetcher before _poll_once returns. Otherwise a future poll
|
||||
that picks up the corresponding chat-message row would race with
|
||||
in-flight fetches from the previous batch.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [_upload_row("act-1", "file-A"), _upload_row("act-2", "file-B")]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
# By the time _poll_once returned, the URI cache must be hot
|
||||
# for both file_ids — proves the end-of-loop drain ran.
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-A") is not None
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-B") is not None
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Upload rows are NOT message rows; queue stays empty.
|
||||
assert n == 0
|
||||
# Cursor advances past both upload rows.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_no_uploads_does_not_construct_batch_fetcher(state: inbox.InboxState):
|
||||
"""A batch with no upload-receive rows must not pay the BatchFetcher
|
||||
construction cost — the executor + httpx client allocation is
|
||||
deferred until the first upload row appears.
|
||||
"""
|
||||
import inbox_uploads
|
||||
|
||||
constructed: list[Any] = []
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
constructed.append(kwargs)
|
||||
# Don't actually run __init__; we never hit submit/wait_all.
|
||||
self._closed = False
|
||||
self._futures = []
|
||||
self._executor = MagicMock()
|
||||
self._client = MagicMock()
|
||||
self._own_client = False
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-1",
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {"parts": [{"type": "text", "text": "hi"}]},
|
||||
"created_at": "2026-04-30T22:00:00Z",
|
||||
},
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
with patch.object(inbox_uploads.BatchFetcher, "__init__", _patched_init), p:
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
assert n == 1
|
||||
assert constructed == [], "BatchFetcher must not be constructed when no upload rows are present"
|
||||
|
||||
|
||||
def test_default_cursor_path_falls_back_to_default(tmp_path, monkeypatch):
|
||||
"""When CONFIGS_DIR is unset, the cursor path resolves through
|
||||
configs_dir.resolve() — /configs in-container, ~/.molecule-workspace
|
||||
|
||||
@@ -695,3 +695,426 @@ def test_rewrite_request_body_handles_non_list_parts():
|
||||
def test_rewrite_request_body_handles_non_dict_file():
|
||||
body = {"parts": [{"kind": "file", "file": "not a dict"}]}
|
||||
inbox_uploads.rewrite_request_body(body) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_and_stage with shared client — Phase 5b client-reuse contract
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# When a caller passes ``client=`` to fetch_and_stage, that client must be
|
||||
# used for BOTH the GET /content and the POST /ack — no fresh
|
||||
# ``httpx.Client(...)`` constructions should happen. The pre-Phase-5b
|
||||
# implementation made one new client for GET and another for ack; the new
|
||||
# shape lets BatchFetcher share one connection pool across an entire batch.
|
||||
|
||||
|
||||
def test_fetch_and_stage_with_supplied_client_does_not_construct_new_client(monkeypatch):
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
get_resp = _make_resp(200, content=b"PDF", content_type="application/pdf")
|
||||
ack_resp = _make_resp(200)
|
||||
supplied = MagicMock()
|
||||
supplied.get = MagicMock(return_value=get_resp)
|
||||
supplied.post = MagicMock(return_value=ack_resp)
|
||||
# Sentinel: any code path that constructs httpx.Client when one was
|
||||
# already supplied is a regression — count constructions.
|
||||
constructed: list[Any] = []
|
||||
|
||||
class _ShouldNotBeCalled:
|
||||
def __init__(self, *a, **kw):
|
||||
constructed.append((a, kw))
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _ShouldNotBeCalled)
|
||||
|
||||
local_uri = inbox_uploads.fetch_and_stage(
|
||||
row,
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={"Authorization": "Bearer t"},
|
||||
client=supplied,
|
||||
)
|
||||
assert local_uri is not None
|
||||
assert constructed == [], "supplied client must be reused; no new Client should be constructed"
|
||||
# GET + POST ack both went through the supplied client.
|
||||
supplied.get.assert_called_once()
|
||||
supplied.post.assert_called_once()
|
||||
# Caller-owned client must NOT be closed by fetch_and_stage; the
|
||||
# batch fetcher (or test) closes it once the whole batch is done.
|
||||
supplied.close.assert_not_called()
|
||||
|
||||
|
||||
def test_fetch_and_stage_without_supplied_client_constructs_and_closes_one(monkeypatch):
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
get_resp = _make_resp(200, content=b"PDF", content_type="application/pdf")
|
||||
ack_resp = _make_resp(200)
|
||||
built: list[MagicMock] = []
|
||||
|
||||
def _factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.get = MagicMock(return_value=get_resp)
|
||||
c.post = MagicMock(return_value=ack_resp)
|
||||
built.append(c)
|
||||
return c
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _factory)
|
||||
|
||||
local_uri = inbox_uploads.fetch_and_stage(
|
||||
row, platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
assert local_uri is not None
|
||||
# Pre-Phase-5b built TWO clients (one for GET, one for ack); now exactly one.
|
||||
assert len(built) == 1, f"expected 1 httpx.Client construction, got {len(built)}"
|
||||
# Same client must serve BOTH calls.
|
||||
built[0].get.assert_called_once()
|
||||
built[0].post.assert_called_once()
|
||||
# Owned client must be closed by fetch_and_stage on the way out.
|
||||
built[0].close.assert_called_once()
|
||||
|
||||
|
||||
def test_fetch_and_stage_with_supplied_client_does_not_close_caller_client():
|
||||
# Even on failure the supplied client must not be closed — the
|
||||
# BatchFetcher owns the lifecycle for the whole batch.
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
supplied = MagicMock()
|
||||
supplied.get = MagicMock(side_effect=RuntimeError("network down"))
|
||||
supplied.post = MagicMock() # should not be reached on GET failure
|
||||
inbox_uploads.fetch_and_stage(
|
||||
row,
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=supplied,
|
||||
)
|
||||
supplied.close.assert_not_called()
|
||||
supplied.post.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BatchFetcher — concurrent fetch + URI cache barrier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_with_id(act_id: str, file_id: str) -> dict:
|
||||
"""Helper: an upload-receive row with a distinct activity id + file id."""
|
||||
return {
|
||||
"id": act_id,
|
||||
"method": "chat_upload_receive",
|
||||
"request_body": {
|
||||
"file_id": file_id,
|
||||
"name": f"{file_id}.pdf",
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _stub_client_for_batch(get_responses: dict[str, MagicMock]) -> MagicMock:
|
||||
"""Build one MagicMock client that returns per-file_id responses
|
||||
based on the file_id segment of the URL.
|
||||
"""
|
||||
client = MagicMock()
|
||||
|
||||
def _get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
for fid, resp in get_responses.items():
|
||||
if f"/pending-uploads/{fid}/content" in url:
|
||||
return resp
|
||||
return _make_resp(404)
|
||||
|
||||
def _post(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
return _make_resp(200)
|
||||
|
||||
client.get = MagicMock(side_effect=_get)
|
||||
client.post = MagicMock(side_effect=_post)
|
||||
return client
|
||||
|
||||
|
||||
def test_batch_fetcher_runs_submitted_rows_concurrently():
|
||||
# Three rows whose .get() blocks for ~120ms each. With 4 workers the
|
||||
# batch should complete in ~120ms (parallel), not ~360ms (serial).
|
||||
# The 250ms ceiling accommodates CI scheduler jitter while still
|
||||
# discriminating concurrent (~120ms) from serial (~360ms).
|
||||
import time
|
||||
|
||||
barrier_start = [0.0]
|
||||
|
||||
def _slow_get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
time.sleep(0.12)
|
||||
for fid in ("a", "b", "c"):
|
||||
if f"/pending-uploads/{fid}/content" in url:
|
||||
return _make_resp(200, content=b"X", content_type="text/plain")
|
||||
return _make_resp(404)
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=4,
|
||||
)
|
||||
barrier_start[0] = time.time()
|
||||
for fid in ("a", "b", "c"):
|
||||
bf.submit(_row_with_id(f"act-{fid}", fid))
|
||||
bf.wait_all()
|
||||
elapsed = time.time() - barrier_start[0]
|
||||
bf.close()
|
||||
|
||||
assert elapsed < 0.25, (
|
||||
f"3 rows × 120ms with 4 workers should finish in <250ms; got {elapsed:.3f}s "
|
||||
"(suggests serial execution — Phase 5b regression)"
|
||||
)
|
||||
assert client.get.call_count == 3
|
||||
assert client.post.call_count == 3
|
||||
|
||||
|
||||
def test_batch_fetcher_wait_all_blocks_until_uri_cache_populated():
|
||||
"""Pin the correctness invariant: when wait_all returns, the URI
|
||||
cache is hot for every submitted row. Without this barrier the
|
||||
inbox loop would process the chat-message row before its uploads
|
||||
were staged, and rewrite_request_body would surface the un-rewritten
|
||||
platform-pending: URI to the agent.
|
||||
"""
|
||||
import time
|
||||
|
||||
def _slow_get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
time.sleep(0.05)
|
||||
return _make_resp(200, content=b"data", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
inbox_uploads.get_cache().clear()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.submit(_row_with_id("act-b", "b"))
|
||||
bf.wait_all()
|
||||
# Cache must be hot for BOTH rows by the time wait_all returns.
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/a") is not None
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/b") is not None
|
||||
|
||||
|
||||
def test_batch_fetcher_isolates_per_row_failure():
|
||||
"""One failing fetch must not abort siblings. Sibling rows complete,
|
||||
URI cache populates for them; the bad row's cache entry stays absent.
|
||||
"""
|
||||
def _get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
if "/pending-uploads/bad/content" in url:
|
||||
return _make_resp(500, text="upstream broken")
|
||||
return _make_resp(200, content=b"ok", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
inbox_uploads.get_cache().clear()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-1", "good1"))
|
||||
bf.submit(_row_with_id("act-2", "bad"))
|
||||
bf.submit(_row_with_id("act-3", "good2"))
|
||||
bf.wait_all()
|
||||
|
||||
cache = inbox_uploads.get_cache()
|
||||
assert cache.get("platform-pending:ws-1/good1") is not None
|
||||
assert cache.get("platform-pending:ws-1/good2") is not None
|
||||
assert cache.get("platform-pending:ws-1/bad") is None
|
||||
|
||||
|
||||
def test_batch_fetcher_reuses_one_client_across_all_submits():
|
||||
"""Every row in the batch must share the same client instance. This
|
||||
is the connection-pool-reuse leg of the perf win: a second fetch
|
||||
to the same host reuses the TCP+TLS handshake from the first.
|
||||
"""
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
for fid in ("a", "b", "c"):
|
||||
bf.submit(_row_with_id(f"act-{fid}", fid))
|
||||
bf.wait_all()
|
||||
|
||||
# 3 GETs + 3 POST acks all on the same client — no per-row Client
|
||||
# construction.
|
||||
assert client.get.call_count == 3
|
||||
assert client.post.call_count == 3
|
||||
|
||||
|
||||
def test_batch_fetcher_close_idempotent():
|
||||
client = MagicMock()
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
)
|
||||
bf.close()
|
||||
bf.close() # second call must not raise
|
||||
|
||||
|
||||
def test_batch_fetcher_submit_after_close_raises():
|
||||
client = MagicMock()
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
)
|
||||
bf.close()
|
||||
with pytest.raises(RuntimeError, match="submit after close"):
|
||||
bf.submit(_row_with_id("act-x", "x"))
|
||||
|
||||
|
||||
def test_batch_fetcher_owns_client_when_not_supplied(monkeypatch):
|
||||
built: list[MagicMock] = []
|
||||
|
||||
def _factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
c.post = MagicMock(return_value=_make_resp(200))
|
||||
built.append(c)
|
||||
return c
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _factory)
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
|
||||
assert len(built) == 1, "expected one owned client per BatchFetcher"
|
||||
built[0].close.assert_called_once()
|
||||
|
||||
|
||||
def test_batch_fetcher_does_not_close_supplied_client():
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
# Supplied client survives the BatchFetcher's close — caller's lifecycle.
|
||||
client.close.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_fetcher_wait_all_no_op_on_empty_batch():
|
||||
client = MagicMock()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.wait_all() # nothing submitted; must not block, must not raise
|
||||
client.get.assert_not_called()
|
||||
client.post.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_fetcher_httpx_missing_makes_submit_a_noop(monkeypatch):
|
||||
# No client supplied + httpx import fails → BatchFetcher degrades
|
||||
# gracefully: submit() returns None and the row is silently skipped.
|
||||
import sys
|
||||
|
||||
real_httpx = sys.modules.pop("httpx", None)
|
||||
monkeypatch.setitem(sys.modules, "httpx", None)
|
||||
try:
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
result = bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
finally:
|
||||
if real_httpx is not None:
|
||||
sys.modules["httpx"] = real_httpx
|
||||
else:
|
||||
sys.modules.pop("httpx", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_batch_fetcher_close_after_timeout_does_not_block_on_running_workers():
|
||||
"""The deadline contract: when wait_all times out, close() must NOT
|
||||
block waiting for the leaked worker threads. Otherwise the inbox
|
||||
poll loop stalls indefinitely on a hung /content fetch — undoing
|
||||
the user-facing timeout.
|
||||
|
||||
Strategy: build a client whose .get() blocks on a threading.Event
|
||||
that the test never sets. Submit a row, wait_all with a tiny
|
||||
timeout, then time close(). If close() drained-and-waited it would
|
||||
block until we set the event (i.e., forever in this test).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
blocker = threading.Event() # never set — workers stay running
|
||||
|
||||
def _hang_get(url, headers=None):
|
||||
# Wait at most ~5s so a buggy implementation eventually unblocks
|
||||
# the test instead of timing out the whole pytest run, but
|
||||
# nothing legitimate should reach this fallback.
|
||||
blocker.wait(timeout=5.0)
|
||||
return _make_resp(200, content=b"x", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_hang_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=1, # serialize so submitting 1 keeps the worker busy
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
# Tiny timeout — wait_all must report the future as not_done.
|
||||
bf.wait_all(timeout=0.05)
|
||||
t0 = time.time()
|
||||
bf.close()
|
||||
elapsed = time.time() - t0
|
||||
# Unblock the lingering worker so it doesn't pollute later tests.
|
||||
blocker.set()
|
||||
|
||||
# Without the cancel-on-timeout fix, close() would block until
|
||||
# blocker.set() — i.e., the full ~5s. With the fix it returns
|
||||
# immediately because shutdown(wait=False) doesn't drain.
|
||||
assert elapsed < 1.0, (
|
||||
f"close() blocked for {elapsed:.2f}s after wait_all timeout — "
|
||||
"cancel-on-timeout regression: close() is draining instead of bailing"
|
||||
)
|
||||
|
||||
|
||||
def test_batch_fetcher_close_without_timeout_still_drains():
|
||||
"""Negative leg of the timeout contract: when wait_all completes
|
||||
cleanly (no timeout), close() must KEEP its drain-and-wait
|
||||
behavior so a still-queued ack POST isn't dropped mid-write.
|
||||
"""
|
||||
import time
|
||||
|
||||
def _slow_get(url, headers=None):
|
||||
time.sleep(0.05)
|
||||
return _make_resp(200, content=b"x", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=2,
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.submit(_row_with_id("act-b", "b"))
|
||||
bf.wait_all() # generous default timeout — should not fire
|
||||
bf.close()
|
||||
|
||||
# All 2 GETs + 2 ACK POSTs ran to completion via drain-and-wait.
|
||||
assert client.get.call_count == 2
|
||||
assert client.post.call_count == 2
|
||||
|
||||
Reference in New Issue
Block a user