Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 226e57a942 | |||
| e9eb3868d5 | |||
| cb70d3d437 | |||
| a1d202723d | |||
| fc30b5c9de | |||
| ef67dc513e | |||
| 23d3f057d3 | |||
| 8ca027ddf3 | |||
| 46a4ef83bb | |||
| a6afc18de5 | |||
| 423d58d42c | |||
| 9386f1d399 | |||
| a766e5ce48 | |||
| 5ad2669f88 | |||
| 0ca4e431c1 | |||
| 184ce7ae4e | |||
| 2bf6a7005f | |||
| 16ead69641 | |||
| 60afcd43c9 | |||
| ff75aeb43e | |||
| 81cf0cbf98 | |||
| 412dec0d87 | |||
| 9a53529047 | |||
| 39931acd9c | |||
| 6f19b88fa7 | |||
| 83454e5efd | |||
| 575f893f4e | |||
| 4cac4e7710 | |||
| 8254bedf30 | |||
| ec72f199e6 | |||
| ae22a55675 | |||
| 08648bf4b1 | |||
| eec4ea2e7d | |||
| 6201d12533 | |||
| 81e83c05b7 | |||
| 5b5eacbb29 | |||
| c8fca1467e | |||
| 7c8b81c6eb | |||
| fc1c45789e | |||
| e3a18ed8e8 | |||
| 9f551319d2 | |||
| 1052f8bdb0 | |||
| 30fb507165 | |||
| 5334d60de4 | |||
| d6c0227e3f | |||
| 0f25f6de97 | |||
| b89a49ec93 | |||
| 3d0a7c381b | |||
| 210a26d31a |
@@ -172,6 +172,9 @@ jobs:
|
||||
- name: Run poll-mode + since_id cursor E2E (#2339)
|
||||
if: needs.detect-changes.outputs.api == 'true'
|
||||
run: bash tests/e2e/test_poll_mode_e2e.sh
|
||||
- name: Run poll-mode chat upload E2E (RFC #2891)
|
||||
if: needs.detect-changes.outputs.api == 'true'
|
||||
run: bash tests/e2e/test_poll_mode_chat_upload_e2e.sh
|
||||
- name: Dump platform log on failure
|
||||
if: failure() && needs.detect-changes.outputs.api == 'true'
|
||||
run: cat workspace-server/platform.log || true
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
// quick bounce between signup and either Checkout or the tenant UI.
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchSession, redirectToLogin, type Session } from "@/lib/auth";
|
||||
import { fetchSession, redirectToLogin, signOut, type Session } from "@/lib/auth";
|
||||
import { PLATFORM_URL } from "@/lib/api";
|
||||
import { formatCredits, pillTone, bannerKind } from "@/lib/credits";
|
||||
import { TermsGate } from "@/components/TermsGate";
|
||||
@@ -129,7 +129,7 @@ export default function OrgsPage() {
|
||||
return <EmptyState banner={justCheckedOut ? <CheckoutBanner /> : null} />;
|
||||
}
|
||||
return (
|
||||
<Shell>
|
||||
<Shell session={session}>
|
||||
{justCheckedOut && <CheckoutBanner />}
|
||||
<ul className="space-y-3">
|
||||
{orgs.map((o) => (
|
||||
@@ -160,11 +160,21 @@ function CheckoutBanner() {
|
||||
);
|
||||
}
|
||||
|
||||
function Shell({ children }: { children: React.ReactNode }) {
|
||||
function Shell({
|
||||
children,
|
||||
session,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
// Optional: when present, the header renders the signed-in email +
|
||||
// a Sign-out button. The empty-state Shell call doesn't have a
|
||||
// session in scope, so accept null and skip the header chrome there.
|
||||
session?: Session | null;
|
||||
}) {
|
||||
return (
|
||||
<main className="min-h-screen bg-surface text-ink">
|
||||
<TermsGate>
|
||||
<div className="mx-auto max-w-2xl px-6 pt-20 pb-12">
|
||||
{session ? <AccountBar session={session} /> : null}
|
||||
<h1 className="text-3xl font-bold text-ink">Your organizations</h1>
|
||||
<p className="mt-2 text-ink-mid">
|
||||
Each org is an isolated Molecule workspace.
|
||||
@@ -177,6 +187,40 @@ function Shell({ children }: { children: React.ReactNode }) {
|
||||
);
|
||||
}
|
||||
|
||||
// AccountBar renders the signed-in email + a Sign-out button at the
|
||||
// top of the page. Without this the user has no way to log out — the
|
||||
// /cp/auth/signout endpoint exists on the control plane but no UI ever
|
||||
// called it. Reported externally on 2026-05-05; this is the fix.
|
||||
//
|
||||
// Click → calls signOut() which POSTs /cp/auth/signout (clears the
|
||||
// WorkOS session cookie + revokes at the provider) then bounces to
|
||||
// /cp/auth/login. The signOut helper is best-effort — even on a 5xx
|
||||
// or network failure the redirect fires so the user never gets stuck
|
||||
// on an authed-looking page after they clicked Sign out.
|
||||
function AccountBar({ session }: { session: Session }) {
|
||||
const [signingOut, setSigningOut] = useState(false);
|
||||
return (
|
||||
<div className="mb-6 flex items-center justify-between text-sm text-ink-mid">
|
||||
<span title="Signed-in user">{session.email}</span>
|
||||
<button
|
||||
type="button"
|
||||
disabled={signingOut}
|
||||
onClick={async () => {
|
||||
setSigningOut(true);
|
||||
await signOut();
|
||||
// Redirect happens inside signOut; this line is for tests +
|
||||
// edge cases (jsdom, blocked navigation) where it doesn't.
|
||||
setSigningOut(false);
|
||||
}}
|
||||
className="rounded border border-line bg-surface-card px-3 py-1 text-xs text-ink hover:bg-surface-card disabled:opacity-50"
|
||||
aria-label="Sign out"
|
||||
>
|
||||
{signingOut ? "Signing out…" : "Sign out"}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// DataResidencyNotice surfaces where workspace data lives so EU-based
|
||||
// signups can make an informed choice (GDPR Art. 13 disclosure
|
||||
// requirement). Plain text, no icon — the goal is clarity, not
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useMemo, useRef } from "react";
|
||||
import { useState, useEffect, useLayoutEffect, useMemo, useRef, useCallback } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { api } from "@/lib/api";
|
||||
@@ -184,13 +184,23 @@ function unwrapErrorText(raw: string | null): string {
|
||||
export function AgentCommsPanel({ workspaceId }: { workspaceId: string }) {
|
||||
const [messages, setMessages] = useState<CommMessage[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
// Dedup by timestamp+type+peer to handle API load + WebSocket race
|
||||
const seenKeys = useRef(new Set<string>());
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
// Mirrors the my-chat scroll behaviour from ChatTab (PR #2903) —
|
||||
// smooth-scroll on a long history gets interrupted by concurrent
|
||||
// renders and lands the panel mid-conversation. Switch the first
|
||||
// arrival to instant; subsequent appends animate.
|
||||
const hasInitialScrollRef = useRef(false);
|
||||
|
||||
// Load history
|
||||
useEffect(() => {
|
||||
// Load history. Extracted so the error-state retry button can
|
||||
// re-invoke without remount. ChatTab uses the same shape
|
||||
// (loadInitial → loadError state → retry button).
|
||||
const loadInitial = useCallback(() => {
|
||||
setLoading(true);
|
||||
setLoadError(null);
|
||||
seenKeys.current.clear();
|
||||
api.get<ActivityEntry[]>(`/workspaces/${workspaceId}/activity?source=agent&limit=50`)
|
||||
.then((entries) => {
|
||||
const filtered = (entries ?? [])
|
||||
@@ -234,10 +244,15 @@ export function AgentCommsPanel({ workspaceId }: { workspaceId: string }) {
|
||||
// the .then body) — the panel just sat on the empty state
|
||||
// with zero signal.
|
||||
console.warn("AgentCommsPanel: load activity failed", err);
|
||||
setLoadError(err instanceof Error ? err.message : String(err));
|
||||
setLoading(false);
|
||||
});
|
||||
}, [workspaceId]);
|
||||
|
||||
useEffect(() => {
|
||||
loadInitial();
|
||||
}, [loadInitial]);
|
||||
|
||||
// Live updates routed through the global ReconnectingSocket. The
|
||||
// previous pattern of `new WebSocket(WS_URL)` per panel had no
|
||||
// onclose / no reconnect, so any drop (idle timeout, browser
|
||||
@@ -358,7 +373,18 @@ export function AgentCommsPanel({ workspaceId }: { workspaceId: string }) {
|
||||
} catch { /* ignore */ }
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
// useLayoutEffect (not useEffect) so the scroll runs BEFORE paint —
|
||||
// otherwise the user sees the panel jump for one frame on every
|
||||
// append. Mirrors ChatTab's MyChatPanel scroll block.
|
||||
useLayoutEffect(() => {
|
||||
if (!hasInitialScrollRef.current && messages.length > 0) {
|
||||
// Instant on first arrival — smooth-scroll on a long history
|
||||
// gets interrupted by concurrent renders and lands the panel
|
||||
// mid-conversation (the chat-opens-in-middle bug class).
|
||||
hasInitialScrollRef.current = true;
|
||||
bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior });
|
||||
return;
|
||||
}
|
||||
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages]);
|
||||
|
||||
@@ -366,6 +392,27 @@ export function AgentCommsPanel({ workspaceId }: { workspaceId: string }) {
|
||||
return <div className="text-xs text-ink-soft text-center py-8">Loading agent communications...</div>;
|
||||
}
|
||||
|
||||
if (loadError !== null && messages.length === 0) {
|
||||
// Mirrors ChatTab my-chat error UI — surfaces the load failure
|
||||
// with a retry button instead of silently rendering empty state.
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="mx-2 mt-2 rounded-lg border border-red-800/50 bg-red-950/30 px-3 py-2.5"
|
||||
>
|
||||
<p className="text-[11px] text-bad mb-1.5">
|
||||
Failed to load agent communications: {loadError}
|
||||
</p>
|
||||
<button
|
||||
onClick={loadInitial}
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-red-800/40 text-bad hover:bg-red-700/50 transition-colors"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (messages.length === 0) {
|
||||
return (
|
||||
<div className="text-xs text-ink-soft text-center py-8">
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
// @vitest-environment jsdom
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
|
||||
// API mock — tests can override per case via apiGetMock.mockImplementationOnce.
|
||||
const apiGetMock = vi.fn<(url: string) => Promise<unknown>>();
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: {
|
||||
get: (url: string) => apiGetMock(url),
|
||||
},
|
||||
}));
|
||||
|
||||
// useSocketEvent — no-op for these render tests; live updates aren't
|
||||
// what we're verifying here.
|
||||
vi.mock("@/hooks/useSocketEvent", () => ({
|
||||
useSocketEvent: () => {},
|
||||
}));
|
||||
|
||||
// Canvas store — peer name resolution.
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: {
|
||||
getState: () => ({
|
||||
nodes: [
|
||||
{ id: "ws-self", data: { name: "Self" } },
|
||||
{ id: "ws-peer", data: { name: "Peer Agent" } },
|
||||
],
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
// Toaster shim — AgentCommsPanel imports showToast.
|
||||
vi.mock("../../Toaster", () => ({
|
||||
showToast: vi.fn(),
|
||||
}));
|
||||
|
||||
import { AgentCommsPanel } from "../AgentCommsPanel";
|
||||
|
||||
// jsdom doesn't implement scrollIntoView. Tests that observe the call
|
||||
// install a spy here; tests that don't care still need a no-op stub
|
||||
// so the component doesn't throw.
|
||||
const scrollSpy = vi.fn<(opts?: ScrollIntoViewOptions | boolean) => void>();
|
||||
beforeEach(() => {
|
||||
apiGetMock.mockReset();
|
||||
scrollSpy.mockReset();
|
||||
Element.prototype.scrollIntoView = scrollSpy as unknown as Element["scrollIntoView"];
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("AgentCommsPanel — initial-state parity with ChatTab my-chat", () => {
|
||||
it("shows loading text while history fetch is in flight", () => {
|
||||
apiGetMock.mockReturnValueOnce(new Promise(() => { /* never resolves */ }));
|
||||
render(<AgentCommsPanel workspaceId="ws-self" />);
|
||||
expect(screen.getByText("Loading agent communications...")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders error UI with a Retry button when the history fetch rejects", async () => {
|
||||
apiGetMock.mockRejectedValueOnce(new Error("network down"));
|
||||
render(<AgentCommsPanel workspaceId="ws-self" />);
|
||||
|
||||
// Wait for the error state to render — loading→error transition is async.
|
||||
const alert = await waitFor(() => screen.getByRole("alert"));
|
||||
expect(alert.textContent).toMatch(/Failed to load agent communications/);
|
||||
expect(alert.textContent).toMatch(/network down/);
|
||||
|
||||
// Retry button must be present and trigger a refetch.
|
||||
const retry = screen.getByRole("button", { name: "Retry" });
|
||||
apiGetMock.mockResolvedValueOnce([]); // success on retry
|
||||
fireEvent.click(retry);
|
||||
|
||||
// Two calls total: initial load + retry. Pin via mock call count.
|
||||
await waitFor(() => expect(apiGetMock.mock.calls.length).toBe(2));
|
||||
});
|
||||
|
||||
it("falls back to empty-state copy when load succeeds with zero rows", async () => {
|
||||
apiGetMock.mockResolvedValueOnce([]);
|
||||
render(<AgentCommsPanel workspaceId="ws-self" />);
|
||||
await waitFor(() =>
|
||||
expect(screen.getByText("No agent-to-agent communications yet.")).toBeDefined(),
|
||||
);
|
||||
});
|
||||
|
||||
it("scrollIntoView is called with behavior=instant on the first message arrival", async () => {
|
||||
apiGetMock.mockResolvedValueOnce([
|
||||
{
|
||||
id: "act-1",
|
||||
activity_type: "a2a_send",
|
||||
source_id: "ws-self",
|
||||
target_id: "ws-peer",
|
||||
method: "message/send",
|
||||
summary: "Delegating",
|
||||
request_body: { message: { parts: [{ text: "hi" }] } },
|
||||
response_body: null,
|
||||
status: "ok",
|
||||
created_at: "2026-04-25T18:00:00Z",
|
||||
},
|
||||
]);
|
||||
render(<AgentCommsPanel workspaceId="ws-self" />);
|
||||
|
||||
// useLayoutEffect is what makes the first call instant — wait for
|
||||
// the panel to render at least one message.
|
||||
await waitFor(() => expect(scrollSpy.mock.calls.length).toBeGreaterThan(0));
|
||||
|
||||
// The pinned contract: SOME call uses behavior: "instant" — the
|
||||
// first-arrival case. Subsequent appends use "smooth", but those
|
||||
// can't fire here (no live update yet).
|
||||
const sawInstant = scrollSpy.mock.calls.some((args) => {
|
||||
const opts = args[0];
|
||||
return typeof opts === "object" && opts !== null && "behavior" in opts && opts.behavior === "instant";
|
||||
});
|
||||
expect(sawInstant).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -2,7 +2,7 @@
|
||||
* @vitest-environment jsdom
|
||||
*/
|
||||
import { describe, it, expect, vi, afterEach } from "vitest";
|
||||
import { fetchSession, redirectToLogin } from "../auth";
|
||||
import { fetchSession, redirectToLogin, signOut } from "../auth";
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
@@ -110,3 +110,157 @@ describe("redirectToLogin", () => {
|
||||
expect((window.location as unknown as { href: string }).href).toBe(signupHref);
|
||||
});
|
||||
});
|
||||
|
||||
describe("signOut", () => {
|
||||
// Helper — most tests need the same window.location stub.
|
||||
function stubLocation(): void {
|
||||
Object.defineProperty(window, "location", {
|
||||
writable: true,
|
||||
value: {
|
||||
href: "https://acme.moleculesai.app/orgs",
|
||||
pathname: "/orgs",
|
||||
hostname: "acme.moleculesai.app",
|
||||
protocol: "https:",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
it("POSTs to /cp/auth/signout with credentials:include", async () => {
|
||||
stubLocation();
|
||||
const fetchMock = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ ok: true, logout_url: "" }),
|
||||
});
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
await signOut();
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchMock).toHaveBeenCalledWith(
|
||||
expect.stringContaining("/cp/auth/signout"),
|
||||
expect.objectContaining({ method: "POST", credentials: "include" }),
|
||||
);
|
||||
});
|
||||
|
||||
it("navigates to provider logout_url when the response includes one", async () => {
|
||||
// The hosted-logout path is what actually breaks the SSO re-auth
|
||||
// loop reported on PR #2913. Without this, AuthKit's browser
|
||||
// cookie keeps the user signed in via SSO and any subsequent
|
||||
// /cp/auth/login silently re-auths.
|
||||
stubLocation();
|
||||
const hostedLogout =
|
||||
"https://api.workos.com/user_management/sessions/logout?session_id=cookie&return_to=https%3A%2F%2Fapp.moleculesai.app%2Forgs";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ ok: true, logout_url: hostedLogout }),
|
||||
}),
|
||||
);
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
expect(after).toBe(hostedLogout);
|
||||
});
|
||||
|
||||
it("falls back to /cp/auth/login when logout_url is empty (DisabledProvider / dev)", async () => {
|
||||
// DisabledProvider returns "" — the local /cp/auth/login redirect
|
||||
// works in dev/test where there's no SSO session to escape.
|
||||
stubLocation();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ ok: true, logout_url: "" }),
|
||||
}),
|
||||
);
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
// Tenant subdomain (acme.moleculesai.app) → auth origin is app.moleculesai.app.
|
||||
expect(after).toBe("https://app.moleculesai.app/cp/auth/login");
|
||||
});
|
||||
|
||||
it("redirects even when the POST fails so the user isn't stuck on an authed page", async () => {
|
||||
// Critical UX invariant: clicking 'Sign out' MUST navigate away from
|
||||
// the authenticated app, even if the network is down or the cookie
|
||||
// is already invalid. Anything else looks like the button is
|
||||
// broken — the precise complaint that triggered this fix.
|
||||
stubLocation();
|
||||
vi.stubGlobal("fetch", vi.fn().mockRejectedValue(new Error("network down")));
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
expect(after).toBe("https://app.moleculesai.app/cp/auth/login");
|
||||
});
|
||||
|
||||
it("redirects on 401 (session already invalid) just like 200", async () => {
|
||||
// A user with an already-invalid cookie should still see the
|
||||
// logout flow complete — no error, no stuck-on-app dead end.
|
||||
// Note: 401 means res.ok=false → we don't read .json() at all,
|
||||
// so a missing body is fine.
|
||||
stubLocation();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 401,
|
||||
json: async () => ({}),
|
||||
}),
|
||||
);
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
expect(after).toBe("https://app.moleculesai.app/cp/auth/login");
|
||||
});
|
||||
|
||||
it("falls back to /cp/auth/login when the response body is malformed", async () => {
|
||||
// Defensive parsing: a body that isn't valid JSON, or doesn't
|
||||
// have logout_url, or has logout_url as the wrong type — none of
|
||||
// these should strand the user on the authed page. Fallback path
|
||||
// takes over.
|
||||
stubLocation();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => {
|
||||
throw new Error("not json");
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
expect(after).toBe("https://app.moleculesai.app/cp/auth/login");
|
||||
});
|
||||
|
||||
it("falls back to /cp/auth/login when logout_url is the wrong type", async () => {
|
||||
// Even valid JSON should be type-checked: a non-string logout_url
|
||||
// (e.g. server-side bug, version drift) must not crash or open-
|
||||
// redirect the user.
|
||||
stubLocation();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({ ok: true, logout_url: 42 }),
|
||||
}),
|
||||
);
|
||||
|
||||
await signOut();
|
||||
|
||||
const after = (window.location as unknown as { href: string }).href;
|
||||
expect(after).toBe("https://app.moleculesai.app/cp/auth/login");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -67,3 +67,80 @@ export function redirectToLogin(screenHint: "sign-up" | "sign-in" = "sign-in"):
|
||||
const dest = `${authOrigin}${AUTH_BASE}/${path}?return_to=${encodeURIComponent(returnTo)}`;
|
||||
window.location.href = dest;
|
||||
}
|
||||
|
||||
/**
|
||||
* signOut posts to /cp/auth/signout to clear the WorkOS session cookie
|
||||
* + revoke at the provider, then navigates the browser to the
|
||||
* provider-supplied hosted logout URL (so the provider's BROWSER-side
|
||||
* SSO cookie is cleared too — without this, AuthKit silently re-auths
|
||||
* via SSO on the next /cp/auth/login and the user is "still signed
|
||||
* in" after pressing Sign out).
|
||||
*
|
||||
* Two-layer flow:
|
||||
* 1. POST /cp/auth/signout → CP clears OUR session cookie + revokes
|
||||
* session_id at the provider API. Response includes
|
||||
* `logout_url` — the AuthKit hosted URL the BROWSER must navigate
|
||||
* to so the provider's own browser cookie is cleared.
|
||||
* 2. window.location.href = <logout_url> → AuthKit clears its
|
||||
* session, then redirects the browser to the configured
|
||||
* return_to (defaults to APP_URL/orgs).
|
||||
*
|
||||
* Best-effort by design: a 5xx, network failure, missing logout_url
|
||||
* (DisabledProvider, dev), or stale cookie still results in the
|
||||
* browser navigating away — leaving the user on a logged-in-looking
|
||||
* page after they clicked "Sign out" is the worst possible UX. The
|
||||
* fallback path navigates to /cp/auth/login on the auth origin, which
|
||||
* works correctly in environments without a hosted logout flow (dev,
|
||||
* tests, DisabledProvider).
|
||||
*
|
||||
* Throws nothing — callers can disable the button optimistically or
|
||||
* await this and trust it returns. On a redirect-blocked test
|
||||
* environment (jsdom under vitest) we still exit cleanly so unit tests
|
||||
* can spy on the fetch call.
|
||||
*/
|
||||
export async function signOut(): Promise<void> {
|
||||
let logoutURL: string | undefined;
|
||||
// Fire-and-tolerate the POST. credentials:include is mandatory cross-
|
||||
// origin so the SaaS canvas (acme.moleculesai.app) can hit
|
||||
// app.moleculesai.app/cp/auth/signout with the session cookie.
|
||||
try {
|
||||
const res = await fetch(`${getAuthOrigin()}${AUTH_BASE}/signout`, {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (res.ok) {
|
||||
// Body shape: {"ok": true, "logout_url": "..."}. logout_url is
|
||||
// empty for DisabledProvider (dev/local) — we fall back to
|
||||
// /cp/auth/login below. Defensive parsing: a malformed body
|
||||
// shouldn't strand the user on the authed page.
|
||||
const body: unknown = await res.json().catch(() => null);
|
||||
if (
|
||||
body &&
|
||||
typeof body === "object" &&
|
||||
"logout_url" in body &&
|
||||
typeof (body as { logout_url: unknown }).logout_url === "string" &&
|
||||
(body as { logout_url: string }).logout_url
|
||||
) {
|
||||
logoutURL = (body as { logout_url: string }).logout_url;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore — we still redirect below.
|
||||
}
|
||||
if (typeof window === "undefined") return;
|
||||
if (logoutURL) {
|
||||
// Hosted logout: AuthKit clears its SSO cookie + redirects to
|
||||
// return_to (configured server-side). This is the path that
|
||||
// actually breaks the SSO re-auth loop.
|
||||
window.location.href = logoutURL;
|
||||
return;
|
||||
}
|
||||
// Fallback: no hosted logout (dev, DisabledProvider, network
|
||||
// failure). Land on the login screen rather than the current URL:
|
||||
// returning to a tenant URL after signout would just re-redirect
|
||||
// through /cp/auth/login due to AuthGate. Send the user straight
|
||||
// there with no return_to so they don't loop back into the org they
|
||||
// just left.
|
||||
const authOrigin = getAuthOrigin();
|
||||
window.location.href = `${authOrigin}${AUTH_BASE}/login`;
|
||||
}
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# Team Expansion (Recursive Workspaces)
|
||||
|
||||
When a workspace is expanded into a team, it gains sub-workspaces while its own agent remains as the **team lead** (coordinator). This is recursive — sub-workspaces can themselves be expanded into teams, infinitely deep.
|
||||
|
||||
## How It Works
|
||||
|
||||
When Developer PM is expanded into a team:
|
||||
|
||||
```
|
||||
Business Core
|
||||
|
|
||||
+-- Developer PM (agent stays, becomes coordinator)
|
||||
|
|
||||
+-- Frontend Agent (sub-workspace, private scope)
|
||||
+-- Backend Agent (sub-workspace, private scope)
|
||||
+-- QA Agent (sub-workspace, private scope)
|
||||
```
|
||||
|
||||
- Developer PM's agent **still exists** and acts as coordinator
|
||||
- Developer PM receives incoming A2A messages from Business Core
|
||||
- Developer PM's agent decides how to delegate to sub-workspaces
|
||||
- Sub-workspaces talk to Developer PM and to each other (same level)
|
||||
- Sub-workspaces **cannot** talk to Business Core or any workspace outside the team
|
||||
|
||||
## Communication Rules
|
||||
|
||||
| Direction | Allowed? | Example |
|
||||
|-----------|----------|---------|
|
||||
| Parent level -> team lead | Yes | Business Core -> Developer PM |
|
||||
| Team lead -> sub-workspaces | Yes | Developer PM -> Frontend Agent |
|
||||
| Sub-workspace -> team lead | Yes | Frontend Agent -> Developer PM |
|
||||
| Sub-workspace <-> sibling | Yes | Frontend Agent <-> Backend Agent |
|
||||
| Outside -> sub-workspace directly | No (403) | Business Core -> Frontend Agent |
|
||||
| Sub-workspace -> outside directly | No | Frontend Agent -> Business Core |
|
||||
|
||||
The team lead (Developer PM) is the **only** bridge between the team's internal world and the outside.
|
||||
|
||||
## Scoped Registry
|
||||
|
||||
Sub-workspaces register in the platform registry but with a **private scope**. The registry knows about them but enforces access control.
|
||||
|
||||
```
|
||||
Registry:
|
||||
Business Core :8001 scope: public
|
||||
Developer PM :8002 scope: public
|
||||
Frontend Agent :8010 scope: private, parent=Developer PM
|
||||
Backend Agent :8011 scope: private, parent=Developer PM
|
||||
QA Agent :8012 scope: private, parent=Developer PM
|
||||
```
|
||||
|
||||
- The platform can always discover any workspace (for provisioning, monitoring)
|
||||
- The parent workspace can discover its sub-workspaces
|
||||
- Sub-workspaces can discover their siblings (same parent)
|
||||
- Outside workspaces get a **403 Forbidden** if they try to discover a private sub-workspace
|
||||
|
||||
## How to Expand
|
||||
|
||||
Expansion is triggered via `POST /workspaces/:id/expand`. The platform reads the `sub_workspaces` list from the workspace's config and provisions each one. On the canvas, users right-click a workspace node and select "Expand into team."
|
||||
|
||||
Collapsing is the inverse: `POST /workspaces/:id/collapse`. Sub-workspaces are stopped and removed.
|
||||
|
||||
## What Happens on Expansion
|
||||
|
||||
When Developer PM is expanded into a team, the hierarchy changes but the outside view doesn't. Business Core's parent/child relationship to Developer PM is unaffected — Developer PM still responds to the same A2A endpoint.
|
||||
|
||||
The events fired:
|
||||
- `WORKSPACE_EXPANDED` with the new `sub_workspace_ids` in the payload
|
||||
- `WORKSPACE_PROVISIONING` for each new sub-workspace
|
||||
- `WORKSPACE_ONLINE` for each sub-workspace as they come up
|
||||
|
||||
Communication rules are automatically derived from the new hierarchy — no manual wiring needed.
|
||||
|
||||
## Canvas Behavior
|
||||
|
||||
- Children render as embedded mini-cards (`TeamMemberChip`) inside the parent node, not as separate canvas nodes
|
||||
- Each mini-card shows full status: gradient bar, name, tier badge, skills pills, active tasks, descendant count
|
||||
- **Recursive rendering** up to 3 levels deep (`MAX_NESTING_DEPTH = 3`) — sub-cards can contain their own "Team" sections
|
||||
- Parent node dynamically resizes: 210-280px (no children), 320-450px (children), 400-560px (grandchildren)
|
||||
- Eject button (sky-blue arrow icon) on hover extracts a child from the team
|
||||
- "Extract from Team" also available in the right-click context menu
|
||||
- Double-click a team node to zoom/fit to the parent area
|
||||
- The parent workspace node shows a badge with total descendant count
|
||||
|
||||
## Collapsing a Team
|
||||
|
||||
The inverse of expansion, triggered via `POST /workspaces/:id/collapse`:
|
||||
|
||||
1. Each sub-workspace agent wraps up current work and writes a handoff document to memory
|
||||
2. Sub-workspaces are stopped and removed
|
||||
3. The team lead's agent goes back to handling everything directly
|
||||
4. A `WORKSPACE_COLLAPSED` event fires
|
||||
|
||||
Sub-workspace memory is cleaned up based on backend (see [Memory — Cleanup](../architecture/memory.md#cleanup-on-workspace-deletion)).
|
||||
|
||||
## Deleting a Team Workspace
|
||||
|
||||
When a team workspace is deleted:
|
||||
1. Platform shows a warning listing all sub-workspaces that will be deleted
|
||||
2. User can **drag sub-workspaces out** of the team before confirming (promotes them to the parent level)
|
||||
3. On confirmation, cascade delete removes the parent and all remaining sub-workspaces
|
||||
4. `WORKSPACE_REMOVED` events fire for each deleted workspace
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Communication Rules](../api-protocol/communication-rules.md) — Full access control model
|
||||
- [Core Concepts](../product/core-concepts.md) — Workspace fundamentals
|
||||
- [System Prompt Structure](./system-prompt-structure.md) — How peer capabilities are injected
|
||||
- [Provisioner](../architecture/provisioner.md) — How sub-workspaces are deployed
|
||||
- [Registry & Heartbeat](../api-protocol/registry-and-heartbeat.md) — How registration works
|
||||
- [Event Log](../architecture/event-log.md) — Events fired during expansion
|
||||
- [Canvas UI](../frontend/canvas.md) — Visual behavior of teams
|
||||
@@ -41,8 +41,6 @@ Full contract: `docs/runbooks/admin-auth.md`.
|
||||
| GET | /admin/workspaces/:id/test-token | admin_test_token.go — mint a fresh bearer token for E2E scripts; returns 404 unless `MOLECULE_ENV != production` or `MOLECULE_ENABLE_TEST_TOKENS=1` |
|
||||
| GET/POST/DELETE | /admin/secrets[/:key] | secrets.go — legacy aliases for /settings/secrets |
|
||||
| WS | /workspaces/:id/terminal | terminal.go |
|
||||
| POST | /workspaces/:id/expand | team.go |
|
||||
| POST | /workspaces/:id/collapse | team.go |
|
||||
| POST/GET | /workspaces/:id/approvals | approvals.go |
|
||||
| POST | /workspaces/:id/approvals/:id/decide | approvals.go |
|
||||
| GET | /approvals/pending | approvals.go |
|
||||
|
||||
@@ -336,8 +336,6 @@ This same logic governs: A2A delegation, memory scope enforcement, activity visi
|
||||
|
||||
| Method | Endpoint | Purpose |
|
||||
|--------|----------|---------|
|
||||
| `POST` | `/workspaces/:id/expand` | Expand workspace into team (become coordinator) |
|
||||
| `POST` | `/workspaces/:id/collapse` | Collapse team back to single workspace |
|
||||
|
||||
### Files, Terminal, Templates, Bundles (8 endpoints)
|
||||
|
||||
|
||||
@@ -186,4 +186,3 @@ So the UI now exposes more operational failure state directly instead of silentl
|
||||
- [Quickstart](../quickstart.md)
|
||||
- [Platform API](../api-protocol/platform-api.md)
|
||||
- [Workspace Runtime](../agent-runtime/workspace-runtime.md)
|
||||
- [Team Expansion](../agent-runtime/team-expansion.md)
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ lands in the watch list with a colliding term, add a row here.
|
||||
| **plugin** | A directory under `plugins/` packaging one or more skills or an MCP server wrapper, installable per-workspace via `POST /workspaces/:id/plugins`. Governed by `plugin.yaml`. | **Langflow**: a visual UI node / component in a flowchart. **CrewAI**: a Python-importable callable registered as a capability. |
|
||||
| **agent** | A persistent containerized workspace running continuously — an identity with memory, a role, and a schedule. Not a one-shot invocation. | Most frameworks (AutoGPT, LangChain agents, OpenAI Assistants): a stateless function-call loop. No persistence between invocations unless explicitly checkpointed. |
|
||||
| **flow** | A task execution within a workspace — a request enters, the agent runs tools, emits a response, logs activity. No explicit graph abstraction. | **Langflow**: a directed graph of nodes you author visually. **LangGraph**: a stateful graph of callable nodes. Our "flow" is an imperative timeline, not a graph. |
|
||||
| **team** | A named cluster of workspaces under a PM (org template `expand_team`). Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **team** | A named cluster of workspaces under a PM . Used for role grouping in Canvas. | **CrewAI**: a "crew" is a sequence of agents that pass a task through a declared order. Our "team" is an org-chart abstraction, not an execution order. |
|
||||
| **skill** | A directory with `SKILL.md` that an agent invokes via the `Skill` tool. Skills are documentation + optional scripts that teach an agent a recipe. | **Anthropic Skills API**: nearly identical. **CrewAI tool**: closer to our plugin's MCP tool, not our skill. |
|
||||
| **channel** | An outbound/inbound social integration (Telegram, Slack, …) per-workspace, wired in `workspace_channels`. | Slack's "channel": the container for messages. We use "channel" for the adapter + credentials, not the conversation itself. |
|
||||
| **runtime** | The execution engine image tag for a workspace: one of `langgraph`, `claude-code`, `openclaw`, `crewai`, `autogen`, `deepagents`, `hermes`. | **LangGraph runtime**: the Python process running the graph. We use "runtime" for the Docker image + adapter pairing, not the inner process. |
|
||||
|
||||
@@ -166,8 +166,6 @@ list_workspaces
|
||||
|
||||
| MCP Tool | API Route | Method | Description |
|
||||
|----------|-----------|--------|-------------|
|
||||
| `expand_team` | `/workspaces/:id/expand` | POST | Expand team node |
|
||||
| `collapse_team` | `/workspaces/:id/collapse` | POST | Collapse team node |
|
||||
|
||||
### Templates & Bundles
|
||||
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# Workspace Runtime PyPI Package
|
||||
|
||||
## Requires Python >= 3.11
|
||||
|
||||
The wheel pins `requires_python>=3.11`. On Python 3.10 or older, `pip install
|
||||
molecule-ai-workspace-runtime` fails with `Could not find a version that
|
||||
satisfies the requirement (from versions: none)` — the pin filters the only
|
||||
available artifact before pip even attempts install. Upgrade the interpreter
|
||||
(`brew install python@3.12` / `apt install python3.12` / etc.) or use a
|
||||
3.11+ venv.
|
||||
|
||||
## Overview
|
||||
|
||||
The shared workspace runtime infrastructure has **one editable source** and
|
||||
|
||||
@@ -56,6 +56,7 @@ TOP_LEVEL_MODULES = {
|
||||
"a2a_mcp_server",
|
||||
"a2a_tools",
|
||||
"a2a_tools_delegation",
|
||||
"a2a_tools_memory",
|
||||
"a2a_tools_rbac",
|
||||
"adapter_base",
|
||||
"agent",
|
||||
|
||||
Executable
+295
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env bash
|
||||
# E2E for poll-mode chat upload (RFC #2891 phases 1-5b).
|
||||
#
|
||||
# Round-trip: register a workspace as poll-mode (no callback URL) → POST a
|
||||
# multi-file chat upload → verify each file becomes (a) one
|
||||
# `chat_upload_receive` activity row and (b) one /pending-uploads row → fetch
|
||||
# the bytes back via the poll endpoint → ack → verify the row 404s on
|
||||
# subsequent fetch. Also pins cross-workspace bleed protection: workspace B
|
||||
# cannot read workspace A's pending uploads even with its own valid bearer.
|
||||
#
|
||||
# Why this exists separately from test_chat_upload_e2e.sh: that script
|
||||
# covers the PUSH path (the workspace's own /internal/chat/uploads/ingest).
|
||||
# This script covers the POLL path: the same canvas-side request lands on
|
||||
# the platform's pendinguploads.Storage instead, and the workspace fetches
|
||||
# it later. The two paths share zero handler code on the platform side, so
|
||||
# both need their own E2E.
|
||||
#
|
||||
# Requires: platform running on localhost:8080 with migrations applied.
|
||||
# bash workspace-server/scripts/dev-start.sh
|
||||
# bash workspace-server/scripts/run-migrations.sh
|
||||
#
|
||||
# Idempotent: each run uses fresh per-script workspace UUIDs so reruns
|
||||
# don't collide. Best-effort cleanup on EXIT — does NOT call
|
||||
# e2e_cleanup_all_workspaces (see
|
||||
# `feedback_never_run_cluster_cleanup_tests_on_live_platform.md`).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
source "$(dirname "$0")/_lib.sh"
|
||||
|
||||
PASS=0
|
||||
FAIL=0
|
||||
TIMEOUT="${A2A_TIMEOUT:-30}"
|
||||
|
||||
gen_uuid() {
|
||||
if command -v uuidgen >/dev/null 2>&1; then
|
||||
uuidgen | tr '[:upper:]' '[:lower:]'
|
||||
else
|
||||
python3 -c 'import uuid; print(uuid.uuid4())'
|
||||
fi
|
||||
}
|
||||
WS_A="$(gen_uuid)"
|
||||
WS_B="$(gen_uuid)"
|
||||
|
||||
# Per-run scratch dir collected under one trap so every assertion-failure
|
||||
# path drops the temp files it made (see test_chat_attachments_e2e.sh).
|
||||
TMPDIR_E2E=$(mktemp -d -t poll-chat-upload-e2e-XXXXXX)
|
||||
|
||||
cleanup() {
|
||||
local rc=$?
|
||||
curl -s -X DELETE "$BASE/workspaces/$WS_A?confirm=true" >/dev/null 2>&1 || true
|
||||
curl -s -X DELETE "$BASE/workspaces/$WS_B?confirm=true" >/dev/null 2>&1 || true
|
||||
rm -rf "$TMPDIR_E2E"
|
||||
exit $rc
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
check() {
|
||||
local desc="$1" expected="$2" actual="$3"
|
||||
if echo "$actual" | grep -qF -- "$expected"; then
|
||||
echo "PASS: $desc"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo "FAIL: $desc"
|
||||
echo " expected to contain: $expected"
|
||||
echo " got: $(echo "$actual" | head -10)"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
check_eq() {
|
||||
local desc="$1" expected="$2" actual="$3"
|
||||
if [ "$actual" = "$expected" ]; then
|
||||
echo "PASS: $desc"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo "FAIL: $desc"
|
||||
echo " expected: $expected"
|
||||
echo " got: $actual"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
echo "=== Poll-Mode Chat Upload E2E ==="
|
||||
echo " base: $BASE"
|
||||
echo " workspace A: $WS_A"
|
||||
echo " workspace B: $WS_B"
|
||||
echo ""
|
||||
|
||||
# ---------- Phase 1: register poll-mode workspace ----------
|
||||
echo "--- Phase 1: Register poll-mode workspace A ---"
|
||||
|
||||
REG_A=$(curl -s -X POST "$BASE/registry/register" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"id\": \"$WS_A\",
|
||||
\"delivery_mode\": \"poll\",
|
||||
\"agent_card\": {\"name\": \"poll-chat-upload-test-a\"}
|
||||
}")
|
||||
check "register accepts poll mode without URL" '"status":"registered"' "$REG_A"
|
||||
TOK_A=$(echo "$REG_A" | e2e_extract_token || true)
|
||||
[ -n "$TOK_A" ] || { echo "FAIL: no auth_token in register response (ws A)"; FAIL=$((FAIL + 1)); exit 1; }
|
||||
|
||||
# ---------- Phase 2: multi-file chat upload ----------
|
||||
echo ""
|
||||
echo "--- Phase 2: POST /chat/uploads with two files ---"
|
||||
|
||||
FILE1="$TMPDIR_E2E/alpha.txt"
|
||||
FILE2="$TMPDIR_E2E/beta.txt"
|
||||
EXPECTED1="alpha-secret-$(openssl rand -hex 4)"
|
||||
EXPECTED2="beta-secret-$(openssl rand -hex 4)"
|
||||
printf '%s' "$EXPECTED1" > "$FILE1"
|
||||
printf '%s' "$EXPECTED2" > "$FILE2"
|
||||
|
||||
UPLOAD=$(curl -s -X POST "$BASE/workspaces/$WS_A/chat/uploads" \
|
||||
-H "Authorization: Bearer $TOK_A" \
|
||||
-F "files=@$FILE1;filename=alpha.txt;type=text/plain" \
|
||||
-F "files=@$FILE2;filename=beta.txt;type=text/plain" \
|
||||
-w "\nHTTP_CODE=%{http_code}\n")
|
||||
UPLOAD_CODE=$(echo "$UPLOAD" | grep -oE 'HTTP_CODE=[0-9]+' | cut -d= -f2)
|
||||
UPLOAD_BODY=$(echo "$UPLOAD" | sed '/^HTTP_CODE=/,$d')
|
||||
|
||||
check_eq "upload returns 200" "200" "$UPLOAD_CODE"
|
||||
check "upload response has files array" '"files":' "$UPLOAD_BODY"
|
||||
|
||||
# Pull file_ids out of the URI in the response. URI shape is
|
||||
# `platform-pending:<wsid>/<file_id>` — proves the response came from the
|
||||
# poll-mode branch, not the push-mode internal-ingest branch.
|
||||
URI1=$(echo "$UPLOAD_BODY" | python3 -c 'import sys,json; d=json.load(sys.stdin); print(d["files"][0]["uri"])')
|
||||
URI2=$(echo "$UPLOAD_BODY" | python3 -c 'import sys,json; d=json.load(sys.stdin); print(d["files"][1]["uri"])')
|
||||
check "URI 1 has platform-pending: scheme" "platform-pending:$WS_A/" "$URI1"
|
||||
check "URI 2 has platform-pending: scheme" "platform-pending:$WS_A/" "$URI2"
|
||||
|
||||
FID1="${URI1##*/}"
|
||||
FID2="${URI2##*/}"
|
||||
[ -n "$FID1" ] && [ -n "$FID2" ] || { echo "FAIL: could not extract file IDs"; FAIL=$((FAIL + 1)); exit 1; }
|
||||
echo " file_id 1: $FID1"
|
||||
echo " file_id 2: $FID2"
|
||||
|
||||
# ---------- Phase 3: activity rows visible to the workspace ----------
|
||||
echo ""
|
||||
echo "--- Phase 3: /activity shows two chat_upload_receive rows ---"
|
||||
|
||||
# activity_logs INSERTs run in a goroutine — give them a moment.
|
||||
sleep 1
|
||||
ACT=$(curl -s --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/activity?type=a2a_receive&limit=20")
|
||||
check "activity feed has the alpha file" "$FID1" "$ACT"
|
||||
check "activity feed has the beta file" "$FID2" "$ACT"
|
||||
check "activity rows tagged chat_upload_receive" '"method":"chat_upload_receive"' "$ACT"
|
||||
check "activity rows record alpha mimetype" '"mimeType":"text/plain"' "$ACT"
|
||||
|
||||
CHAT_UPLOAD_COUNT=$(echo "$ACT" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin)
|
||||
n = sum(1 for r in rows if (r.get("method") or "") == "chat_upload_receive")
|
||||
print(n)
|
||||
')
|
||||
check_eq "exactly two chat_upload_receive rows" "2" "$CHAT_UPLOAD_COUNT"
|
||||
|
||||
# ---------- Phase 4: GET /pending-uploads/:file_id/content ----------
|
||||
echo ""
|
||||
echo "--- Phase 4: Fetch content for each pending upload ---"
|
||||
|
||||
GOT1=$(curl -s --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/content")
|
||||
check_eq "alpha bytes round-trip" "$EXPECTED1" "$GOT1"
|
||||
|
||||
GOT2=$(curl -s --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID2/content")
|
||||
check_eq "beta bytes round-trip" "$EXPECTED2" "$GOT2"
|
||||
|
||||
# Mimetype + Content-Disposition headers should match what was uploaded.
|
||||
HEAD1=$(curl -s -D - -o /dev/null --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/content")
|
||||
check "alpha response carries text/plain Content-Type" "Content-Type: text/plain" "$HEAD1"
|
||||
check "alpha response carries Content-Disposition with filename" 'filename="alpha.txt"' "$HEAD1"
|
||||
|
||||
# ---------- Phase 5: idempotent re-fetch (until ack) ----------
|
||||
echo ""
|
||||
echo "--- Phase 5: Re-fetch before ack returns the same bytes ---"
|
||||
|
||||
RE_GOT1=$(curl -s --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/content")
|
||||
check_eq "re-fetch returns same alpha bytes" "$EXPECTED1" "$RE_GOT1"
|
||||
|
||||
# ---------- Phase 6: ack each row ----------
|
||||
echo ""
|
||||
echo "--- Phase 6: Ack each pending upload ---"
|
||||
|
||||
ACK1=$(curl -s -X POST --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/ack")
|
||||
check "alpha ack returns acked:true" '"acked":true' "$ACK1"
|
||||
|
||||
ACK2=$(curl -s -X POST --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID2/ack")
|
||||
check "beta ack returns acked:true" '"acked":true' "$ACK2"
|
||||
|
||||
# Re-ack should still 200 (idempotent — the row's gone but the workspace's
|
||||
# at-least-once intent was already honored, and the second ack hits the
|
||||
# raced path which also returns 200).
|
||||
RE_ACK1=$(curl -s -w '\n%{http_code}' -X POST --max-time "$TIMEOUT" \
|
||||
-H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/ack")
|
||||
RE_ACK1_CODE=$(printf '%s' "$RE_ACK1" | tail -n1)
|
||||
# Acked rows return 404 on Get-before-Ack (the row's still in the table
|
||||
# but Get filters acked_at IS NULL); workspace would not normally re-ack
|
||||
# since it already saw the success. Accept both 200 and 404 here so the
|
||||
# test pins the contract without being brittle on the inner ordering.
|
||||
case "$RE_ACK1_CODE" in
|
||||
200|404)
|
||||
echo "PASS: re-ack returns 200 or 404 ($RE_ACK1_CODE)"
|
||||
PASS=$((PASS + 1))
|
||||
;;
|
||||
*)
|
||||
echo "FAIL: re-ack returned unexpected $RE_ACK1_CODE"
|
||||
FAIL=$((FAIL + 1))
|
||||
;;
|
||||
esac
|
||||
|
||||
# ---------- Phase 7: GET content after ack returns 404 ----------
|
||||
echo ""
|
||||
echo "--- Phase 7: Acked file 404s on subsequent fetch ---"
|
||||
|
||||
POST_ACK=$(curl -s -w '\n%{http_code}' --max-time "$TIMEOUT" -H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$FID1/content")
|
||||
POST_ACK_CODE=$(printf '%s' "$POST_ACK" | tail -n1)
|
||||
check_eq "acked alpha returns HTTP 404" "404" "$POST_ACK_CODE"
|
||||
|
||||
# ---------- Phase 8: cross-workspace bleed protection ----------
|
||||
echo ""
|
||||
echo "--- Phase 8: Workspace B cannot read workspace A's pending uploads ---"
|
||||
|
||||
# Stage a fresh upload on workspace A so we have an UN-acked row to probe.
|
||||
PROBE_FILE="$TMPDIR_E2E/probe.txt"
|
||||
printf '%s' "probe-bytes-$(openssl rand -hex 4)" > "$PROBE_FILE"
|
||||
PROBE_UP=$(curl -s -X POST "$BASE/workspaces/$WS_A/chat/uploads" \
|
||||
-H "Authorization: Bearer $TOK_A" \
|
||||
-F "files=@$PROBE_FILE;filename=probe.txt;type=text/plain")
|
||||
PROBE_FID=$(echo "$PROBE_UP" | python3 -c 'import sys,json; d=json.load(sys.stdin); print(d["files"][0]["uri"].split("/")[-1])')
|
||||
[ -n "$PROBE_FID" ] || { echo "FAIL: probe upload returned no file_id"; FAIL=$((FAIL + 1)); exit 1; }
|
||||
|
||||
# Register a SECOND poll-mode workspace and capture its bearer.
|
||||
REG_B=$(curl -s -X POST "$BASE/registry/register" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"id\": \"$WS_B\",
|
||||
\"delivery_mode\": \"poll\",
|
||||
\"agent_card\": {\"name\": \"poll-chat-upload-test-b\"}
|
||||
}")
|
||||
check "second workspace registers" '"status":"registered"' "$REG_B"
|
||||
TOK_B=$(echo "$REG_B" | e2e_extract_token || true)
|
||||
[ -n "$TOK_B" ] || { echo "FAIL: no auth_token (ws B)"; FAIL=$((FAIL + 1)); exit 1; }
|
||||
|
||||
# B's bearer hitting B's URL with A's file_id → 404 (handler checks the row's
|
||||
# workspace_id matches the URL :id, not the bearer's workspace).
|
||||
CROSS_RESP=$(curl -s -w '\n%{http_code}' --max-time "$TIMEOUT" \
|
||||
-H "Authorization: Bearer $TOK_B" \
|
||||
"$BASE/workspaces/$WS_B/pending-uploads/$PROBE_FID/content")
|
||||
CROSS_CODE=$(printf '%s' "$CROSS_RESP" | tail -n1)
|
||||
check_eq "B's URL with A's file_id returns 404" "404" "$CROSS_CODE"
|
||||
|
||||
# B's bearer hitting A's URL → 401 (wsAuth pins bearer to :id). This is the
|
||||
# strictest cross-workspace check: a presented-but-wrong bearer is rejected
|
||||
# in EVERY platform posture (dev-mode fail-open only triggers when no bearer
|
||||
# is presented at all — invalid tokens always 401).
|
||||
WRONG_BEARER=$(curl -s -w '\n%{http_code}' --max-time "$TIMEOUT" \
|
||||
-H "Authorization: Bearer $TOK_B" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/$PROBE_FID/content")
|
||||
WRONG_CODE=$(printf '%s' "$WRONG_BEARER" | tail -n1)
|
||||
check_eq "B's bearer on A's URL returns 401" "401" "$WRONG_CODE"
|
||||
|
||||
# NB: a fully bearerless request to /pending-uploads/:fid/content returns
|
||||
# 401 ONLY when the platform has MOLECULE_ENV != development (production /
|
||||
# staging). On local-dev with MOLECULE_ENV=development the wsauth middleware
|
||||
# fail-opens for bearerless requests so the canvas at :3000 can talk to the
|
||||
# platform at :8080 without per-call token plumbing — see middleware/
|
||||
# devmode.go. The strict bearerless-401 contract is covered by the wsauth
|
||||
# unit + middleware tests; we don't reassert it here because the result
|
||||
# depends on platform posture, not the poll-mode upload contract.
|
||||
|
||||
# ---------- Phase 9: invalid file_id rejected at the URL parser ----------
|
||||
echo ""
|
||||
echo "--- Phase 9: Invalid file_id returns 400 ---"
|
||||
|
||||
BAD_FID=$(curl -s -w '\n%{http_code}' --max-time "$TIMEOUT" \
|
||||
-H "Authorization: Bearer $TOK_A" \
|
||||
"$BASE/workspaces/$WS_A/pending-uploads/not-a-uuid/content")
|
||||
BAD_FID_CODE=$(printf '%s' "$BAD_FID" | tail -n1)
|
||||
check_eq "invalid file_id UUID returns 400" "400" "$BAD_FID_CODE"
|
||||
|
||||
# ---------- Results ----------
|
||||
echo ""
|
||||
echo "=== Results: $PASS passed, $FAIL failed ==="
|
||||
[ "$FAIL" -eq 0 ]
|
||||
@@ -94,6 +94,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
@@ -142,6 +149,13 @@ services:
|
||||
CP_UPSTREAM_URL: "http://cp-stub:9090"
|
||||
RATE_LIMIT: "1000"
|
||||
CANVAS_PROXY_URL: "http://localhost:3000"
|
||||
# Memory v2 sidecar (PR #2906) bundles the plugin into the
|
||||
# tenant image and starts it before the main server. The plugin
|
||||
# runs `CREATE EXTENSION vector` on first boot, which fails on
|
||||
# the harness's plain postgres:15-alpine (no pgvector). The
|
||||
# harness doesn't exercise memory features, so disable the
|
||||
# sidecar via the entrypoint's documented escape hatch.
|
||||
MEMORY_PLUGIN_DISABLE: "1"
|
||||
networks: [harness-net]
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O- http://localhost:8080/health || exit 1"]
|
||||
|
||||
@@ -21,6 +21,14 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Bundle the built-in memory-plugin-postgres binary so an operator can
|
||||
# activate Memory v2 by setting MEMORY_V2_CUTOVER=true + (default)
|
||||
# MEMORY_PLUGIN_URL=http://localhost:9100. The entrypoint starts this
|
||||
# binary in the background; main /platform talks to it over loopback.
|
||||
# Stays inert until the operator flips the cutover env var.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# Clone templates + plugins at build time from manifest.json
|
||||
FROM alpine:3.20 AS templates
|
||||
@@ -30,8 +38,9 @@ COPY scripts/clone-manifest.sh /scripts/clone-manifest.sh
|
||||
RUN chmod +x /scripts/clone-manifest.sh && /scripts/clone-manifest.sh /manifest.json /workspace-configs-templates /org-templates /plugins
|
||||
|
||||
FROM alpine:3.20
|
||||
RUN apk add --no-cache ca-certificates git tzdata
|
||||
RUN apk add --no-cache ca-certificates git tzdata wget
|
||||
COPY --from=builder /platform /platform
|
||||
COPY --from=builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
COPY --from=templates /workspace-configs-templates /workspace-configs-templates
|
||||
COPY --from=templates /org-templates /org-templates
|
||||
@@ -41,6 +50,7 @@ RUN addgroup -g 1000 platform && adduser -u 1000 -G platform -s /bin/sh -D platf
|
||||
EXPOSE 8080
|
||||
COPY <<'ENTRY' /entrypoint.sh
|
||||
#!/bin/sh
|
||||
# Set up docker-socket group (unchanged from pre-sidecar entrypoint).
|
||||
if [ -S /var/run/docker.sock ]; then
|
||||
SOCK_GID=$(stat -c '%g' /var/run/docker.sock 2>/dev/null || stat -f '%g' /var/run/docker.sock 2>/dev/null)
|
||||
if [ -n "$SOCK_GID" ] && [ "$SOCK_GID" != "0" ]; then
|
||||
@@ -50,6 +60,61 @@ if [ -S /var/run/docker.sock ]; then
|
||||
addgroup platform root 2>/dev/null || true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). Co-located with the
|
||||
# main server so operators flipping MEMORY_V2_CUTOVER=true don't need
|
||||
# to provision a separate service.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it — either MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set.
|
||||
# Without that signal, the sidecar adds zero value (the platform's
|
||||
# wiring.go skips building the client too) but pays a real cost: the
|
||||
# plugin's first migration runs `CREATE EXTENSION vector`, which fails
|
||||
# on tenant Postgres without pgvector preinstalled and aborts container
|
||||
# boot via the 30s health gate. Caught on staging redeploy 2026-05-05.
|
||||
#
|
||||
# Env defaults (when sidecar IS spawned):
|
||||
# MEMORY_PLUGIN_DATABASE_URL = $DATABASE_URL (share existing Postgres;
|
||||
# plugin's `memory_namespaces` / `memory_records` tables coexist
|
||||
# with `agent_memories` and the rest of the platform schema —
|
||||
# no conflicts. Operator can override with a separate URL.)
|
||||
# MEMORY_PLUGIN_LISTEN_ADDR = 127.0.0.1:9100
|
||||
#
|
||||
# Set MEMORY_PLUGIN_DISABLE=1 to force-skip the sidecar even with
|
||||
# cutover env set (e.g. running the plugin externally on a separate host).
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
# Drop privs to the platform user — the plugin doesn't need root and
|
||||
# runs unprivileged elsewhere (tenant image already starts as canvas).
|
||||
su-exec platform /memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for the plugin's /v1/health to return 200. Boot
|
||||
# failure here is fatal — better to crash-loop than to silently
|
||||
# serve cutover traffic against a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check that DATABASE_URL is reachable, has the pgvector extension, and the plugin's migrations applied." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
exec su-exec platform /platform "$@"
|
||||
ENTRY
|
||||
RUN chmod +x /entrypoint.sh && apk add --no-cache su-exec
|
||||
|
||||
@@ -34,6 +34,13 @@ ARG GIT_SHA=dev
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /platform ./cmd/server
|
||||
# Memory v2 sidecar binary (Memory v2 #2728). Bundled so an operator
|
||||
# can activate cutover by flipping MEMORY_V2_CUTOVER=true without
|
||||
# provisioning a separate service. See entrypoint-tenant.sh for the
|
||||
# launch logic.
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build \
|
||||
-ldflags "-X github.com/Molecule-AI/molecule-monorepo/platform/internal/buildinfo.GitSHA=${GIT_SHA}" \
|
||||
-o /memory-plugin ./cmd/memory-plugin-postgres
|
||||
|
||||
# ── Stage 2: Canvas Next.js standalone ────────────────────────────────
|
||||
FROM node:20-alpine AS canvas-builder
|
||||
@@ -74,8 +81,9 @@ RUN deluser --remove-home node 2>/dev/null || true; \
|
||||
delgroup node 2>/dev/null || true; \
|
||||
addgroup -g 1000 canvas && adduser -u 1000 -G canvas -s /bin/sh -D canvas
|
||||
|
||||
# Go platform binary
|
||||
# Go platform binary + Memory v2 sidecar
|
||||
COPY --from=go-builder /platform /platform
|
||||
COPY --from=go-builder /memory-plugin /memory-plugin
|
||||
COPY workspace-server/migrations /migrations
|
||||
|
||||
# Templates + plugins (cloned from GitHub in stage 3)
|
||||
@@ -91,7 +99,7 @@ COPY --from=canvas-builder /canvas/public ./public
|
||||
|
||||
COPY workspace-server/entrypoint-tenant.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh && \
|
||||
chown -R canvas:canvas /canvas /platform /migrations
|
||||
chown -R canvas:canvas /canvas /platform /memory-plugin /migrations
|
||||
|
||||
EXPOSE 8080
|
||||
# entrypoint.sh starts as root to fix volume perms, then drops to
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestLoadConfig_DefaultListenAddrIsLoopback pins the default-bind contract.
|
||||
//
|
||||
// Why this matters: with the prior `:9100` default, the plugin listened on
|
||||
// every interface. Inside the container it didn't matter (no host port
|
||||
// mapping today), but a future change that publishes 9100 OR a cross-host
|
||||
// sidecar deploy would have exposed an unauth'd memory store. Loopback by
|
||||
// default is the least-privilege baseline; operators with a multi-host
|
||||
// topology override via MEMORY_PLUGIN_LISTEN_ADDR.
|
||||
func TestLoadConfig_DefaultListenAddrIsLoopback(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", "")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(cfg.ListenAddr, "127.0.0.1:") {
|
||||
t.Errorf("default ListenAddr must bind loopback-only, got %q "+
|
||||
"(security regression — would expose plugin on every interface)",
|
||||
cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_ListenAddrEnvOverride(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "postgres://stub")
|
||||
t.Setenv("MEMORY_PLUGIN_LISTEN_ADDR", ":9100")
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("loadConfig: %v", err)
|
||||
}
|
||||
if cfg.ListenAddr != ":9100" {
|
||||
t.Errorf("env override ignored: want :9100, got %q", cfg.ListenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MissingDatabaseURL(t *testing.T) {
|
||||
t.Setenv("MEMORY_PLUGIN_DATABASE_URL", "")
|
||||
|
||||
if _, err := loadConfig(); err == nil {
|
||||
t.Fatal("loadConfig must error when MEMORY_PLUGIN_DATABASE_URL is empty")
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -26,12 +28,28 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
|
||||
)
|
||||
|
||||
// migrationsFS bundles the .up.sql files into the binary at build time
|
||||
// so the prebuilt image doesn't need the source tree at runtime. The
|
||||
// prior `os.ReadDir("cmd/memory-plugin-postgres/migrations")` path
|
||||
// only resolved during `go test` from the repo root — in the published
|
||||
// image the path didn't exist and boot failed after the 30s health gate
|
||||
// (caught on staging redeploy 2026-05-05 after PR #2906).
|
||||
//
|
||||
//go:embed migrations/*.up.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
const (
|
||||
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
|
||||
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
|
||||
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
|
||||
|
||||
defaultListenAddr = ":9100"
|
||||
// Loopback-only by default (defense in depth). The platform talks to
|
||||
// the plugin over `http://localhost:9100` from the same container, so
|
||||
// binding to all interfaces would only widen the reachable surface
|
||||
// without enabling any in-design caller. Operators running the plugin
|
||||
// on a separate host override via MEMORY_PLUGIN_LISTEN_ADDR=:9100 (or
|
||||
// some other interface).
|
||||
defaultListenAddr = "127.0.0.1:9100"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -143,32 +161,71 @@ func openDB(databaseURL string) (*sql.DB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations applies the schema migrations bundled at
|
||||
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
|
||||
// runMigrations applies the schema migrations bundled into the binary
|
||||
// via go:embed (see migrationsFS at the top of this file). Idempotent
|
||||
// on repeat boot — every migration file uses CREATE … IF NOT EXISTS.
|
||||
//
|
||||
// Implementation note: rather than embedding the full migrate engine,
|
||||
// we read the migration files at boot from a known relative path. The
|
||||
// down migrations are deliberately NOT applied here — that's a manual
|
||||
// operator action. This keeps the binary tiny and avoids dragging in
|
||||
// golang-migrate's drivers.
|
||||
// The down migrations are deliberately NOT applied here — that's a
|
||||
// manual operator action. This keeps the binary tiny and avoids
|
||||
// dragging in golang-migrate's drivers.
|
||||
//
|
||||
// MEMORY_PLUGIN_MIGRATIONS_DIR (filesystem path) is honored as an
|
||||
// override for operators who need to ship custom migrations alongside
|
||||
// the binary without rebuilding. When unset (the common case) we read
|
||||
// from the embedded FS.
|
||||
func runMigrations(db *sql.DB) error {
|
||||
// Find the migrations directory. In `go run` mode it's relative
|
||||
// to the cmd dir; in the prebuilt binary case it's expected next
|
||||
// to the binary OR via env var override.
|
||||
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
|
||||
if dir == "" {
|
||||
// Best-effort: try the cwd-relative path that works for `go test`.
|
||||
dir = "cmd/memory-plugin-postgres/migrations"
|
||||
if dir := strings.TrimSpace(os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")); dir != "" {
|
||||
return runMigrationsFromDisk(db, dir)
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
return runMigrationsFromEmbed(db)
|
||||
}
|
||||
|
||||
// runMigrationsFromEmbed applies the *.up.sql files bundled into the
|
||||
// binary at build time. Order is alphabetical (matches the on-disk
|
||||
// behavior of os.ReadDir on Linux for the same set of names).
|
||||
func runMigrationsFromEmbed(db *sql.DB) error {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
return fmt.Errorf("read embedded migrations: %w", err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
path := dir + "/" + e.Name()
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
data, err := migrationsFS.ReadFile("migrations/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read embedded %q: %w", name, err)
|
||||
}
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", name, err)
|
||||
}
|
||||
log.Printf("applied embedded migration %s", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMigrationsFromDisk preserves the legacy filesystem-path mode for
|
||||
// operator-supplied custom migrations.
|
||||
func runMigrationsFromDisk(db *sql.DB, dir string) error {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
}
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
path := dir + "/" + name
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %q: %w", path, err)
|
||||
@@ -176,7 +233,7 @@ func runMigrations(db *sql.DB) error {
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", path, err)
|
||||
}
|
||||
log.Printf("applied migration %s", e.Name())
|
||||
log.Printf("applied disk migration %s (from %s)", name, dir)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMigrationsEmbedded_ContainsCreateTable pins that the migrations
|
||||
// are bundled into the binary at build time, NOT loaded from a
|
||||
// filesystem path that doesn't exist at runtime in the published image.
|
||||
//
|
||||
// Pre-fix: PR #2906 shipped the binary without the migrations dir;
|
||||
// `os.ReadDir("cmd/memory-plugin-postgres/migrations")` errored on every
|
||||
// tenant boot, the 30s health gate aborted the container, and the
|
||||
// staging redeploy fleet job marked all tenants as failed. Embedding
|
||||
// the migrations into the binary removes the runtime path entirely.
|
||||
func TestMigrationsEmbedded_ContainsCreateTable(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("embedded migrations dir is empty — go:embed pattern matched no files")
|
||||
}
|
||||
|
||||
var seenUp bool
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
seenUp = true
|
||||
data, err := migrationsFS.ReadFile("migrations/" + e.Name())
|
||||
if err != nil {
|
||||
t.Errorf("read embedded %q: %v", e.Name(), err)
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(string(data), "CREATE TABLE") {
|
||||
t.Errorf("embedded %q has no CREATE TABLE — wrong file embedded?", e.Name())
|
||||
}
|
||||
}
|
||||
if !seenUp {
|
||||
t.Fatal("no *.up.sql in embedded migrations — runtime would have no schema to apply")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRunMigrationsFromEmbed_OrderingIsAlphabetic pins that we apply
|
||||
// migrations in deterministic alphabetical order, not in whatever
|
||||
// arbitrary order migrationsFS.ReadDir happens to return. With one
|
||||
// migration today this is moot, but a future second migration ('002_…')
|
||||
// MUST run after '001_…' or the schema is broken.
|
||||
//
|
||||
// We can't easily exercise db.Exec here (no test DB); instead pin the
|
||||
// sort step on the directory listing itself.
|
||||
func TestRunMigrationsFromEmbed_OrderingIsAlphabetic(t *testing.T) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
t.Fatalf("embedded migrations dir unreadable: %v", err)
|
||||
}
|
||||
var names []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
names = append(names, e.Name())
|
||||
}
|
||||
for i := 1; i < len(names); i++ {
|
||||
if names[i-1] > names[i] {
|
||||
t.Errorf("ReadDir returned non-sorted names; runMigrationsFromEmbed must sort. "+
|
||||
"Got %q before %q", names[i-1], names[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,51 @@ cd /canvas
|
||||
PORT=3000 HOSTNAME=0.0.0.0 node server.js &
|
||||
CANVAS_PID=$!
|
||||
|
||||
# Memory v2 sidecar (built-in postgres plugin). See Dockerfile entrypoint
|
||||
# comment for rationale.
|
||||
#
|
||||
# Spawn-gating: only start the sidecar when the operator has indicated
|
||||
# they want it (MEMORY_V2_CUTOVER=true OR MEMORY_PLUGIN_URL set).
|
||||
# Without that signal, the sidecar adds zero value and risks aborting
|
||||
# tenant boot via the 30s health gate when the tenant Postgres lacks
|
||||
# pgvector. Caught on staging redeploy 2026-05-05:
|
||||
# pq: extension "vector" is not available
|
||||
#
|
||||
# Defaults (when sidecar IS spawned): MEMORY_PLUGIN_DATABASE_URL
|
||||
# falls back to the tenant's DATABASE_URL.
|
||||
MEMORY_PLUGIN_PID=""
|
||||
memory_plugin_wanted=""
|
||||
if [ "$MEMORY_V2_CUTOVER" = "true" ] || [ -n "$MEMORY_PLUGIN_URL" ]; then
|
||||
memory_plugin_wanted=1
|
||||
fi
|
||||
if [ -z "$MEMORY_PLUGIN_DISABLE" ] && [ -n "$memory_plugin_wanted" ] && [ -n "$DATABASE_URL" ]; then
|
||||
: "${MEMORY_PLUGIN_DATABASE_URL:=$DATABASE_URL}"
|
||||
: "${MEMORY_PLUGIN_LISTEN_ADDR:=:9100}"
|
||||
export MEMORY_PLUGIN_DATABASE_URL MEMORY_PLUGIN_LISTEN_ADDR
|
||||
echo "memory-plugin: starting sidecar on $MEMORY_PLUGIN_LISTEN_ADDR" >&2
|
||||
/memory-plugin &
|
||||
MEMORY_PLUGIN_PID=$!
|
||||
# Wait up to 30s for /v1/health. Boot failure is fatal so a misconfigured
|
||||
# tenant crash-loops instead of silently serving cutover traffic against
|
||||
# a dead plugin.
|
||||
health_port=${MEMORY_PLUGIN_LISTEN_ADDR#:}
|
||||
ready=0
|
||||
for _ in $(seq 1 30); do
|
||||
if wget -qO- --timeout=2 "http://localhost:${health_port}/v1/health" >/dev/null 2>&1; then
|
||||
ready=1
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if [ "$ready" != "1" ]; then
|
||||
echo "memory-plugin: ❌ /v1/health never returned 200 after 30s — aborting boot. Check DATABASE_URL reachability + pgvector extension + migrations." >&2
|
||||
kill "$MEMORY_PLUGIN_PID" 2>/dev/null || true
|
||||
kill "$CANVAS_PID" 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
echo "memory-plugin: ✅ sidecar healthy on :$health_port" >&2
|
||||
fi
|
||||
|
||||
# Start Go platform in foreground-ish (we trap signals)
|
||||
# CANVAS_PROXY_URL tells the platform to proxy unmatched routes to Canvas.
|
||||
# CONTAINER_BACKEND: empty = Docker (default for self-hosted/local).
|
||||
@@ -29,15 +74,20 @@ cd /
|
||||
/platform &
|
||||
PLATFORM_PID=$!
|
||||
|
||||
# If either process exits, kill the other
|
||||
# If any process exits, kill the others
|
||||
cleanup() {
|
||||
kill $CANVAS_PID 2>/dev/null || true
|
||||
kill $PLATFORM_PID 2>/dev/null || true
|
||||
[ -n "$MEMORY_PLUGIN_PID" ] && kill $MEMORY_PLUGIN_PID 2>/dev/null || true
|
||||
}
|
||||
trap cleanup EXIT SIGTERM SIGINT
|
||||
|
||||
# Wait for either to exit — whichever exits first triggers cleanup
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
# Wait for any to exit — whichever exits first triggers cleanup
|
||||
if [ -n "$MEMORY_PLUGIN_PID" ]; then
|
||||
wait -n $CANVAS_PID $PLATFORM_PID $MEMORY_PLUGIN_PID
|
||||
else
|
||||
wait -n $CANVAS_PID $PLATFORM_PID
|
||||
fi
|
||||
EXIT_CODE=$?
|
||||
cleanup
|
||||
exit $EXIT_CODE
|
||||
|
||||
@@ -201,7 +201,7 @@ func TestPollUpload_HappyPath_OneFile_StagesAndLogs(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"report.pdf": []byte("PDF-bytes")})
|
||||
@@ -259,7 +259,7 @@ func TestPollUpload_MultipleFiles_AllStagedAndLogged(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
@@ -297,7 +297,7 @@ func TestPollUpload_PushModeFallsThroughToForward(t *testing.T) {
|
||||
// URL empty + mode=push → 503 (no inbound secret check needed).
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
@@ -321,7 +321,7 @@ func TestPollUpload_NotConfigured_FallsThrough(t *testing.T) {
|
||||
wsID := "33333333-2222-3333-4444-555555555555"
|
||||
expectURLAndMode(mock, wsID, "", "poll") // resolveWorkspaceForwardCreds emits 422
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
// No WithPendingUploads — pendingUploads is nil.
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("data")})
|
||||
@@ -342,7 +342,7 @@ func TestPollUpload_WorkspaceMissing_404(t *testing.T) {
|
||||
wsID := "44444444-2222-3333-4444-555555555555"
|
||||
expectPollDeliveryModeMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
@@ -362,7 +362,7 @@ func TestPollUpload_DeliveryModeLookupDBError_500(t *testing.T) {
|
||||
mock.ExpectQuery(`SELECT delivery_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).WillReturnError(errors.New("connection lost"))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(newInMemStorage(), nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x": []byte("d")})
|
||||
@@ -382,7 +382,7 @@ func TestPollUpload_NoFilesField_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Multipart with a non-files field — no actual files.
|
||||
@@ -407,7 +407,7 @@ func TestPollUpload_MalformedMultipart_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Body that doesn't match the boundary in Content-Type.
|
||||
@@ -428,7 +428,7 @@ func TestPollUpload_StorageError_500(t *testing.T) {
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("disk full")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -449,7 +449,7 @@ func TestPollUpload_StorageTooLarge_413(t *testing.T) {
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = pendinguploads.ErrTooLarge
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -469,7 +469,7 @@ func TestPollUpload_TooManyFiles_400(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 65 files — over the per-batch cap.
|
||||
@@ -504,7 +504,7 @@ func TestPollUpload_NullDeliveryMode_TreatedAsPush(t *testing.T) {
|
||||
expectURLAndMode(mock, wsID, "", "")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.bin": []byte("data")})
|
||||
@@ -537,7 +537,7 @@ func TestPollUpload_PerFileCapPreStorage_413(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// 25 MB + 1 byte. Single file, large enough to trip the early
|
||||
@@ -572,7 +572,7 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
|
||||
expectActivityInsert(mock)
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"hello world!.pdf": []byte("data")})
|
||||
@@ -616,7 +616,7 @@ func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) {
|
||||
expectPollDeliveryMode(mock, wsID, "poll")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
// Two files: first OK, second over the per-file cap. Pre-validation
|
||||
@@ -653,7 +653,7 @@ func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
|
||||
|
||||
store := newInMemStorage()
|
||||
store.putErr = errors.New("db down mid-batch")
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{
|
||||
@@ -734,7 +734,7 @@ func TestPollUpload_ActivityRowDiscriminator(t *testing.T) {
|
||||
expectActivityInsertWithTypeAndMethod(mock, wsID, "a2a_receive", "chat_upload_receive")
|
||||
|
||||
store := newInMemStorage()
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
|
||||
WithPendingUploads(store, nil)
|
||||
|
||||
body, ct := pollUploadFixture(t, map[string][]byte{"x.pdf": []byte("xx")})
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestChatUpload_InvalidWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
c, w := makeUploadRequest(t, "not-a-uuid", &bytes.Buffer{}, "")
|
||||
h.Upload(c)
|
||||
@@ -122,7 +122,7 @@ func TestChatUpload_WorkspaceNotInDB(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLMissing(mock, wsID)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -166,7 +166,7 @@ func TestChatUpload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -203,7 +203,7 @@ func TestChatUpload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone) // mint fails
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -231,7 +231,7 @@ func TestChatUpload_NoURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000042"
|
||||
expectURLAndMode(mock, wsID, "", "push")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -256,7 +256,7 @@ func TestChatUpload_PollModeEmptyURL(t *testing.T) {
|
||||
wsID := "00000000-0000-0000-0000-000000000099"
|
||||
expectURLAndMode(mock, wsID, "", "poll")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -286,7 +286,7 @@ func TestChatUpload_NullModeEmptyURL(t *testing.T) {
|
||||
wsID := "30ba7f0b-b303-4a20-aefe-3a4a675b8aa4" // user's "mac laptop"
|
||||
expectURLNullMode(mock, wsID, "")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -338,7 +338,7 @@ func TestChatUpload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "super-secret-123")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -380,7 +380,7 @@ func TestChatUpload_ForwardsErrorStatusUnchanged(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -402,7 +402,7 @@ func TestChatUpload_WorkspaceUnreachable(t *testing.T) {
|
||||
expectURL(mock, wsID, "http://127.0.0.1:1")
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
body, ct := uploadFixture(t)
|
||||
c, w := makeUploadRequest(t, wsID, body, ct)
|
||||
h.Upload(c)
|
||||
@@ -418,7 +418,7 @@ func TestChatDownload_InvalidPath(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
|
||||
cases := []struct {
|
||||
name, path, wantSubstr string
|
||||
@@ -507,7 +507,7 @@ func TestChatDownload_WorkspaceNotInDB(t *testing.T) {
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -533,7 +533,7 @@ func TestChatDownload_NoInboundSecret_LazyHeal(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -559,7 +559,7 @@ func TestChatDownload_NoInboundSecret_LazyHealFailure(t *testing.T) {
|
||||
WithArgs(sqlmock.AnyArg(), wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/foo.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -592,7 +592,7 @@ func TestChatDownload_ForwardsToWorkspace_HappyPath(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "the-secret")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/report.txt")
|
||||
h.Download(c)
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestChatDownload_404FromWorkspacePropagated(t *testing.T) {
|
||||
expectURL(mock, wsID, srv.URL)
|
||||
expectInboundSecret(mock, wsID, "tok")
|
||||
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil))
|
||||
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil))
|
||||
c, w := makeDownloadRequest(t, wsID, "/workspace/missing.txt")
|
||||
h.Download(c)
|
||||
|
||||
|
||||
@@ -0,0 +1,468 @@
|
||||
package handlers
|
||||
|
||||
// class1_ast_gate_test.go — generic Class 1 leak gate per #2867 PR-A.
|
||||
//
|
||||
// What this gate prevents:
|
||||
// The tenant-hongming leak class — a handler iterates a YAML-derived
|
||||
// slice (ws.Children, sub_workspaces, etc.) and calls
|
||||
// `INSERT INTO workspaces` inside the loop body without first
|
||||
// checking whether a workspace with the same (parent_id, name) is
|
||||
// already there. Each call to such a handler doubles the tree.
|
||||
//
|
||||
// Why this is broader than TestCreateWorkspaceTree_CallsLookupBeforeInsert:
|
||||
// The existing gate is hard-coded to org_import.go's createWorkspaceTree.
|
||||
// That catches the specific function that triggered the original
|
||||
// incident — but a future handler written from scratch in a different
|
||||
// file would not be covered. This gate walks every production handler
|
||||
// .go file and applies a structural rule that does not depend on
|
||||
// function or file names.
|
||||
//
|
||||
// The rule (verbatim from #2867 PR-A):
|
||||
//
|
||||
// "No handler in handlers/ may iterate a slice (any RangeStmt) AND
|
||||
// call INSERT INTO workspaces inside the loop body without a
|
||||
// preceding SELECT id FROM workspaces WHERE name=$1 AND parent_id IS
|
||||
// NOT DISTINCT FROM $2 in the same function (== a lookupExistingChild
|
||||
// call, OR an ON CONFLICT clause baked into the same INSERT, OR an
|
||||
// explicit allowlist annotation)."
|
||||
//
|
||||
// Allowlist mechanism: a function whose body contains the exact comment
|
||||
// string `// class1-gate: idempotent-by-design` is treated as safe.
|
||||
// Use this only after writing a unit test that pins WHY the function
|
||||
// is safe. The annotation is intentionally awkward to type — it should
|
||||
// be rare.
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// reINSERTWorkspaces matches the exact statement shape we care about.
|
||||
// Tightened (vs bytes.Index "INSERT INTO workspaces") so the audit
|
||||
// table `workspaces_audit` literal — or any other lookalike — does not
|
||||
// false-positive trigger this gate. The same regex is used in the
|
||||
// existing createWorkspaceTree gate (workspaces_insert_allowlist_test.go)
|
||||
// — keep them in sync if either changes.
|
||||
var reINSERTWorkspaces = regexp.MustCompile(`(?m)^\s*INSERT INTO workspaces\s*\(`)
|
||||
|
||||
// reONCONFLICT matches ON CONFLICT clauses anywhere in the same SQL
|
||||
// literal. An UPSERT (INSERT ... ON CONFLICT ... DO UPDATE) is
|
||||
// idempotent by definition, so the gate exempts it.
|
||||
var reONCONFLICT = regexp.MustCompile(`(?i)\bON CONFLICT\b`)
|
||||
|
||||
// gateAllowlistComment is the magic comment a function author writes
|
||||
// to opt out of this gate. Forces an explicit decision.
|
||||
const gateAllowlistComment = "// class1-gate: idempotent-by-design"
|
||||
|
||||
// preflightCallNames are function names whose presence in a function
|
||||
// body counts as "did a SELECT-by-(parent_id, name) preflight". Add
|
||||
// new names here as new preflight helpers are introduced. Keep the
|
||||
// list TIGHT — any sloppy addition weakens the gate.
|
||||
var preflightCallNames = map[string]bool{
|
||||
"lookupExistingChild": true,
|
||||
}
|
||||
|
||||
// TestClass1_NoUnpreflightedInsertInsideRange walks every production
|
||||
// .go file in this package, parses the AST, and fails the test if any
|
||||
// FuncDecl violates the rule above.
|
||||
//
|
||||
// Failure message must include: file path, function name, line of
|
||||
// the offending INSERT, line of the enclosing range, and a hint at
|
||||
// the three escape hatches (preflight call, ON CONFLICT, allowlist
|
||||
// comment).
|
||||
func TestClass1_NoUnpreflightedInsertInsideRange(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("readdir %s: %v", wd, err)
|
||||
}
|
||||
|
||||
type violation struct {
|
||||
file string
|
||||
fn string
|
||||
insertLine int
|
||||
rangeLine int
|
||||
}
|
||||
var violations []violation
|
||||
scanned := 0
|
||||
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if e.IsDir() || !strings.HasSuffix(name, ".go") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(wd, name)
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read %s: %v", path, err)
|
||||
}
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, name, src, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s: %v", path, err)
|
||||
}
|
||||
scanned++
|
||||
|
||||
// Walk every function declaration and apply the rule.
|
||||
for _, decl := range file.Decls {
|
||||
fd, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fd.Body == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Allowlist: skip if the function body contains the magic
|
||||
// comment. We check via the source range of the function
|
||||
// — comments inside the body are in file.Comments and
|
||||
// must overlap the function's Pos/End range.
|
||||
if functionHasAllowlistComment(file, fd) {
|
||||
continue
|
||||
}
|
||||
|
||||
// First pass: locate every INSERT INTO workspaces literal
|
||||
// in this function. We treat each such literal as a
|
||||
// candidate violation and try to clear it via the rules.
|
||||
candidates := findInsertWorkspacesLiterals(fd, src, fset)
|
||||
if len(candidates) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Has the function called a preflight helper? Single
|
||||
// pass — if any preflight name appears, every INSERT in
|
||||
// the function is considered preflighted. This is more
|
||||
// permissive than position-aware (preflight could be
|
||||
// AFTER the INSERT and still satisfy the gate), but the
|
||||
// existing org_import.go gate already pins the position
|
||||
// invariant for createWorkspaceTree, and a function that
|
||||
// preflights AFTER inserting would fail the position
|
||||
// gate in a separate test.
|
||||
hasPreflight := functionCallsAny(fd, preflightCallNames)
|
||||
|
||||
for _, c := range candidates {
|
||||
if c.hasONCONFLICT {
|
||||
continue
|
||||
}
|
||||
if hasPreflight {
|
||||
continue
|
||||
}
|
||||
if c.enclosingRangeLine == 0 {
|
||||
// INSERT not inside any RangeStmt — single-shot,
|
||||
// not the bug pattern.
|
||||
continue
|
||||
}
|
||||
violations = append(violations, violation{
|
||||
file: name,
|
||||
fn: fd.Name.Name,
|
||||
insertLine: c.insertLine,
|
||||
rangeLine: c.enclosingRangeLine,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if scanned == 0 {
|
||||
t.Fatal("scanned 0 .go files — wrong working directory? gate would always pass")
|
||||
}
|
||||
|
||||
if len(violations) > 0 {
|
||||
// Stable sort so the failure message is deterministic across
|
||||
// reruns.
|
||||
sort.Slice(violations, func(i, j int) bool {
|
||||
if violations[i].file != violations[j].file {
|
||||
return violations[i].file < violations[j].file
|
||||
}
|
||||
return violations[i].insertLine < violations[j].insertLine
|
||||
})
|
||||
var b strings.Builder
|
||||
b.WriteString("Class 1 leak gate (#2867 PR-A) — these handler functions iterate a slice and INSERT INTO workspaces inside the loop body without a (parent_id, name) preflight.\n\n")
|
||||
b.WriteString("This is the bug shape that triggered the tenant-hongming leak (TeamHandler.Expand re-inserting the entire sub_workspaces tree on every call). To fix any reported violation, choose ONE of:\n")
|
||||
b.WriteString(" 1. Call h.lookupExistingChild(ctx, name, parentID) before the INSERT and skip the INSERT when it returns existing=true. (preferred)\n")
|
||||
b.WriteString(" 2. Use INSERT ... ON CONFLICT ... DO ... (idempotent UPSERT, like registry.go).\n")
|
||||
b.WriteString(" 3. Annotate the function with a `// class1-gate: idempotent-by-design` comment AND a unit test that pins why the function is structurally idempotent. (rare; require code review)\n\n")
|
||||
b.WriteString("Violations:\n")
|
||||
for _, v := range violations {
|
||||
b.WriteString(" - ")
|
||||
b.WriteString(v.file)
|
||||
b.WriteString(":")
|
||||
b.WriteString(itoa(v.insertLine))
|
||||
b.WriteString(" — function ")
|
||||
b.WriteString(v.fn)
|
||||
b.WriteString("() INSERTs inside RangeStmt at line ")
|
||||
b.WriteString(itoa(v.rangeLine))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
t.Fatal(b.String())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
// Avoid strconv import for one call site — keeps the test focused.
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
neg := n < 0
|
||||
if neg {
|
||||
n = -n
|
||||
}
|
||||
var buf [20]byte
|
||||
i := len(buf)
|
||||
for n > 0 {
|
||||
i--
|
||||
buf[i] = byte('0' + n%10)
|
||||
n /= 10
|
||||
}
|
||||
if neg {
|
||||
i--
|
||||
buf[i] = '-'
|
||||
}
|
||||
return string(buf[i:])
|
||||
}
|
||||
|
||||
// candidateInsert holds the per-INSERT facts needed to decide whether
|
||||
// the gate fires.
|
||||
type candidateInsert struct {
|
||||
insertLine int
|
||||
hasONCONFLICT bool
|
||||
enclosingRangeLine int // 0 means not inside any range
|
||||
}
|
||||
|
||||
// findInsertWorkspacesLiterals walks fd's body and returns one
|
||||
// candidateInsert per INSERT INTO workspaces string literal.
|
||||
//
|
||||
// Position-based detection: collect every RangeStmt's body span first,
|
||||
// then for each INSERT literal check if its position is inside any
|
||||
// span. ast.Inspect's nil-call ordering does NOT give per-node pop
|
||||
// semantics, so a stack-based approach against ast.Inspect would
|
||||
// silently miscount. Position spans are deterministic and easy to
|
||||
// reason about.
|
||||
func findInsertWorkspacesLiterals(fd *ast.FuncDecl, src []byte, fset *token.FileSet) []candidateInsert {
|
||||
var out []candidateInsert
|
||||
|
||||
type span struct{ start, end token.Pos }
|
||||
var ranges []span
|
||||
ast.Inspect(fd.Body, func(n ast.Node) bool {
|
||||
rs, ok := n.(*ast.RangeStmt)
|
||||
if !ok || rs.Body == nil {
|
||||
return true
|
||||
}
|
||||
ranges = append(ranges, span{rs.Body.Lbrace, rs.Body.Rbrace})
|
||||
return true
|
||||
})
|
||||
|
||||
enclosingRangeLineFor := func(p token.Pos) int {
|
||||
// Pick the innermost enclosing range — i.e., the one with the
|
||||
// largest start that still covers p. Innermost is the one
|
||||
// whose body actually contains the INSERT, which is the line
|
||||
// most useful in a violation message.
|
||||
bestStart := token.NoPos
|
||||
bestLine := 0
|
||||
for _, s := range ranges {
|
||||
if p > s.start && p < s.end && s.start > bestStart {
|
||||
bestStart = s.start
|
||||
bestLine = fset.Position(s.start).Line
|
||||
}
|
||||
}
|
||||
return bestLine
|
||||
}
|
||||
|
||||
ast.Inspect(fd.Body, func(n ast.Node) bool {
|
||||
bl, ok := n.(*ast.BasicLit)
|
||||
if !ok || bl.Kind != token.STRING {
|
||||
return true
|
||||
}
|
||||
// Strip surrounding backticks/quotes — value includes them.
|
||||
lit := bl.Value
|
||||
if len(lit) >= 2 {
|
||||
lit = lit[1 : len(lit)-1]
|
||||
}
|
||||
if !reINSERTWorkspaces.MatchString(lit) {
|
||||
return true
|
||||
}
|
||||
out = append(out, candidateInsert{
|
||||
insertLine: fset.Position(bl.Pos()).Line,
|
||||
hasONCONFLICT: reONCONFLICT.MatchString(lit),
|
||||
enclosingRangeLine: enclosingRangeLineFor(bl.Pos()),
|
||||
})
|
||||
return true
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// functionCallsAny returns true if any CallExpr in fd's body has a
|
||||
// function name (either a SelectorExpr Sel.Name or an Ident name)
|
||||
// matching a key in names.
|
||||
func functionCallsAny(fd *ast.FuncDecl, names map[string]bool) bool {
|
||||
found := false
|
||||
ast.Inspect(fd.Body, func(n ast.Node) bool {
|
||||
if found {
|
||||
return false
|
||||
}
|
||||
ce, ok := n.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
switch fun := ce.Fun.(type) {
|
||||
case *ast.Ident:
|
||||
if names[fun.Name] {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
case *ast.SelectorExpr:
|
||||
if names[fun.Sel.Name] {
|
||||
found = true
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// functionHasAllowlistComment returns true if the function body
|
||||
// (between fd.Body.Lbrace and fd.Body.Rbrace) contains a comment
|
||||
// equal to gateAllowlistComment.
|
||||
func functionHasAllowlistComment(file *ast.File, fd *ast.FuncDecl) bool {
|
||||
if fd.Body == nil {
|
||||
return false
|
||||
}
|
||||
start := fd.Body.Lbrace
|
||||
end := fd.Body.Rbrace
|
||||
for _, cg := range file.Comments {
|
||||
for _, c := range cg.List {
|
||||
if c.Pos() < start || c.Pos() > end {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(c.Text) == gateAllowlistComment {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestClass1_GateFiresOnSyntheticBuggySource — proves the gate actually
|
||||
// catches the bug shape it's named after. Without this, a regression
|
||||
// to "always pass" would not be noticed until the leak shipped again.
|
||||
// Per memory feedback_assert_exact_not_substring.md: tighten the test
|
||||
// + verify it FAILS on old-shape source before merging.
|
||||
func TestClass1_GateFiresOnSyntheticBuggySource(t *testing.T) {
|
||||
const buggySrc = `package handlers
|
||||
|
||||
import "context"
|
||||
|
||||
type fakeDB struct{}
|
||||
func (fakeDB) ExecContext(ctx context.Context, sql string, args ...interface{}) {}
|
||||
|
||||
func buggyExpand(db fakeDB, ctx context.Context, children []string) {
|
||||
for _, child := range children {
|
||||
// Bug shape: INSERT inside the range body, no preflight.
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces (id, name) VALUES ($1, $2)`" + `, "x", child)
|
||||
}
|
||||
}
|
||||
`
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, "buggy.go", buggySrc, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse synthetic source: %v", err)
|
||||
}
|
||||
for _, decl := range file.Decls {
|
||||
fd, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fd.Name.Name != "buggyExpand" {
|
||||
continue
|
||||
}
|
||||
candidates := findInsertWorkspacesLiterals(fd, []byte(buggySrc), fset)
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected 1 INSERT literal, got %d", len(candidates))
|
||||
}
|
||||
c := candidates[0]
|
||||
if c.enclosingRangeLine == 0 {
|
||||
t.Errorf("synthetic INSERT inside `for _, child := range` should be detected as enclosed by range, got enclosingRangeLine=0 — gate would miss the bug shape")
|
||||
}
|
||||
if c.hasONCONFLICT {
|
||||
t.Errorf("synthetic INSERT has no ON CONFLICT, gate falsely treated it as idempotent")
|
||||
}
|
||||
if functionCallsAny(fd, preflightCallNames) {
|
||||
t.Errorf("synthetic function does not call lookupExistingChild — gate falsely treated it as preflighted")
|
||||
}
|
||||
// All three guards say the gate WOULD fire. Pass.
|
||||
return
|
||||
}
|
||||
t.Fatal("buggyExpand FuncDecl not found in synthetic source")
|
||||
}
|
||||
|
||||
// TestClass1_GateAllowsONCONFLICT — pins that an INSERT with ON
|
||||
// CONFLICT inside a range body is NOT flagged. registry.go's
|
||||
// upsert pattern is the prod example.
|
||||
func TestClass1_GateAllowsONCONFLICT(t *testing.T) {
|
||||
const safeSrc = `package handlers
|
||||
|
||||
import "context"
|
||||
|
||||
type fakeDB struct{}
|
||||
func (fakeDB) ExecContext(ctx context.Context, sql string, args ...interface{}) {}
|
||||
|
||||
func upsertLoop(db fakeDB, ctx context.Context, children []string) {
|
||||
for _, child := range children {
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces (id, name) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET name = $2`" + `, "x", child)
|
||||
}
|
||||
}
|
||||
`
|
||||
fset := token.NewFileSet()
|
||||
file, _ := parser.ParseFile(fset, "safe.go", safeSrc, parser.ParseComments)
|
||||
for _, decl := range file.Decls {
|
||||
fd, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fd.Name.Name != "upsertLoop" {
|
||||
continue
|
||||
}
|
||||
candidates := findInsertWorkspacesLiterals(fd, []byte(safeSrc), fset)
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected 1 candidate, got %d", len(candidates))
|
||||
}
|
||||
if !candidates[0].hasONCONFLICT {
|
||||
t.Errorf("ON CONFLICT clause should be detected, was missed — gate would falsely flag idempotent UPSERTs")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestClass1_GateAllowsAllowlistAnnotation — pins the escape hatch
|
||||
// works. Annotated functions are skipped at the FuncDecl level.
|
||||
func TestClass1_GateAllowsAllowlistAnnotation(t *testing.T) {
|
||||
const annotatedSrc = `package handlers
|
||||
|
||||
import "context"
|
||||
|
||||
type fakeDB struct{}
|
||||
func (fakeDB) ExecContext(ctx context.Context, sql string, args ...interface{}) {}
|
||||
|
||||
func intentionallyUnpreflighted(db fakeDB, ctx context.Context, children []string) {
|
||||
// class1-gate: idempotent-by-design
|
||||
for _, child := range children {
|
||||
db.ExecContext(ctx, ` + "`INSERT INTO workspaces (id, name) VALUES ($1, $2)`" + `, "x", child)
|
||||
}
|
||||
}
|
||||
`
|
||||
fset := token.NewFileSet()
|
||||
file, _ := parser.ParseFile(fset, "annotated.go", annotatedSrc, parser.ParseComments)
|
||||
for _, decl := range file.Decls {
|
||||
fd, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fd.Name.Name != "intentionallyUnpreflighted" {
|
||||
continue
|
||||
}
|
||||
if !functionHasAllowlistComment(file, fd) {
|
||||
t.Error("allowlist comment should be detected for the intentionallyUnpreflighted function — escape hatch not working")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -198,6 +198,13 @@ const externalUniversalMcpTemplate = `# Universal MCP — standalone register +
|
||||
# Pair with the Claude Code or Python SDK tab if your runtime needs
|
||||
# inbound A2A delivery (canvas messages → agent conversation turns).
|
||||
|
||||
# Requires Python >= 3.11. On 3.10 or older pip says
|
||||
# "Could not find a version that satisfies the requirement
|
||||
# (from versions: none)" — the wheel's requires_python pin filters
|
||||
# the only available artifact before pip even attempts install.
|
||||
# Upgrade the interpreter (brew install python@3.12 / apt install
|
||||
# python3.12 / etc.) or use a 3.11+ venv.
|
||||
|
||||
# 1. Install the workspace runtime wheel:
|
||||
pip install molecule-ai-workspace-runtime
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/scheduler"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -61,20 +63,33 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
tier = defaults.Tier
|
||||
}
|
||||
if tier == 0 {
|
||||
// SaaS-aware fallback. SaaS → T4 (one container per sibling
|
||||
// EC2, no neighbour to protect from). Self-hosted → T2
|
||||
// (safe shared-Docker-daemon default — many workspaces in
|
||||
// one kernel). Templates that want a different floor
|
||||
// declare `tier:` in their config.yaml or the org-template's
|
||||
// `defaults.tier`.
|
||||
if h.workspace != nil && h.workspace.IsSaaS() {
|
||||
tier = 4
|
||||
// Resolved via the same DefaultTier helper Create + Templates
|
||||
// use (#2910 PR-E). SaaS → T4 (one container per sibling EC2,
|
||||
// no neighbour to protect from), self-hosted → T3. Pre-#2910
|
||||
// this path returned T2 on self-hosted, asymmetric with
|
||||
// workspace.go's T3 — undocumented drift. Lifting to
|
||||
// DefaultTier collapses both call sites onto one source of
|
||||
// truth so a future tier-default change sweeps every entry
|
||||
// point at once. Templates that want a different floor still
|
||||
// declare `tier:` in config.yaml or `defaults.tier` in
|
||||
// org.yaml.
|
||||
if h.workspace != nil {
|
||||
tier = h.workspace.DefaultTier()
|
||||
} else {
|
||||
tier = 2
|
||||
tier = 3
|
||||
}
|
||||
}
|
||||
|
||||
ctxLookup := context.Background()
|
||||
// 5s timeout bounds the lookup independently of any HTTP request
|
||||
// context. createWorkspaceTree runs in goroutines spawned from the
|
||||
// /org/import handler, so plumbing the request context here would
|
||||
// cascade-cancel into provisionWorkspaceAuto and abort in-flight
|
||||
// EC2 provisioning if the client disconnected mid-import — that's
|
||||
// the wrong behaviour. A short bounded timeout protects the
|
||||
// per-row SELECT against a wedged DB without taking the
|
||||
// drop-everything-on-disconnect tradeoff.
|
||||
ctxLookup, cancelLookup := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelLookup()
|
||||
// Idempotency: if a workspace with the same (parent_id, name) already
|
||||
// exists, skip the INSERT + canvas_layouts + broadcast + provisioning.
|
||||
// This is what makes /org/import safe to call multiple times — the
|
||||
@@ -86,12 +101,31 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// (parent exists, some children missing) backfill the missing children
|
||||
// instead of either no-op'ing the whole subtree or duplicating the
|
||||
// existing children.
|
||||
//
|
||||
// /org/import is ADDITIVE-ONLY, never destructive. Children present
|
||||
// in the existing tree but absent from the new template are
|
||||
// preserved (no DELETE on diff). Skip-path also does NOT propagate
|
||||
// updates to existing nodes — a re-import that adds an
|
||||
// initial_memory or schedule to an existing workspace is silently
|
||||
// dropped (the function bypasses seedInitialMemories, schedule SQL,
|
||||
// channel config for skipped rows). To force-update an existing
|
||||
// tree, delete and re-import or use a future /org/sync route.
|
||||
existingID, existing, lookupErr := h.lookupExistingChild(ctxLookup, ws.Name, parentID)
|
||||
if lookupErr != nil {
|
||||
return fmt.Errorf("idempotency check for %s: %w", ws.Name, lookupErr)
|
||||
}
|
||||
if existing {
|
||||
log.Printf("Org import: %q already exists (id=%s) — skipping create+provision, recursing into children for partial-match", ws.Name, existingID)
|
||||
parentRef := ""
|
||||
if parentID != nil {
|
||||
parentRef = *parentID
|
||||
}
|
||||
provlog.Event("provision.skip_existing", map[string]any{
|
||||
"name": ws.Name,
|
||||
"existing_id": existingID,
|
||||
"parent_id": parentRef,
|
||||
"tier": tier,
|
||||
})
|
||||
*results = append(*results, map[string]interface{}{
|
||||
"id": existingID,
|
||||
"name": ws.Name,
|
||||
@@ -590,6 +624,12 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
//
|
||||
// On sql.ErrNoRows: returns ("", false, nil) — caller should INSERT.
|
||||
// On a real DB error: returns ("", false, err) — caller propagates.
|
||||
//
|
||||
// errors.Is is wrap-safe — a future caller wrapping the error
|
||||
// (database/sql can wrap driver errors with %w in some setups) would
|
||||
// silently break a `err == sql.ErrNoRows` equality check, causing the
|
||||
// no-rows path to fall through to the "real DB error" branch and
|
||||
// abort the import. errors.Is unwraps.
|
||||
func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
var existingID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
@@ -599,7 +639,7 @@ func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, paren
|
||||
AND status != 'removed'
|
||||
LIMIT 1
|
||||
`, name, parentID).Scan(&existingID)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -2,7 +2,9 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
@@ -123,6 +125,36 @@ func TestLookupExistingChild_DBError_Propagates(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupExistingChild_WrappedNoRows_TreatedAsNotFound — pins the
|
||||
// wrap-safety of the errors.Is(err, sql.ErrNoRows) check. The previous
|
||||
// `err == sql.ErrNoRows` equality would fall through to the
|
||||
// "real DB error" branch on a wrapped no-rows error, aborting the
|
||||
// import for what is in fact the no-rows happy path. driver/sql
|
||||
// wrapping is currently a non-issue but a future driver change or a
|
||||
// caller that wraps the result via fmt.Errorf("…: %w", err) would
|
||||
// silently break the equality check. errors.Is unwraps.
|
||||
func TestLookupExistingChild_WrappedNoRows_TreatedAsNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
parent := "parent-1"
|
||||
wrapped := fmt.Errorf("driver-wrapped: %w", sql.ErrNoRows)
|
||||
mock.ExpectQuery(`SELECT id FROM workspaces`).
|
||||
WithArgs("Alpha", &parent).
|
||||
WillReturnError(wrapped)
|
||||
|
||||
h := &OrgHandler{}
|
||||
id, found, err := h.lookupExistingChild(context.Background(), "Alpha", &parent)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected wrapped no-rows to be treated as not-found (err=nil), got: %v", err)
|
||||
}
|
||||
if found {
|
||||
t.Errorf("expected found=false on wrapped no-rows, got found=true")
|
||||
}
|
||||
if id != "" {
|
||||
t.Errorf("expected empty id on wrapped no-rows, got %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
// workspacesInsertRE matches a SQL literal that begins (after optional
|
||||
// leading whitespace) with `INSERT INTO workspaces` followed by `(` —
|
||||
// requiring the open-paren rules out lookalikes like
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package handlers
|
||||
|
||||
// provlog_emit_test.go — pins that the structured-logging emit sites
|
||||
// added for #2867 PR-D actually fire when their boundary is crossed.
|
||||
//
|
||||
// These are call-site contract tests, not provlog package tests (those
|
||||
// live next to the helper). The assertion is "this dispatcher path
|
||||
// emits this event name" — if a refactor moves the call out of the
|
||||
// boundary helper, the gate fails. Fields are NOT pinned here on
|
||||
// purpose; the field set is convenience for ops, not contract for the
|
||||
// emit point. Pinning fields would block additive evolution of the
|
||||
// payload (see also feedback_behavior_based_ast_gates.md).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// captureProvLog redirects the global logger to a buffer for the test
|
||||
// duration. provlog.Event uses log.Printf, so this is the only seam.
|
||||
// Returned mutex protects against concurrent reads from the goroutine
|
||||
// fired by provisionWorkspaceAuto (the goroutine never returns in
|
||||
// these tests because Start() is stubbed, but the buffer can still be
|
||||
// touched by it racing the assertion).
|
||||
func captureProvLog(t *testing.T) (read func() string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
var mu sync.Mutex
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetFlags(0)
|
||||
log.SetOutput(&safeWriter{buf: &buf, mu: &mu})
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
return func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionWorkspaceAutoSync_EmitsProvisionStart — sync variant is
|
||||
// chosen for the assertion path because it returns once the (stubbed)
|
||||
// Start() has been called, so we know the emit has flushed. The async
|
||||
// variant would race a goroutine.
|
||||
func TestProvisionWorkspaceAutoSync_EmitsProvisionStart(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
// Best-effort: the body will hit DB code under provisionWorkspaceCP
|
||||
// — we only need the emit at the entry, which fires unconditionally
|
||||
// before the dispatch. Recovering from any later panic keeps the
|
||||
// test focused.
|
||||
defer func() { _ = recover() }()
|
||||
h.provisionWorkspaceAutoSync("ws-test-1", "tmpl", nil, models.CreateWorkspacePayload{
|
||||
Name: "n", Tier: 4, Runtime: "claude-code",
|
||||
})
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: provision.start ") {
|
||||
t.Fatalf("expected provision.start emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-test-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"sync":true`) {
|
||||
t.Errorf("sync flag not pinned for sync dispatcher: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsRestartPreStop — emit fires before the actual
|
||||
// Stop call, so the trackingCPProv stub doesn't need to be wired for
|
||||
// real Stop semantics. Backend label "cp" pinned because that's the
|
||||
// SaaS path; we don't pin "docker" or "none" branches here (separate
|
||||
// tests would only re-test the trivial branch label switch).
|
||||
func TestStopForRestart_EmitsRestartPreStop(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
defer func() { _ = recover() }()
|
||||
h.stopForRestart(context.Background(), "ws-restart-1")
|
||||
got := read()
|
||||
if !strings.Contains(got, "evt: restart.pre_stop ") {
|
||||
t.Fatalf("expected restart.pre_stop emit, got log:\n%s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"workspace_id":"ws-restart-1"`) {
|
||||
t.Errorf("workspace_id not in payload: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `"backend":"cp"`) {
|
||||
t.Errorf("backend label missing or wrong: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopForRestart_EmitsBackendNoneWhenUnwired — pin the no-backend
|
||||
// branch so a future refactor that drops the label switch is caught.
|
||||
// This is the silent-Stop case (workspace_dispatchers.go:StopWorkspaceAuto
|
||||
// returns nil for unwired backends); the emit ensures the operator can
|
||||
// still see the boundary in the log.
|
||||
func TestStopForRestart_EmitsBackendNoneWhenUnwired(t *testing.T) {
|
||||
read := captureProvLog(t)
|
||||
h := &WorkspaceHandler{} // both nil
|
||||
h.stopForRestart(context.Background(), "ws-restart-2")
|
||||
got := read()
|
||||
if !strings.Contains(got, `"backend":"none"`) {
|
||||
t.Fatalf("expected backend=none for unwired handler: %s", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
// Tests for the SaaS-aware default-tier resolution introduced in #2901
|
||||
// and hardened in #2910 (multi-model review of #2901 found the original
|
||||
// claim of "all green" was passing because no SaaS-mode test existed).
|
||||
//
|
||||
// These tests pin three invariants:
|
||||
//
|
||||
// 1. WorkspaceHandler.IsSaaS() returns true when cpProv is wired,
|
||||
// false otherwise.
|
||||
// 2. WorkspaceHandler.DefaultTier() returns 4 on SaaS, 3 self-hosted.
|
||||
// 3. generateDefaultConfig (TemplatesHandler.Import path) writes the
|
||||
// passed-in tier into the generated config.yaml — pre-#2910 it
|
||||
// was hardcoded to 3 and silently disagreed with the create-
|
||||
// handler default on SaaS.
|
||||
|
||||
// stubCPProv is a minimal stand-in for the CP provisioner — only
|
||||
// exercises the IsSaaS / HasProvisioner contract, never invoked in
|
||||
// these tests.
|
||||
type stubCPProv struct{}
|
||||
|
||||
func (stubCPProv) Start(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (stubCPProv) Stop(_ interface{}, _ string) error { return nil }
|
||||
func (stubCPProv) Restart(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestIsSaaS_TrueWhenCPProvWired(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if !h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=false with cpProv wired; expected true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSaaS_FalseWhenOnlyDocker(t *testing.T) {
|
||||
// provisioner field set, cpProv nil — the self-hosted path.
|
||||
// Use a non-nil sentinel so the check actually has something to
|
||||
// disagree with. trackingCPProv lives in workspace_provision_auto_test.go
|
||||
// and is the established stub for these handler-level tests.
|
||||
h := &WorkspaceHandler{provisioner: nil, cpProv: nil}
|
||||
if h.IsSaaS() {
|
||||
t.Errorf("IsSaaS()=true with both backends nil; expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SaaS_IsT4(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if got := h.DefaultTier(); got != 4 {
|
||||
t.Errorf("SaaS DefaultTier()=%d; expected 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultTier_SelfHosted_IsT3(t *testing.T) {
|
||||
h := &WorkspaceHandler{}
|
||||
if got := h.DefaultTier(); got != 3 {
|
||||
t.Errorf("self-hosted DefaultTier()=%d; expected 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
// generateDefaultConfig — pin that the tier param flows into the
|
||||
// emitted config.yaml verbatim. Pre-#2910 this was hardcoded "tier: 3"
|
||||
// regardless of caller intent.
|
||||
func TestGenerateDefaultConfig_RespectsTierParam(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 4)
|
||||
if !strings.Contains(cfg, "tier: 4\n") {
|
||||
t.Errorf("expected `tier: 4` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
// The pre-#2910 hardcoded `tier: 3` line must NOT appear.
|
||||
if strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("config should not contain `tier: 3` when caller passed 4, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateDefaultConfig_SelfHostedTierT3(t *testing.T) {
|
||||
cfg := generateDefaultConfig("Test Agent", map[string]string{"system-prompt.md": ""}, 3)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("expected `tier: 3` in generated config, got:\n%s", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// Bounds check — caller passes 0 or out-of-range, helper falls back
|
||||
// to T3 (the safer-of-the-two when deployment mode can't be resolved).
|
||||
func TestGenerateDefaultConfig_OutOfRangeFallsBackToT3(t *testing.T) {
|
||||
for _, tier := range []int{0, -1, 99} {
|
||||
cfg := generateDefaultConfig("X", map[string]string{}, tier)
|
||||
if !strings.Contains(cfg, "tier: 3\n") {
|
||||
t.Errorf("invalid tier %d should fall back to T3, got:\n%s", tier, cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func TestSecurity_GetTemplates_NoAuth_Returns401(t *testing.T) {
|
||||
authDB, authMock := newEnrolledAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
@@ -98,7 +98,7 @@ func TestSecurity_GetTemplates_FreshInstall_FailsOpen(t *testing.T) {
|
||||
authDB, authMock := newFreshInstallAuthDB(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil)
|
||||
tmplh := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/templates", middleware.AdminAuth(authDB), tmplh.List)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TeamHandler now hosts only Collapse — the visual "expand" action is
|
||||
// canvas-side and creating children goes through the regular
|
||||
// WorkspaceHandler.Create path with parent_id set, like any other
|
||||
// workspace. Every workspace can have children; "team" is just the
|
||||
// state of having children. The old Expand handler bulk-created
|
||||
// children by reading sub_workspaces from a parent's config and was
|
||||
// non-idempotent — calling it N times leaked N×children EC2s, which
|
||||
// is how tenant-hongming accumulated 72 stale workspaces.
|
||||
type TeamHandler struct {
|
||||
wh *WorkspaceHandler
|
||||
b *events.Broadcaster
|
||||
}
|
||||
|
||||
// NewTeamHandler constructs a TeamHandler. wh is used by Collapse to
|
||||
// route StopWorkspaceAuto through the backend dispatcher.
|
||||
func NewTeamHandler(b *events.Broadcaster, wh *WorkspaceHandler, platformURL, configsDir string) *TeamHandler {
|
||||
return &TeamHandler{wh: wh, b: b}
|
||||
}
|
||||
|
||||
// Collapse handles POST /workspaces/:id/collapse
|
||||
// Stops and removes all child workspaces.
|
||||
func (h *TeamHandler) Collapse(c *gin.Context) {
|
||||
parentID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find children
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, parentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to query children"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
removed := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var childID, childName string
|
||||
if rows.Scan(&childID, &childName) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Stop the workload via the backend dispatcher (CP for SaaS,
|
||||
// Docker for self-hosted). Pre-2026-05-05 this was
|
||||
// `if h.provisioner != nil { h.provisioner.Stop(...) }`, which
|
||||
// silently skipped on every SaaS tenant — child EC2s kept running
|
||||
// after team-collapse until the orphan sweeper caught them
|
||||
// (issue #2813).
|
||||
if err := h.wh.StopWorkspaceAuto(ctx, childID); err != nil {
|
||||
log.Printf("Team collapse: stop %s failed: %v — orphan sweeper will reconcile", childID, err)
|
||||
}
|
||||
|
||||
// Mark as removed
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusRemoved, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to remove workspace %s: %v", childID, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
`DELETE FROM canvas_layouts WHERE workspace_id = $1`, childID); err != nil {
|
||||
log.Printf("Team collapse: failed to delete layout for %s: %v", childID, err)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", childID, map[string]interface{}{})
|
||||
|
||||
removed = append(removed, childName)
|
||||
}
|
||||
|
||||
h.b.RecordAndBroadcast(ctx, "WORKSPACE_COLLAPSED", parentID, map[string]interface{}{
|
||||
"removed_children": removed,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "collapsed",
|
||||
"removed": removed,
|
||||
})
|
||||
}
|
||||
|
||||
// findTemplateDirByName resolves a workspace name to its template
|
||||
// directory. Kept here because callers outside this package may use
|
||||
// it, even though the in-package consumer (Expand) is gone.
|
||||
//
|
||||
// TODO: relocate alongside the templates handler if no other callers
|
||||
// surface, or delete entirely after a deprecation cycle.
|
||||
func findTemplateDirByName(configsDir, name string) string {
|
||||
normalized := normalizeName(name)
|
||||
|
||||
candidate := filepath.Join(configsDir, normalized)
|
||||
if _, err := os.Stat(filepath.Join(candidate, "config.yaml")); err == nil {
|
||||
return candidate
|
||||
}
|
||||
|
||||
// Fall back to scanning all dirs
|
||||
entries, err := os.ReadDir(configsDir)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cfgPath := filepath.Join(configsDir, e.Name(), "config.yaml")
|
||||
data, err := os.ReadFile(cfgPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var cfg struct {
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
if yaml.Unmarshal(data, &cfg) == nil && cfg.Name == name {
|
||||
return filepath.Join(configsDir, e.Name())
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ---------- TeamHandler: Collapse ----------
|
||||
|
||||
func TestTeamCollapse_NoChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// No children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "collapsed" {
|
||||
t.Errorf("expected status 'collapsed', got %v", resp["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamCollapse_WithChildren(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewTeamHandler(broadcaster, NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()), "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// Two children
|
||||
mock.ExpectQuery("SELECT id, name FROM workspaces WHERE parent_id").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("child-1", "Worker A").
|
||||
AddRow("child-2", "Worker B"))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-1
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// UPDATE + DELETE + broadcast for child-2
|
||||
mock.ExpectExec("UPDATE workspaces SET status =").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("DELETE FROM canvas_layouts").
|
||||
WithArgs("child-2").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// WORKSPACE_COLLAPSED broadcast for parent
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-parent"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
handler.Collapse(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
removed, ok := resp["removed"].([]interface{})
|
||||
if !ok || len(removed) != 2 {
|
||||
t.Errorf("expected 2 removed children, got %v", resp["removed"])
|
||||
}
|
||||
}
|
||||
// ---------- findTemplateDirByName helper ----------
|
||||
|
||||
func TestFindTemplateDirByName_DirectMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
subDir := filepath.Join(dir, "mybot")
|
||||
os.MkdirAll(subDir, 0755)
|
||||
os.WriteFile(filepath.Join(subDir, "config.yaml"), []byte("name: MyBot"), 0644)
|
||||
|
||||
result := findTemplateDirByName(dir, "mybot")
|
||||
if result != subDir {
|
||||
t.Errorf("expected %s, got %s", subDir, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
result := findTemplateDirByName(dir, "nonexistent")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindTemplateDirByName_InvalidConfigsDir(t *testing.T) {
|
||||
result := findTemplateDirByName("/nonexistent/path", "anything")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for invalid dir, got %s", result)
|
||||
}
|
||||
}
|
||||
@@ -36,8 +36,14 @@ func normalizeName(name string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files and skills.
|
||||
func generateDefaultConfig(name string, files map[string]string) string {
|
||||
// generateDefaultConfig creates a config.yaml from detected prompt files
|
||||
// and skills. tier is the deployment-aware default (caller passes
|
||||
// h.wh.DefaultTier() — T4 on SaaS, T3 on self-hosted) so the generated
|
||||
// file matches what POST /workspaces would default to. Pre-#2910 this
|
||||
// was hardcoded to 3, which split-brained with the create-handler
|
||||
// default on SaaS (T4) and pinned newly-imported templates at T3 even
|
||||
// when downstream Create paths picked T4.
|
||||
func generateDefaultConfig(name string, files map[string]string, tier int) string {
|
||||
promptFiles := []string{}
|
||||
skillSet := map[string]bool{}
|
||||
|
||||
@@ -74,9 +80,15 @@ func generateDefaultConfig(name string, files map[string]string) string {
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString(`name: "` + escaped + `"` + "\n")
|
||||
cfg.WriteString("description: Imported agent\n")
|
||||
// Default to tier 3 ("Privileged") — matches the workspace.go
|
||||
// create handler default. See its comment for rationale.
|
||||
cfg.WriteString("version: 1.0.0\ntier: 3\n")
|
||||
// Tier is SaaS-aware via the caller's DefaultTier (#2910 PR-B).
|
||||
// Bounds-checked: invalid input falls back to T3 (the historical
|
||||
// default + the safer-of-the-two when the deployment mode can't
|
||||
// be resolved).
|
||||
if tier < 1 || tier > 4 {
|
||||
tier = 3
|
||||
}
|
||||
cfg.WriteString("version: 1.0.0\n")
|
||||
cfg.WriteString(fmt.Sprintf("tier: %d\n", tier))
|
||||
cfg.WriteString("model: anthropic:claude-haiku-4-5-20251001\n")
|
||||
cfg.WriteString("\nprompt_files:\n")
|
||||
if len(promptFiles) > 0 {
|
||||
@@ -148,7 +160,11 @@ func (h *TemplatesHandler) Import(c *gin.Context) {
|
||||
|
||||
// Auto-generate config.yaml if not provided
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
cfg := generateDefaultConfig(body.Name, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(body.Name, body.Files, tier)
|
||||
if err := os.WriteFile(filepath.Join(destDir, "config.yaml"), []byte(cfg), 0600); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to write config.yaml"})
|
||||
return
|
||||
@@ -227,7 +243,11 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
if _, exists := body.Files["config.yaml"]; !exists {
|
||||
// Check if config.yaml exists in container
|
||||
if _, err := h.execInContainer(ctx, containerName, []string{"test", "-f", "/configs/config.yaml"}); err != nil {
|
||||
cfg := generateDefaultConfig(wsName, body.Files)
|
||||
tier := 3
|
||||
if h.wh != nil {
|
||||
tier = h.wh.DefaultTier()
|
||||
}
|
||||
cfg := generateDefaultConfig(wsName, body.Files, tier)
|
||||
singleFile := map[string]string{"config.yaml": cfg}
|
||||
h.copyFilesToContainer(ctx, containerName, "/configs", singleFile)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestGenerateDefaultConfig_WithFiles(t *testing.T) {
|
||||
"skills/review/templates.md": "Templates",
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Test Agent", files)
|
||||
cfg := generateDefaultConfig("Test Agent", files, 3)
|
||||
|
||||
// Name is emitted as a double-quoted scalar (#221 sanitizer).
|
||||
if !strings.Contains(cfg, `name: "Test Agent"`) {
|
||||
@@ -85,7 +85,7 @@ func TestGenerateDefaultConfig_Empty(t *testing.T) {
|
||||
"data/something.json": `{"key": "value"}`,
|
||||
}
|
||||
|
||||
cfg := generateDefaultConfig("Empty Agent", files)
|
||||
cfg := generateDefaultConfig("Empty Agent", files, 3)
|
||||
|
||||
if !strings.Contains(cfg, `name: "Empty Agent"`) {
|
||||
t.Errorf("config should contain quoted agent name, got:\n%s", cfg)
|
||||
@@ -134,7 +134,7 @@ func TestGenerateDefaultConfig_YAMLInjection(t *testing.T) {
|
||||
|
||||
for _, tc := range adversarialCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{})
|
||||
cfg := generateDefaultConfig(tc.name, map[string]string{}, 3)
|
||||
var parsed map[string]interface{}
|
||||
if err := yaml.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("sanitized config does not parse as YAML: %v\n--- config ---\n%s", err, cfg)
|
||||
@@ -205,7 +205,7 @@ func TestImport_Success(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "New Agent",
|
||||
@@ -245,7 +245,7 @@ func TestImport_MissingName(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
body := `{"files": {"test.md": "content"}}`
|
||||
|
||||
@@ -265,7 +265,7 @@ func TestImport_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -296,7 +296,7 @@ func TestImport_AlreadyExists(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.MkdirAll(filepath.Join(tmpDir, "existing-agent"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{"name": "Existing Agent", "files": {"test.md": "content"}}`
|
||||
|
||||
@@ -317,7 +317,7 @@ func TestImport_WithConfigYaml(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
body := `{
|
||||
"name": "Custom Agent",
|
||||
@@ -354,7 +354,7 @@ func TestReplaceFiles_MissingBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -373,7 +373,7 @@ func TestReplaceFiles_TooManyFiles(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
files := make(map[string]string)
|
||||
for i := 0; i <= maxUploadFiles; i++ {
|
||||
@@ -398,7 +398,7 @@ func TestReplaceFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
// ReplaceFiles now selects (name, instance_id, runtime) for the
|
||||
// restart-cascade. Match the full column list rather than just the
|
||||
@@ -429,7 +429,7 @@ func TestReplaceFiles_PathTraversal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-rf-pt").
|
||||
|
||||
@@ -31,10 +31,20 @@ const maxUploadFiles = 200
|
||||
type TemplatesHandler struct {
|
||||
configsDir string
|
||||
docker *client.Client
|
||||
// wh is used by Import and ReplaceFiles to call DefaultTier() so a
|
||||
// generated config.yaml's tier matches the SaaS-vs-self-hosted
|
||||
// boundary (#2910 PR-B). nil-tolerant — the field is unused when
|
||||
// the caller doesn't import templates that need a fresh config
|
||||
// generated.
|
||||
wh *WorkspaceHandler
|
||||
}
|
||||
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli}
|
||||
// NewTemplatesHandler constructs a TemplatesHandler. wh may be nil for
|
||||
// callers that only use the read-only template surfaces (List,
|
||||
// ReadFile, ListFiles). Import + ReplaceFiles need wh non-nil so the
|
||||
// generated config.yaml picks the SaaS-aware default tier.
|
||||
func NewTemplatesHandler(configsDir string, dockerCli *client.Client, wh *WorkspaceHandler) *TemplatesHandler {
|
||||
return &TemplatesHandler{configsDir: configsDir, docker: dockerCli, wh: wh}
|
||||
}
|
||||
|
||||
// modelSpec describes a single supported model on a template: its id (sent
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestTemplatesList_EmptyDir(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -99,7 +99,7 @@ skills:
|
||||
// Create a directory without config.yaml (should be skipped)
|
||||
os.MkdirAll(filepath.Join(tmpDir, "no-config"), 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -160,7 +160,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -237,7 +237,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -315,7 +315,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -434,7 +434,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -512,7 +512,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -555,7 +555,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -589,7 +589,7 @@ skills: []
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -661,7 +661,7 @@ skills: []
|
||||
log.SetOutput(&logBuf)
|
||||
defer log.SetOutput(prevOutput)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/templates", nil)
|
||||
@@ -698,7 +698,7 @@ func TestTemplatesList_NonexistentDir(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil)
|
||||
handler := NewTemplatesHandler("/nonexistent/path/to/templates", nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -723,7 +723,7 @@ func TestListFiles_InvalidRoot(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -748,7 +748,7 @@ func TestListFiles_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-nonexist").
|
||||
@@ -775,7 +775,7 @@ func TestListFiles_FallbackToHost_NoTemplate(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil) // nil docker = no container
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil) // nil docker = no container
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-fallback").
|
||||
@@ -815,7 +815,7 @@ func TestListFiles_FallbackToHost_WithTemplate(t *testing.T) {
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Test Agent\n"), 0644)
|
||||
os.WriteFile(filepath.Join(tmplDir, "system-prompt.md"), []byte("# prompt"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-tmpl").
|
||||
@@ -849,7 +849,7 @@ func TestReadFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -870,7 +870,7 @@ func TestReadFile_InvalidRoot(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -892,7 +892,7 @@ func TestReadFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nf").
|
||||
@@ -926,7 +926,7 @@ func TestReadFile_FallbackToHost_Success(t *testing.T) {
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
os.WriteFile(filepath.Join(tmplDir, "config.yaml"), []byte("name: Reader Agent\ntier: 1\n"), 0644)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
// instance_id="" → SaaS branch skipped → falls through to local
|
||||
// Docker / template-dir host fallback (the only path the test
|
||||
@@ -967,7 +967,7 @@ func TestReadFile_FallbackToHost_NotFound(t *testing.T) {
|
||||
setupTestRedis(t)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-nofile").
|
||||
@@ -999,7 +999,7 @@ func TestWriteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1023,7 +1023,7 @@ func TestWriteFile_InvalidBody(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1046,7 +1046,7 @@ func TestWriteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT name, COALESCE\(instance_id, ''\), COALESCE\(runtime, ''\) FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-wf-nf").
|
||||
@@ -1080,7 +1080,7 @@ func TestDeleteFile_PathTraversal(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -1101,7 +1101,7 @@ func TestDeleteFile_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces WHERE id =").
|
||||
WithArgs("ws-del-nf").
|
||||
@@ -1133,7 +1133,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
tmplDir := filepath.Join(tmpDir, "my-agent")
|
||||
os.MkdirAll(tmplDir, 0755)
|
||||
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("My Agent")
|
||||
|
||||
if result != tmplDir {
|
||||
@@ -1143,7 +1143,7 @@ func TestResolveTemplateDir_ByNormalizedName(t *testing.T) {
|
||||
|
||||
func TestResolveTemplateDir_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
handler := NewTemplatesHandler(tmpDir, nil)
|
||||
handler := NewTemplatesHandler(tmpDir, nil, nil)
|
||||
result := handler.resolveTemplateDir("Nonexistent Agent")
|
||||
|
||||
if result != "" {
|
||||
@@ -1177,7 +1177,7 @@ func TestCWE78_DeleteFile_TraversalVariants(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil)
|
||||
handler := NewTemplatesHandler(t.TempDir(), nil, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// HasProvisioner reports whether either backend (CP or local Docker) is
|
||||
@@ -101,6 +102,14 @@ func (h *WorkspaceHandler) DefaultTier() int {
|
||||
// lives in prepareProvisionContext (shared by both per-backend
|
||||
// goroutines).
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": false,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
@@ -136,6 +145,14 @@ func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath stri
|
||||
// Keep these two helpers in sync — when one grows a new arm (third
|
||||
// backend, retry semantics), the other should too.
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAutoSync(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
provlog.Event("provision.start", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"name": payload.Name,
|
||||
"tier": payload.Tier,
|
||||
"runtime": payload.Runtime,
|
||||
"template": payload.Template,
|
||||
"sync": true,
|
||||
})
|
||||
if h.cpProv != nil {
|
||||
h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -431,6 +432,16 @@ func coalesceRestart(workspaceID string, cycle func()) {
|
||||
// NPE'd before reaching the reprovision step — which is why every SaaS dead-
|
||||
// agent incident pre-this-fix required manual restart from canvas.
|
||||
func (h *WorkspaceHandler) stopForRestart(ctx context.Context, workspaceID string) {
|
||||
backend := "none"
|
||||
if h.provisioner != nil {
|
||||
backend = "docker"
|
||||
} else if h.cpProv != nil {
|
||||
backend = "cp"
|
||||
}
|
||||
provlog.Event("restart.pre_stop", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"backend": backend,
|
||||
})
|
||||
if h.provisioner != nil {
|
||||
h.provisioner.Stop(ctx, workspaceID)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestINSERTworkspacesAllowlist enumerates every function in this
|
||||
// package that emits an `INSERT INTO workspaces (` SQL literal, and
|
||||
// pins the result against an explicit allowlist. New entries fail the
|
||||
// build until a reviewer adds them — forcing the question "what
|
||||
// makes this INSERT idempotent?" at PR-review time, not after the
|
||||
// next bulk-create leak.
|
||||
//
|
||||
// Pairs with TestCreateWorkspaceTree_CallsLookupBeforeInsert (the
|
||||
// behavior pin for the one bulk path). Together they close the
|
||||
// regression class: this test catches "did a new function start
|
||||
// inserting workspaces?", that test catches "did the existing bulk
|
||||
// path drop its idempotency check?". Either fires immediately when
|
||||
// drift happens.
|
||||
//
|
||||
// Why allowlist rather than pure behavior gate (per memory
|
||||
// feedback_behavior_based_ast_gates.md): the bulk-create leak class
|
||||
// is small + stable (1 path today), and a behavior gate would have
|
||||
// to disambiguate "iterating a YAML array of workspaces" from the
|
||||
// many other `for ... range` patterns in a Create handler (config
|
||||
// lines, secrets map, channels). Type-info-aware AST analysis would
|
||||
// catch the YAML-iteration shape but is heavy. Allowlisting is the
|
||||
// minimum-viable pin: any PR that adds a new INSERT site is forced
|
||||
// to pause, add an entry here, and document the safety mechanism in
|
||||
// the comment alongside.
|
||||
//
|
||||
// RFC #2867 class 1.
|
||||
func TestINSERTworkspacesAllowlist(t *testing.T) {
|
||||
// expected[key] = safety mechanism. Keep the comment pinned to
|
||||
// what makes that function safe — if the safety changes, the
|
||||
// allowlist must be re-reviewed.
|
||||
expected := map[string]string{
|
||||
// org_import.createWorkspaceTree: lookupExistingChild
|
||||
// before INSERT (#2868 phase 3). Also pinned by
|
||||
// TestCreateWorkspaceTree_CallsLookupBeforeInsert.
|
||||
"org_import.go:createWorkspaceTree": "lookup-then-insert via lookupExistingChild",
|
||||
// registry.Register: external workspace registers itself with
|
||||
// its known UUID; INSERT is idempotent via ON CONFLICT (id)
|
||||
// DO UPDATE — re-registration upserts, never duplicates.
|
||||
"registry.go:Register": "ON CONFLICT (id) DO UPDATE",
|
||||
// workspace.Create: single-workspace POST /workspaces from a
|
||||
// human or automation. No iteration; payload describes one
|
||||
// workspace; UUID is server-generated. Caller intent IS to
|
||||
// create, so no idempotency check is needed.
|
||||
"workspace.go:Create": "single-workspace POST, server-generated UUID",
|
||||
}
|
||||
|
||||
actual := map[string]string{}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(wd)
|
||||
if err != nil {
|
||||
t.Fatalf("readdir %s: %v", wd, err)
|
||||
}
|
||||
for _, ent := range entries {
|
||||
name := ent.Name()
|
||||
if ent.IsDir() {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(name, ".go") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(wd, name)
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s: %v", path, err)
|
||||
}
|
||||
// For each top-level FuncDecl, walk its body and check for an
|
||||
// `INSERT INTO workspaces (` SQL literal in any CallExpr arg.
|
||||
for _, decl := range file.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if !ok || fn.Body == nil {
|
||||
continue
|
||||
}
|
||||
var foundInsert bool
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
lit, ok := n.(*ast.BasicLit)
|
||||
if !ok || lit.Kind != token.STRING {
|
||||
return true
|
||||
}
|
||||
raw := lit.Value
|
||||
if unq, err := strconv.Unquote(raw); err == nil {
|
||||
raw = unq
|
||||
}
|
||||
if workspacesInsertRE.MatchString(raw) {
|
||||
foundInsert = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if foundInsert {
|
||||
key := name + ":" + fn.Name.Name
|
||||
actual[key] = "(observed via AST walk)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute set diffs so failures point at the specific drift.
|
||||
missing := []string{}
|
||||
unexpected := []string{}
|
||||
for k := range expected {
|
||||
if _, ok := actual[k]; !ok {
|
||||
missing = append(missing, k)
|
||||
}
|
||||
}
|
||||
for k := range actual {
|
||||
if _, ok := expected[k]; !ok {
|
||||
unexpected = append(unexpected, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(missing)
|
||||
sort.Strings(unexpected)
|
||||
|
||||
if len(unexpected) > 0 {
|
||||
t.Errorf(`new function(s) emit `+"`INSERT INTO workspaces (`"+` and aren't in the allowlist:
|
||||
%s
|
||||
|
||||
If this is a legitimate addition, add an entry to expected[] in this test
|
||||
with the safety mechanism pinned in the comment alongside (lookup-then-
|
||||
insert / ON CONFLICT / single-workspace path / etc.). The bulk-create
|
||||
regression class needs explicit per-handler review, not silent drift.
|
||||
|
||||
Reference: RFC #2867 class 1, sibling test
|
||||
TestCreateWorkspaceTree_CallsLookupBeforeInsert.`,
|
||||
strings.Join(unexpected, "\n "))
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Errorf(`expected function(s) no longer emit `+"`INSERT INTO workspaces (`"+`:
|
||||
%s
|
||||
|
||||
Either the function was renamed/deleted (update the allowlist) or the
|
||||
INSERT was moved out (verify the new home is also covered). Don't just
|
||||
delete the entry — confirm the safety mechanism is still in place
|
||||
elsewhere or that the workspace-create path was intentionally
|
||||
restructured.`,
|
||||
strings.Join(missing, "\n "))
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provlog"
|
||||
)
|
||||
|
||||
// CPProvisionerAPI is the contract WorkspaceHandler uses to talk to the
|
||||
@@ -214,6 +215,13 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
|
||||
}
|
||||
|
||||
log.Printf("CP provisioner: workspace %s → EC2 instance %s (%s)", cfg.WorkspaceID, result.InstanceID, result.State)
|
||||
provlog.Event("provision.ec2_started", map[string]any{
|
||||
"workspace_id": cfg.WorkspaceID,
|
||||
"instance_id": result.InstanceID,
|
||||
"state": result.State,
|
||||
"tier": cfg.Tier,
|
||||
"runtime": cfg.Runtime,
|
||||
})
|
||||
return result.InstanceID, nil
|
||||
}
|
||||
|
||||
@@ -273,6 +281,10 @@ func (p *CPProvisioner) Stop(ctx context.Context, workspaceID string) error {
|
||||
return fmt.Errorf("cp provisioner: stop %s: unexpected %d: %s",
|
||||
workspaceID, resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
provlog.Event("provision.ec2_stopped", map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"instance_id": instanceID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
// Package provlog emits structured, single-line JSON log records for
|
||||
// provisioning-lifecycle boundaries (workspace create, EC2 start/stop,
|
||||
// restart, idempotency skips). Records share a stable `evt:` prefix and
|
||||
// JSON payload so a future grep|jq pipeline (or a Loki/Datadog ingest)
|
||||
// can reconstruct the per-workspace timeline without parsing the
|
||||
// human-prose log lines that already exist.
|
||||
//
|
||||
// Existing log.Printf lines are intentionally NOT replaced — they
|
||||
// remain the operator-facing message. Event() emits a paired structured
|
||||
// record alongside, additive only.
|
||||
//
|
||||
// Event taxonomy (extend by appending; never rename):
|
||||
//
|
||||
// provision.start — workspace row inserted, EC2 about to launch
|
||||
// provision.skip_existing — idempotency hit, no new EC2
|
||||
// provision.ec2_started — RunInstances returned an instance id
|
||||
// provision.ec2_stopped — TerminateInstances acknowledged
|
||||
// restart.pre_stop — Restart handler about to call Stop
|
||||
//
|
||||
// Required fields per event are documented at each call site.
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
)
|
||||
|
||||
// Event writes a single line of the form:
|
||||
//
|
||||
// evt: <name> {"k":"v",...}
|
||||
//
|
||||
// to the standard logger. JSON encoding errors are silently swallowed —
|
||||
// a logging helper must never panic the request path. fields may be
|
||||
// nil; the empty payload `{}` is still useful to mark an event boundary.
|
||||
func Event(name string, fields map[string]any) {
|
||||
if fields == nil {
|
||||
fields = map[string]any{}
|
||||
}
|
||||
payload, err := json.Marshal(fields)
|
||||
if err != nil {
|
||||
// Fall back to a static payload so the event boundary still
|
||||
// appears in the log. The marshal error itself is recorded
|
||||
// on a best-effort basis.
|
||||
log.Printf("evt: %s {\"_marshal_err\":%q}", name, err.Error())
|
||||
return
|
||||
}
|
||||
log.Printf("evt: %s %s", name, payload)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package provlog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// captureLog redirects the default logger to a buffer for the duration
|
||||
// of fn and returns whatever was written.
|
||||
func captureLog(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
prevWriter := log.Writer()
|
||||
prevFlags := log.Flags()
|
||||
log.SetOutput(&buf)
|
||||
log.SetFlags(0) // strip date/time so assertions stay deterministic
|
||||
t.Cleanup(func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
})
|
||||
fn()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func TestEvent_EmitsEvtPrefixAndJSONPayload(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.start", map[string]any{
|
||||
"workspace_id": "ws-123",
|
||||
"tier": 4,
|
||||
"runtime": "claude-code",
|
||||
})
|
||||
})
|
||||
out = strings.TrimSpace(out)
|
||||
if !strings.HasPrefix(out, "evt: provision.start ") {
|
||||
t.Fatalf("expected evt-prefixed line, got %q", out)
|
||||
}
|
||||
jsonPart := strings.TrimPrefix(out, "evt: provision.start ")
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonPart), &got); err != nil {
|
||||
t.Fatalf("payload not valid JSON: %v (raw=%q)", err, jsonPart)
|
||||
}
|
||||
if got["workspace_id"] != "ws-123" {
|
||||
t.Errorf("workspace_id field lost: %+v", got)
|
||||
}
|
||||
// JSON unmarshal turns numbers into float64 — exact-equal compare.
|
||||
if got["tier"].(float64) != 4 {
|
||||
t.Errorf("tier field lost: %+v", got)
|
||||
}
|
||||
if got["runtime"] != "claude-code" {
|
||||
t.Errorf("runtime field lost: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_NilFieldsEmitsEmptyObject(t *testing.T) {
|
||||
out := captureLog(t, func() {
|
||||
Event("restart.pre_stop", nil)
|
||||
})
|
||||
if !strings.Contains(out, "evt: restart.pre_stop {}") {
|
||||
t.Fatalf("nil fields should emit empty object, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_PreservesEventBoundaryOnUnmarshalableValue(t *testing.T) {
|
||||
// A channel cannot be marshaled by encoding/json — verify we still
|
||||
// emit the event boundary with a recorded marshal error. This is
|
||||
// the structural guarantee: the call site never sees a panic, and
|
||||
// the event name is always present in the log.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.ec2_started", map[string]any{
|
||||
"chan": make(chan int),
|
||||
})
|
||||
})
|
||||
if !strings.Contains(out, "evt: provision.ec2_started ") {
|
||||
t.Fatalf("event boundary missing on marshal error: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "_marshal_err") {
|
||||
t.Fatalf("expected _marshal_err sentinel, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_SingleLineOutput(t *testing.T) {
|
||||
// Log aggregators line-split on \n. A multi-line emit would silently
|
||||
// fragment the JSON across two records — pin single-line shape.
|
||||
out := captureLog(t, func() {
|
||||
Event("provision.skip_existing", map[string]any{
|
||||
"existing_id": "ws-abc",
|
||||
"name": "child-1",
|
||||
})
|
||||
})
|
||||
trimmed := strings.TrimRight(out, "\n")
|
||||
if strings.Contains(trimmed, "\n") {
|
||||
t.Fatalf("event line must be single-line, got %q", out)
|
||||
}
|
||||
}
|
||||
@@ -243,13 +243,15 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
// entire platform. Gated behind AdminAuth (issue #180).
|
||||
r.GET("/approvals/pending", middleware.AdminAuth(db.DB), apph.ListAll)
|
||||
|
||||
// Team handlers — Collapse only. The bulk-Expand path is gone:
|
||||
// every workspace can have children via the regular CreateWorkspace
|
||||
// flow with parent_id set, so a separate handler that bulk-creates
|
||||
// from sub_workspaces (and was non-idempotent — calling it twice
|
||||
// duplicated the team) earned its way out.
|
||||
teamh := handlers.NewTeamHandler(broadcaster, wh, platformURL, configsDir)
|
||||
wsAuth.POST("/collapse", teamh.Collapse)
|
||||
// (TeamHandler is gone — #2864.) The visual canvas Collapse
|
||||
// button calls PATCH /workspaces/:id { collapsed: true/false }
|
||||
// (presentational toggle on canvas_layouts), NOT the destructive
|
||||
// POST /collapse that stopped + removed children. The
|
||||
// destructive route had zero UI callers (verified via grep
|
||||
// across canvas/, scripts/, and the MCP tool registry — only
|
||||
// docs referenced it). team.go + team_test.go + the route
|
||||
// + helpers (findTemplateDirByName, NewTeamHandler) are
|
||||
// deleted; visual collapse is unaffected.
|
||||
|
||||
// Agents
|
||||
ah := handlers.NewAgentHandler(broadcaster)
|
||||
@@ -519,8 +521,9 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
r.GET("/canvas/viewport", vh.Get)
|
||||
r.PUT("/canvas/viewport", middleware.CanvasOrBearer(db.DB), vh.Save)
|
||||
|
||||
// Templates
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli)
|
||||
// Templates — wh threaded so generateDefaultConfig picks the
|
||||
// SaaS-aware default tier in Import + ReplaceFiles (#2910 PR-B).
|
||||
tmplh := handlers.NewTemplatesHandler(configsDir, dockerCli, wh)
|
||||
// #686: GET /templates lists all template names+metadata from configsDir.
|
||||
// Open access lets unauthenticated callers enumerate org configurations and
|
||||
// installed plugins. AdminAuth-gate it alongside POST /templates/import.
|
||||
|
||||
+56
-111
@@ -325,115 +325,14 @@ async def tool_get_workspace_info(source_workspace_id: str | None = None) -> str
|
||||
return json.dumps(info, indent=2)
|
||||
|
||||
|
||||
async def tool_commit_memory(
|
||||
content: str,
|
||||
scope: str = "LOCAL",
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Save important information to persistent memory.
|
||||
|
||||
GLOBAL scope is writable only by root workspaces (tier == 0).
|
||||
RBAC memory.write permission is required for all scope levels.
|
||||
The source workspace_id is embedded in every record so the platform
|
||||
can enforce cross-workspace isolation and audit trail.
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
memory belongs to when the agent is registered into multiple
|
||||
workspaces (PR-1 / multi-workspace mode). When unset, falls back
|
||||
to the module-level WORKSPACE_ID — single-workspace operators see
|
||||
no behaviour change.
|
||||
"""
|
||||
if not content:
|
||||
return "Error: content is required"
|
||||
content = _redact_secrets(content)
|
||||
scope = scope.upper()
|
||||
if scope not in ("LOCAL", "TEAM", "GLOBAL"):
|
||||
scope = "LOCAL"
|
||||
|
||||
# RBAC: require memory.write permission (mirrors builtin_tools/memory.py)
|
||||
if not _check_memory_write_permission():
|
||||
return (
|
||||
"Error: RBAC — this workspace does not have the 'memory.write' "
|
||||
"permission for this operation."
|
||||
)
|
||||
|
||||
# Scope enforcement: only root workspaces (tier 0) can write GLOBAL memory.
|
||||
# This prevents tenant workspaces from poisoning org-wide memory (GH#1610).
|
||||
if scope == "GLOBAL" and not _is_root_workspace():
|
||||
return (
|
||||
"Error: RBAC — only root workspaces (tier 0) can write to GLOBAL scope. "
|
||||
"Non-root workspaces may use LOCAL or TEAM scope."
|
||||
)
|
||||
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/memories",
|
||||
json={
|
||||
"content": content,
|
||||
"scope": scope,
|
||||
# Embed source workspace so the platform can namespace-isolate
|
||||
# and audit cross-workspace writes (GH#1610 fix).
|
||||
"workspace_id": src,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
data = resp.json()
|
||||
if resp.status_code in (200, 201):
|
||||
return json.dumps({"success": True, "id": data.get("id"), "scope": scope})
|
||||
return f"Error: {data.get('error', resp.text)}"
|
||||
except Exception as e:
|
||||
return f"Error saving memory: {e}"
|
||||
|
||||
|
||||
async def tool_recall_memory(
|
||||
query: str = "",
|
||||
scope: str = "",
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Search persistent memory for previously saved information.
|
||||
|
||||
RBAC memory.read permission is required (mirrors builtin_tools/memory.py).
|
||||
The workspace_id is sent as a query parameter so the platform can
|
||||
cross-validate it against the auth token and defend against any future
|
||||
path traversal / cross-tenant read bugs in the platform itself.
|
||||
|
||||
``source_workspace_id`` selects which registered workspace's memories
|
||||
to search when the agent is registered into multiple workspaces.
|
||||
Unset → defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
# RBAC: require memory.read permission (mirrors builtin_tools/memory.py)
|
||||
if not _check_memory_read_permission():
|
||||
return (
|
||||
"Error: RBAC — this workspace does not have the 'memory.read' "
|
||||
"permission for this operation."
|
||||
)
|
||||
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
params: dict[str, str] = {"workspace_id": src}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if scope:
|
||||
params["scope"] = scope.upper()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/memories",
|
||||
params=params,
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
data = resp.json()
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return "No memories found."
|
||||
lines = []
|
||||
for m in data:
|
||||
lines.append(f"[{m.get('scope', '?')}] {m.get('content', '')}")
|
||||
return "\n".join(lines)
|
||||
return json.dumps(data)
|
||||
except Exception as e:
|
||||
return f"Error recalling memory: {e}"
|
||||
# Memory tool handlers — extracted to a2a_tools_memory (RFC #2873 iter 4c).
|
||||
# Re-imported here so call sites + tests that reference
|
||||
# ``a2a_tools.tool_commit_memory`` / ``tool_recall_memory`` keep
|
||||
# resolving identically.
|
||||
from a2a_tools_memory import ( # noqa: E402 (import after the top-of-module imports)
|
||||
tool_commit_memory,
|
||||
tool_recall_memory,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -550,6 +449,52 @@ async def tool_chat_history(
|
||||
return json.dumps(rows)
|
||||
|
||||
|
||||
def _enrich_inbound_for_agent(d: dict) -> dict:
|
||||
"""Add peer_name / peer_role / agent_card_url to a poll-path message.
|
||||
|
||||
The PUSH path (a2a_mcp_server._build_channel_notification) already
|
||||
enriches the meta dict with these fields, so a Claude Code host
|
||||
with channel-push sees them. The POLL path goes through
|
||||
InboxMessage.to_dict, which is intentionally identity-free (the
|
||||
storage layer doesn't know about the registry cache). Without this
|
||||
helper, every non-Claude-Code MCP client that uses inbox_peek /
|
||||
wait_for_message gets a plain message and the receiving agent
|
||||
can't tell who's writing — breaking the contract documented in
|
||||
a2a_mcp_server.py:303-345 ("In both paths the same fields apply").
|
||||
|
||||
Cache-first non-blocking enrichment (same shape as push): on cache
|
||||
miss the helper returns the bare message; the next call within the
|
||||
5-min TTL hits the warm cache. Failure to enrich is non-fatal —
|
||||
the agent still gets text + peer_id + kind + activity_id, just
|
||||
without the friendly identity.
|
||||
"""
|
||||
peer_id = d.get("peer_id") or ""
|
||||
if not peer_id:
|
||||
# canvas_user — no peer to enrich; helper returns the plain
|
||||
# message unchanged so the canvas reply path still works.
|
||||
return d
|
||||
try:
|
||||
from a2a_client import ( # local import — avoid module-load cycle
|
||||
_agent_card_url_for,
|
||||
enrich_peer_metadata_nonblocking,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
# If a2a_client is unavailable (test harness, partial install),
|
||||
# degrade gracefully — agent still gets the bare envelope.
|
||||
return d
|
||||
record = enrich_peer_metadata_nonblocking(peer_id)
|
||||
if record is not None:
|
||||
if name := record.get("name"):
|
||||
d["peer_name"] = name
|
||||
if role := record.get("role"):
|
||||
d["peer_role"] = role
|
||||
# agent_card_url is constructable from peer_id alone — surface it
|
||||
# even when registry enrichment misses, so the receiving agent has
|
||||
# a single endpoint to hit for the peer's full capability list.
|
||||
d["agent_card_url"] = _agent_card_url_for(peer_id)
|
||||
return d
|
||||
|
||||
|
||||
async def tool_inbox_peek(limit: int = 10) -> str:
|
||||
"""Return up to ``limit`` pending inbound messages without removing them."""
|
||||
import inbox # local import — avoids a circular dep at module load
|
||||
@@ -558,7 +503,7 @@ async def tool_inbox_peek(limit: int = 10) -> str:
|
||||
if state is None:
|
||||
return _INBOX_NOT_ENABLED_MSG
|
||||
messages = state.peek(limit=limit if isinstance(limit, int) else 10)
|
||||
return json.dumps([m.to_dict() for m in messages])
|
||||
return json.dumps([_enrich_inbound_for_agent(m.to_dict()) for m in messages])
|
||||
|
||||
|
||||
async def tool_inbox_pop(activity_id: str) -> str:
|
||||
@@ -606,4 +551,4 @@ async def tool_wait_for_message(timeout_secs: float = 60.0) -> str:
|
||||
message = await loop.run_in_executor(None, state.wait, timeout)
|
||||
if message is None:
|
||||
return json.dumps({"timeout": True, "timeout_secs": timeout})
|
||||
return json.dumps(message.to_dict())
|
||||
return json.dumps(_enrich_inbound_for_agent(message.to_dict()))
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Memory tool handlers — single-concern slice of the a2a_tools surface.
|
||||
|
||||
Extracted from ``a2a_tools.py`` (RFC #2873 iter 4c). Owns the two
|
||||
agent-memory MCP tools:
|
||||
|
||||
* ``tool_commit_memory`` — write to the workspace's persistent memory.
|
||||
* ``tool_recall_memory`` — search the workspace's persistent memory.
|
||||
|
||||
Both go through the platform's ``/workspaces/:id/memories`` endpoint;
|
||||
the platform is the source of truth for namespace isolation + audit
|
||||
trail. Local responsibility here is RBAC enforcement BEFORE hitting
|
||||
the network so a denied operation surfaces a clear in-band error
|
||||
instead of an opaque platform 403.
|
||||
|
||||
Imports the RBAC primitives from ``a2a_tools_rbac`` (iter 4a).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from a2a_client import PLATFORM_URL, WORKSPACE_ID
|
||||
from a2a_tools_rbac import (
|
||||
auth_headers_for_heartbeat as _auth_headers_for_heartbeat,
|
||||
check_memory_read_permission as _check_memory_read_permission,
|
||||
check_memory_write_permission as _check_memory_write_permission,
|
||||
is_root_workspace as _is_root_workspace,
|
||||
)
|
||||
from builtin_tools.security import _redact_secrets
|
||||
|
||||
|
||||
async def tool_commit_memory(
|
||||
content: str,
|
||||
scope: str = "LOCAL",
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Save important information to persistent memory.
|
||||
|
||||
GLOBAL scope is writable only by root workspaces (tier == 0).
|
||||
RBAC memory.write permission is required for all scope levels.
|
||||
The source workspace_id is embedded in every record so the platform
|
||||
can enforce cross-workspace isolation and audit trail.
|
||||
|
||||
``source_workspace_id`` selects which registered workspace this
|
||||
memory belongs to when the agent is registered into multiple
|
||||
workspaces (PR-1 / multi-workspace mode). When unset, falls back
|
||||
to the module-level WORKSPACE_ID — single-workspace operators see
|
||||
no behaviour change.
|
||||
"""
|
||||
if not content:
|
||||
return "Error: content is required"
|
||||
content = _redact_secrets(content)
|
||||
scope = scope.upper()
|
||||
if scope not in ("LOCAL", "TEAM", "GLOBAL"):
|
||||
scope = "LOCAL"
|
||||
|
||||
# RBAC: require memory.write permission (mirrors builtin_tools/memory.py)
|
||||
if not _check_memory_write_permission():
|
||||
return (
|
||||
"Error: RBAC — this workspace does not have the 'memory.write' "
|
||||
"permission for this operation."
|
||||
)
|
||||
|
||||
# Scope enforcement: only root workspaces (tier 0) can write GLOBAL memory.
|
||||
# This prevents tenant workspaces from poisoning org-wide memory (GH#1610).
|
||||
if scope == "GLOBAL" and not _is_root_workspace():
|
||||
return (
|
||||
"Error: RBAC — only root workspaces (tier 0) can write to GLOBAL scope. "
|
||||
"Non-root workspaces may use LOCAL or TEAM scope."
|
||||
)
|
||||
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/memories",
|
||||
json={
|
||||
"content": content,
|
||||
"scope": scope,
|
||||
# Embed source workspace so the platform can namespace-isolate
|
||||
# and audit cross-workspace writes (GH#1610 fix).
|
||||
"workspace_id": src,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
data = resp.json()
|
||||
if resp.status_code in (200, 201):
|
||||
return json.dumps({"success": True, "id": data.get("id"), "scope": scope})
|
||||
return f"Error: {data.get('error', resp.text)}"
|
||||
except Exception as e:
|
||||
return f"Error saving memory: {e}"
|
||||
|
||||
|
||||
async def tool_recall_memory(
|
||||
query: str = "",
|
||||
scope: str = "",
|
||||
source_workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Search persistent memory for previously saved information.
|
||||
|
||||
RBAC memory.read permission is required (mirrors builtin_tools/memory.py).
|
||||
The workspace_id is sent as a query parameter so the platform can
|
||||
cross-validate it against the auth token and defend against any future
|
||||
path traversal / cross-tenant read bugs in the platform itself.
|
||||
|
||||
``source_workspace_id`` selects which registered workspace's memories
|
||||
to search when the agent is registered into multiple workspaces.
|
||||
Unset → defaults to the module-level WORKSPACE_ID.
|
||||
"""
|
||||
# RBAC: require memory.read permission (mirrors builtin_tools/memory.py)
|
||||
if not _check_memory_read_permission():
|
||||
return (
|
||||
"Error: RBAC — this workspace does not have the 'memory.read' "
|
||||
"permission for this operation."
|
||||
)
|
||||
|
||||
src = source_workspace_id or WORKSPACE_ID
|
||||
params: dict[str, str] = {"workspace_id": src}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if scope:
|
||||
params["scope"] = scope.upper()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{src}/memories",
|
||||
params=params,
|
||||
headers=_auth_headers_for_heartbeat(src),
|
||||
)
|
||||
data = resp.json()
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return "No memories found."
|
||||
lines = []
|
||||
for m in data:
|
||||
lines.append(f"[{m.get('scope', '?')}] {m.get('content', '')}")
|
||||
return "\n".join(lines)
|
||||
return json.dumps(data)
|
||||
except Exception as e:
|
||||
return f"Error recalling memory: {e}"
|
||||
+44
-8
@@ -553,10 +553,26 @@ def _poll_once(
|
||||
# Imported lazily at use-site so a runtime that never sees an
|
||||
# upload-receive row never imports the module. Cheap on the hot
|
||||
# path because Python caches the import.
|
||||
from inbox_uploads import is_chat_upload_row, fetch_and_stage
|
||||
from inbox_uploads import is_chat_upload_row, BatchFetcher
|
||||
|
||||
new_count = 0
|
||||
last_id: str | None = None
|
||||
# ``batch_fetcher`` is lazy: a poll batch with no upload rows pays
|
||||
# zero overhead. Once the first upload row appears we open one
|
||||
# BatchFetcher and submit every subsequent upload row to its thread
|
||||
# pool; before processing the FIRST non-upload row we drain the
|
||||
# pool (wait_all) so the URI cache is hot when message rewriting
|
||||
# runs. Without the barrier, the chat message that references the
|
||||
# upload would arrive at the agent with the un-rewritten
|
||||
# platform-pending: URI.
|
||||
batch_fetcher: BatchFetcher | None = None
|
||||
|
||||
def _drain_uploads(bf: BatchFetcher | None) -> None:
|
||||
if bf is None:
|
||||
return
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
@@ -570,14 +586,21 @@ def _poll_once(
|
||||
# message_from_activity. We DO advance the cursor past
|
||||
# this row so a permanent network outage on /content
|
||||
# doesn't stall the cursor and block real chat traffic.
|
||||
fetch_and_stage(
|
||||
row,
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
)
|
||||
if batch_fetcher is None:
|
||||
batch_fetcher = BatchFetcher(
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
)
|
||||
batch_fetcher.submit(row)
|
||||
last_id = str(row.get("id", "")) or last_id
|
||||
continue
|
||||
# Non-upload row: drain any pending uploads first so the URI
|
||||
# cache is populated before we run rewrite_request_body /
|
||||
# message_from_activity on a row that may reference one.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
batch_fetcher = None
|
||||
if _is_self_notify_row(row):
|
||||
# The workspace-server's `/notify` handler writes the agent's
|
||||
# own send_message_to_user POSTs to activity_logs with
|
||||
@@ -612,6 +635,13 @@ def _poll_once(
|
||||
last_id = message.activity_id
|
||||
new_count += 1
|
||||
|
||||
# Drain any uploads still in flight if the batch ended with upload
|
||||
# rows (no chat-message row to trigger the inline drain). Without
|
||||
# this, a future poll that picks up the chat-message row first
|
||||
# would race with the still-running fetches.
|
||||
if batch_fetcher is not None:
|
||||
_drain_uploads(batch_fetcher)
|
||||
|
||||
if last_id is not None:
|
||||
state.save_cursor(last_id, cursor_key)
|
||||
return new_count
|
||||
@@ -654,6 +684,7 @@ def start_poller_thread(
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
interval: float = POLL_INTERVAL_SECONDS,
|
||||
stop_event: threading.Event | None = None,
|
||||
) -> threading.Thread:
|
||||
"""Spawn the poller as a daemon thread. Returns the Thread handle.
|
||||
|
||||
@@ -665,13 +696,18 @@ def start_poller_thread(
|
||||
operator running ``ps -eL`` or eyeballing ``threading.enumerate()``
|
||||
can tell which thread is which without reverse-engineering it from
|
||||
crash tracebacks.
|
||||
|
||||
Pass ``stop_event`` to enable graceful shutdown — used by tests so
|
||||
the daemon thread doesn't outlive the test that started it and race
|
||||
with later tests' httpx patches. Production code passes None and
|
||||
relies on the daemon flag for process-exit cleanup.
|
||||
"""
|
||||
name = "molecule-mcp-inbox-poller"
|
||||
if workspace_id:
|
||||
name = f"{name}-{workspace_id[:8]}"
|
||||
t = threading.Thread(
|
||||
target=_poll_loop,
|
||||
args=(state, platform_url, workspace_id, interval),
|
||||
args=(state, platform_url, workspace_id, interval, stop_event),
|
||||
name=name,
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
+264
-15
@@ -37,6 +37,7 @@ read another tenant's bytes even if a token is misrouted.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -68,6 +69,24 @@ MAX_FILE_BYTES = 25 * 1024 * 1024
|
||||
# 10s default for /activity calls — both are user-perceived latency.
|
||||
DEFAULT_FETCH_TIMEOUT = 60.0
|
||||
|
||||
# Concurrency cap for ``BatchFetcher``. Four workers is enough headroom
|
||||
# for the realistic "user dragged 3-4 files into chat at once" case
|
||||
# while bounding the platform's per-workspace fan-out. The cap matters
|
||||
# because the platform's /content endpoint reads bytea from Postgres in
|
||||
# a single round-trip per request — N workers = N concurrent DB reads
|
||||
# of up to 25 MB each, so a higher cap could pressure platform memory
|
||||
# without much UX win (network bandwidth is the bottleneck once the
|
||||
# bytes are buffered).
|
||||
DEFAULT_BATCH_FETCH_WORKERS = 4
|
||||
|
||||
# Upper bound on how long ``BatchFetcher.wait_all`` blocks the inbox
|
||||
# poll loop before giving up on still-in-flight fetches. Aligned with
|
||||
# DEFAULT_FETCH_TIMEOUT so a single hung fetch can't stall the loop
|
||||
# longer than its own deadline. A timeout fires only if a worker thread
|
||||
# is stuck past the underlying httpx timeout — pathological case;
|
||||
# normal completion is bounded by per-fetch timeout × ceil(N/W).
|
||||
DEFAULT_BATCH_WAIT_TIMEOUT = DEFAULT_FETCH_TIMEOUT + 5.0
|
||||
|
||||
# Cap on the URI cache. A long-lived workspace handling thousands of
|
||||
# uploads shouldn't grow without bound; an LRU cap of 1024 keeps the
|
||||
# entries-needed-for-a-typical-conversation well within memory.
|
||||
@@ -275,6 +294,7 @@ def fetch_and_stage(
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
client: Any = None,
|
||||
) -> str | None:
|
||||
"""Fetch the row's bytes, stage them under chat-uploads, and ack.
|
||||
|
||||
@@ -289,6 +309,11 @@ def fetch_and_stage(
|
||||
On success, the URI cache is updated so a subsequent chat message
|
||||
referencing the same ``platform-pending:`` URI is rewritten before
|
||||
the agent sees it.
|
||||
|
||||
Pass ``client`` to reuse a shared ``httpx.Client`` for both GET and
|
||||
POST ack (saves one TLS handshake per row vs. constructing one
|
||||
per-call). ``BatchFetcher`` does this across an entire poll batch so
|
||||
N concurrent fetches share one connection pool.
|
||||
"""
|
||||
body = _request_body_dict(row)
|
||||
if body is None:
|
||||
@@ -317,25 +342,58 @@ def fetch_and_stage(
|
||||
if not isinstance(filename, str):
|
||||
filename = "file"
|
||||
|
||||
# Lazy httpx import: the standalone MCP path uses httpx; an in-
|
||||
# container caller that imports this module by accident shouldn't
|
||||
# explode at import time.
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
logger.error("inbox_uploads: httpx not installed; cannot fetch %s", file_id)
|
||||
return None
|
||||
# Caller-supplied client: reuse for both GET + POST ack. Otherwise
|
||||
# build a one-shot client and close it on the way out. Lazy httpx
|
||||
# import keeps the standalone MCP path's optional dep optional.
|
||||
own_client = client is None
|
||||
if own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
logger.error("inbox_uploads: httpx not installed; cannot fetch %s", file_id)
|
||||
return None
|
||||
client = httpx.Client(timeout=timeout_secs)
|
||||
|
||||
try:
|
||||
return _fetch_and_stage_with_client(
|
||||
client,
|
||||
platform_url=platform_url,
|
||||
workspace_id=workspace_id,
|
||||
headers=headers,
|
||||
file_id=file_id,
|
||||
pending_uri=pending_uri,
|
||||
filename=filename,
|
||||
body=body,
|
||||
)
|
||||
finally:
|
||||
if own_client:
|
||||
try:
|
||||
client.close()
|
||||
except Exception: # noqa: BLE001 — close should never crash the caller
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_and_stage_with_client(
|
||||
client: Any,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
file_id: str,
|
||||
pending_uri: str,
|
||||
filename: str,
|
||||
body: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Inner body of fetch_and_stage. Always uses the supplied client for
|
||||
both GET and POST so the connection pool is shared across the call.
|
||||
"""
|
||||
content_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/content"
|
||||
ack_url = f"{platform_url}/workspaces/{workspace_id}/pending-uploads/{file_id}/ack"
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_secs) as client:
|
||||
resp = client.get(content_url, headers=headers)
|
||||
resp = client.get(content_url, headers=headers)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"inbox_uploads: GET %s failed: %s", content_url, exc
|
||||
)
|
||||
logger.warning("inbox_uploads: GET %s failed: %s", content_url, exc)
|
||||
return None
|
||||
|
||||
if resp.status_code == 404:
|
||||
@@ -403,8 +461,7 @@ def fetch_and_stage(
|
||||
# back the on-disk file — the platform's sweep will clean up
|
||||
# eventually.
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_secs) as client:
|
||||
ack_resp = client.post(ack_url, headers=headers)
|
||||
ack_resp = client.post(ack_url, headers=headers)
|
||||
if ack_resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"inbox_uploads: ack %s returned %d: %s",
|
||||
@@ -418,6 +475,198 @@ def fetch_and_stage(
|
||||
return local_uri
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BatchFetcher — concurrent fetch across a single poll batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BatchFetcher:
|
||||
"""Fetch + stage + ack a batch of upload-receive rows concurrently.
|
||||
|
||||
Why this exists: the inbox poll loop used to call ``fetch_and_stage``
|
||||
serially per row. With N upload rows in a batch (a user dragging
|
||||
multiple files into chat at once), the loop blocked for
|
||||
``N × per_fetch_latency`` before processing the chat message that
|
||||
referenced them — a 4-file upload at 5s each = 20s of stall
|
||||
before the agent saw the user's prompt. ``BatchFetcher`` runs the
|
||||
fetches on a small thread pool (default 4 workers) so the stall is
|
||||
bounded by ``ceil(N/W) × per_fetch_latency`` instead.
|
||||
|
||||
Connection reuse: one ``httpx.Client`` is shared across every fetch
|
||||
in the batch. httpx clients carry a connection pool, so a second
|
||||
fetch to the same platform host reuses the TCP+TLS handshake from
|
||||
the first — measurable win when fetches happen back-to-back.
|
||||
|
||||
Correctness invariant the caller MUST preserve: the inbox loop is
|
||||
expected to call ``wait_all()`` before processing the chat-message
|
||||
activity row that REFERENCES one of these uploads. Without the
|
||||
barrier, the URI cache is empty when ``rewrite_request_body`` runs
|
||||
and the agent sees the un-rewritten ``platform-pending:`` URI. The
|
||||
caller-side test ``test_poll_once_waits_for_uploads_before_messages``
|
||||
pins this end-to-end.
|
||||
|
||||
Use as a context manager so the executor + client are torn down
|
||||
even if the caller raises mid-batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
headers: dict[str, str],
|
||||
timeout_secs: float = DEFAULT_FETCH_TIMEOUT,
|
||||
max_workers: int = DEFAULT_BATCH_FETCH_WORKERS,
|
||||
client: Any = None,
|
||||
):
|
||||
self._platform_url = platform_url
|
||||
self._workspace_id = workspace_id
|
||||
self._headers = dict(headers) # copy so caller mutations don't leak in
|
||||
self._timeout_secs = timeout_secs
|
||||
|
||||
# Caller can inject a client (tests do this); production callers
|
||||
# let us build one. Track ownership so we only close ours.
|
||||
self._own_client = client is None
|
||||
if self._own_client:
|
||||
try:
|
||||
import httpx # noqa: WPS433
|
||||
except ImportError:
|
||||
# Match fetch_and_stage's behavior: log + degrade rather
|
||||
# than raising at construction time. submit() will then
|
||||
# return None for every row.
|
||||
logger.error("inbox_uploads: httpx not installed; BatchFetcher inert")
|
||||
self._client: Any = None
|
||||
else:
|
||||
self._client = httpx.Client(timeout=timeout_secs)
|
||||
else:
|
||||
self._client = client
|
||||
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="upload-fetch",
|
||||
)
|
||||
self._futures: list[concurrent.futures.Future[Any]] = []
|
||||
self._closed = False
|
||||
# Flipped to True by wait_all when the timeout fires; close()
|
||||
# reads this to decide between drain-and-wait vs cancel-queued.
|
||||
self._timed_out = False
|
||||
|
||||
def submit(self, row: dict[str, Any]) -> concurrent.futures.Future[Any] | None:
|
||||
"""Submit ``row`` for fetch + stage + ack. Non-blocking — the
|
||||
worker thread runs ``fetch_and_stage`` with the shared client.
|
||||
|
||||
Returns the Future so a caller that wants per-row outcome can
|
||||
await it; ``None`` if the BatchFetcher is in a degraded state
|
||||
(httpx missing).
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("BatchFetcher: submit after close")
|
||||
if self._client is None:
|
||||
return None
|
||||
fut = self._executor.submit(
|
||||
fetch_and_stage,
|
||||
row,
|
||||
platform_url=self._platform_url,
|
||||
workspace_id=self._workspace_id,
|
||||
headers=self._headers,
|
||||
timeout_secs=self._timeout_secs,
|
||||
client=self._client,
|
||||
)
|
||||
self._futures.append(fut)
|
||||
return fut
|
||||
|
||||
def wait_all(self, timeout: float | None = DEFAULT_BATCH_WAIT_TIMEOUT) -> None:
|
||||
"""Block until every submitted future completes (or times out).
|
||||
|
||||
Per-future exceptions are logged + swallowed — ``fetch_and_stage``
|
||||
already converts every error path to ``return None``, so a real
|
||||
exception propagating up to here is unexpected and we don't want
|
||||
one bad fetch to abort the whole batch.
|
||||
|
||||
Timeouts are also logged + swallowed AND record the timed-out
|
||||
futures on ``self._timed_out`` so ``close`` can cancel them
|
||||
without paying their full latency. Without this hand-off,
|
||||
``close()``'s ``shutdown(wait=True)`` would block on the leaked
|
||||
workers and undo the user-facing timeout — the inbox poll loop
|
||||
would stall indefinitely on a hung /content fetch.
|
||||
"""
|
||||
if not self._futures:
|
||||
return
|
||||
try:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
self._futures,
|
||||
timeout=timeout,
|
||||
return_when=concurrent.futures.ALL_COMPLETED,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 — concurrent.futures shouldn't raise here
|
||||
logger.warning("inbox_uploads: BatchFetcher.wait_all crashed: %s", exc)
|
||||
return
|
||||
for fut in done:
|
||||
exc = fut.exception()
|
||||
if exc is not None:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher worker raised: %s", exc
|
||||
)
|
||||
if not_done:
|
||||
logger.warning(
|
||||
"inbox_uploads: BatchFetcher.wait_all left %d in-flight after %ss timeout",
|
||||
len(not_done),
|
||||
timeout,
|
||||
)
|
||||
# Mark these futures so close() knows to cancel-not-wait. We
|
||||
# cancel queued-but-not-started ones immediately; futures
|
||||
# already running can't be cancelled (Python's threading
|
||||
# model), but close() will pass cancel_futures=True so any
|
||||
# remaining queued items don't run.
|
||||
for fut in not_done:
|
||||
fut.cancel()
|
||||
self._timed_out = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Tear down the executor + (if owned) the httpx client.
|
||||
|
||||
Idempotent. After close, ``submit`` raises and the BatchFetcher
|
||||
cannot be reused — construct a fresh one for the next poll.
|
||||
|
||||
If ``wait_all`` reported a timeout, shutdown skips the
|
||||
``wait=True`` drain and instead asks the executor to drop queued
|
||||
futures (``cancel_futures=True``). Currently-running workers
|
||||
can't be interrupted by Python's threading model, but the poll
|
||||
loop returns immediately rather than blocking on a hung fetch.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
timed_out = getattr(self, "_timed_out", False)
|
||||
try:
|
||||
if timed_out:
|
||||
# cancel_futures landed in Python 3.9 — guarded for older
|
||||
# interpreters via a TypeError fallback. Drop queued
|
||||
# tasks; running ones will exit when their httpx call
|
||||
# eventually returns or the daemon thread dies.
|
||||
try:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
except TypeError:
|
||||
self._executor.shutdown(wait=False)
|
||||
else:
|
||||
# Healthy path: wait for in-flight work so we don't
|
||||
# interrupt a fetch mid-write.
|
||||
self._executor.shutdown(wait=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: executor shutdown error: %s", exc)
|
||||
if self._own_client and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox_uploads: client close error: %s", exc)
|
||||
|
||||
def __enter__(self) -> "BatchFetcher":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URI rewrite for incoming chat messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -702,9 +702,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-1"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("Remember this", scope="local")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -716,9 +716,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(200, {"id": "mem-2"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("Remember this", scope="INVALID")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -728,9 +728,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(200, {"id": "mem-3"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("Team info", scope="TEAM")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -741,9 +741,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-4"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=True):
|
||||
result = await a2a_tools.tool_commit_memory("Global info", scope="GLOBAL")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -753,9 +753,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(200, {"id": "mem-5"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("info")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -766,9 +766,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-6"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("info")
|
||||
|
||||
data = json.loads(result)
|
||||
@@ -779,9 +779,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(400, {"error": "bad request payload"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("info")
|
||||
|
||||
assert "Error" in result
|
||||
@@ -791,9 +791,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_exc=RuntimeError("storage failure"))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("info")
|
||||
|
||||
assert "Error saving memory" in result
|
||||
@@ -808,9 +808,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-poison"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("poisoned GLOBAL memory", scope="GLOBAL")
|
||||
|
||||
# Must NOT have called the platform — early rejection
|
||||
@@ -824,9 +824,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-7"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=False), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=False), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
result = await a2a_tools.tool_commit_memory("should be denied", scope="LOCAL")
|
||||
|
||||
mc.post.assert_not_called()
|
||||
@@ -838,9 +838,9 @@ class TestToolCommitMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(post_resp=_resp(201, {"id": "mem-8"}))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools._is_root_workspace", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_write_permission", return_value=True), \
|
||||
patch("a2a_tools_memory._is_root_workspace", return_value=False):
|
||||
await a2a_tools.tool_commit_memory("test content", scope="LOCAL")
|
||||
|
||||
call_kwargs = mc.post.call_args.kwargs
|
||||
@@ -865,8 +865,8 @@ class TestToolRecallMemory:
|
||||
{"scope": "TEAM", "content": "We use Python 3.11"},
|
||||
]
|
||||
mc = _make_http_mock(get_resp=_resp(200, memories))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
result = await a2a_tools.tool_recall_memory(query="capital")
|
||||
|
||||
assert "[LOCAL]" in result
|
||||
@@ -878,8 +878,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
result = await a2a_tools.tool_recall_memory(query="anything")
|
||||
|
||||
assert result == "No memories found."
|
||||
@@ -890,8 +890,8 @@ class TestToolRecallMemory:
|
||||
|
||||
payload = {"error": "search unavailable"}
|
||||
mc = _make_http_mock(get_resp=_resp(200, payload))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
result = await a2a_tools.tool_recall_memory()
|
||||
|
||||
parsed = json.loads(result)
|
||||
@@ -901,8 +901,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_exc=RuntimeError("search service down"))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
result = await a2a_tools.tool_recall_memory(query="test")
|
||||
|
||||
assert "Error recalling memory" in result
|
||||
@@ -913,8 +913,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
await a2a_tools.tool_recall_memory(query="paris", scope="local")
|
||||
|
||||
call_kwargs = mc.get.call_args.kwargs
|
||||
@@ -928,8 +928,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
await a2a_tools.tool_recall_memory()
|
||||
|
||||
call_kwargs = mc.get.call_args.kwargs
|
||||
@@ -942,8 +942,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, []))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=True):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=True):
|
||||
await a2a_tools.tool_recall_memory(scope="team")
|
||||
|
||||
call_kwargs = mc.get.call_args.kwargs
|
||||
@@ -960,8 +960,8 @@ class TestToolRecallMemory:
|
||||
import a2a_tools
|
||||
|
||||
mc = _make_http_mock(get_resp=_resp(200, [{"scope": "GLOBAL", "content": "secret"}]))
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools._check_memory_read_permission", return_value=False):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=mc), \
|
||||
patch("a2a_tools_memory._check_memory_read_permission", return_value=False):
|
||||
result = await a2a_tools.tool_recall_memory(query="secret")
|
||||
|
||||
mc.get.assert_not_called()
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests for `_enrich_inbound_for_agent` — the poll-path companion to
|
||||
the push-path enrichment in `a2a_mcp_server._build_channel_notification`.
|
||||
|
||||
The MCP poll path (inbox_peek / wait_for_message) returns
|
||||
`InboxMessage.to_dict()`, which has `activity_id, text, peer_id, kind,
|
||||
method, created_at` but NOT the registry-resolved `peer_name`,
|
||||
`peer_role`, or `agent_card_url`. The receiving agent then sees a
|
||||
plain message and can't tell who's writing — breaking the universal
|
||||
contract documented in `a2a_mcp_server.py:303-345` ("In both paths
|
||||
the same fields apply").
|
||||
|
||||
The enrichment helper closes that gap. These tests pin:
|
||||
- canvas_user (peer_id="") passes through unchanged
|
||||
- peer_agent with cache hit gets peer_name + peer_role + agent_card_url
|
||||
- peer_agent with cache miss still gets agent_card_url (constructable
|
||||
from peer_id alone)
|
||||
- a2a_client unavailable (test harness without registry) degrades
|
||||
gracefully — agent still gets the bare envelope
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# a2a_client.py reads WORKSPACE_ID at import time and raises if it's
|
||||
# unset. Stamp a stub before any test pulls in a2a_tools (which transitively
|
||||
# imports a2a_client). conftest.py mocks the SDK but not this env var.
|
||||
os.environ.setdefault("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
|
||||
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
PEER_UUID = "11111111-2222-3333-4444-555555555555"
|
||||
|
||||
|
||||
def test_canvas_user_passes_through_unchanged():
|
||||
from a2a_tools import _enrich_inbound_for_agent
|
||||
|
||||
base = {
|
||||
"activity_id": "act-1",
|
||||
"text": "hello from canvas",
|
||||
"peer_id": "",
|
||||
"kind": "canvas_user",
|
||||
"method": "message/send",
|
||||
"created_at": "2026-05-05T11:00:00Z",
|
||||
}
|
||||
|
||||
out = _enrich_inbound_for_agent(dict(base))
|
||||
|
||||
# Plain pass-through — no enrichment fields added for canvas_user.
|
||||
assert out == base
|
||||
assert "peer_name" not in out
|
||||
assert "peer_role" not in out
|
||||
assert "agent_card_url" not in out
|
||||
|
||||
|
||||
def test_peer_agent_cache_hit_adds_name_role_and_card_url():
|
||||
from a2a_tools import _enrich_inbound_for_agent
|
||||
|
||||
record = {"name": "ops-agent", "role": "sre"}
|
||||
card_url = f"https://platform.example/registry/{PEER_UUID}/agent-card"
|
||||
|
||||
with patch(
|
||||
"a2a_client.enrich_peer_metadata_nonblocking",
|
||||
return_value=record,
|
||||
), patch(
|
||||
"a2a_client._agent_card_url_for",
|
||||
return_value=card_url,
|
||||
):
|
||||
out = _enrich_inbound_for_agent({
|
||||
"activity_id": "act-2",
|
||||
"text": "ping",
|
||||
"peer_id": PEER_UUID,
|
||||
"kind": "peer_agent",
|
||||
"method": "message/send",
|
||||
"created_at": "2026-05-05T11:01:00Z",
|
||||
})
|
||||
|
||||
assert out["peer_name"] == "ops-agent"
|
||||
assert out["peer_role"] == "sre"
|
||||
assert out["agent_card_url"] == card_url
|
||||
|
||||
|
||||
def test_peer_agent_cache_miss_still_gets_agent_card_url():
|
||||
"""agent_card_url is constructable from peer_id alone — surface it
|
||||
even when registry enrichment misses, so the receiving agent has a
|
||||
single endpoint to hit for the peer's full capability list."""
|
||||
from a2a_tools import _enrich_inbound_for_agent
|
||||
|
||||
card_url = f"https://platform.example/registry/{PEER_UUID}/agent-card"
|
||||
|
||||
with patch(
|
||||
"a2a_client.enrich_peer_metadata_nonblocking",
|
||||
return_value=None, # cache miss
|
||||
), patch(
|
||||
"a2a_client._agent_card_url_for",
|
||||
return_value=card_url,
|
||||
):
|
||||
out = _enrich_inbound_for_agent({
|
||||
"activity_id": "act-3",
|
||||
"text": "ping",
|
||||
"peer_id": PEER_UUID,
|
||||
"kind": "peer_agent",
|
||||
"method": "message/send",
|
||||
"created_at": "2026-05-05T11:02:00Z",
|
||||
})
|
||||
|
||||
assert "peer_name" not in out
|
||||
assert "peer_role" not in out
|
||||
assert out["agent_card_url"] == card_url
|
||||
|
||||
|
||||
def test_peer_agent_a2a_client_unavailable_degrades_gracefully(monkeypatch):
|
||||
"""If a2a_client can't be imported (test harness, partial install),
|
||||
return the bare envelope — agent still gets text + peer_id + kind +
|
||||
activity_id, just without the friendly identity."""
|
||||
from a2a_tools import _enrich_inbound_for_agent
|
||||
|
||||
# Stub a2a_client import to fail.
|
||||
real_module = sys.modules.pop("a2a_client", None)
|
||||
fake = types.ModuleType("a2a_client")
|
||||
# Deliberately omit enrich_peer_metadata_nonblocking and
|
||||
# _agent_card_url_for so the helper's fallback path fires.
|
||||
sys.modules["a2a_client"] = fake
|
||||
|
||||
try:
|
||||
out = _enrich_inbound_for_agent({
|
||||
"activity_id": "act-4",
|
||||
"text": "ping",
|
||||
"peer_id": PEER_UUID,
|
||||
"kind": "peer_agent",
|
||||
"method": "message/send",
|
||||
"created_at": "2026-05-05T11:03:00Z",
|
||||
})
|
||||
finally:
|
||||
if real_module is not None:
|
||||
sys.modules["a2a_client"] = real_module
|
||||
else:
|
||||
sys.modules.pop("a2a_client", None)
|
||||
|
||||
# Bare envelope passes through — receiving agent still has enough
|
||||
# to act, even if the friendly identity is missing.
|
||||
assert out["peer_id"] == PEER_UUID
|
||||
assert out["text"] == "ping"
|
||||
assert out["kind"] == "peer_agent"
|
||||
assert "peer_name" not in out
|
||||
assert "peer_role" not in out
|
||||
assert "agent_card_url" not in out
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Drift gate + smoke tests for ``a2a_tools_memory`` (RFC #2873 iter 4c).
|
||||
|
||||
The full behavior matrix (RBAC denies, scope enforcement, platform
|
||||
HTTP error paths) lives in ``test_a2a_tools_impl.py`` (TestToolCommitMemory
|
||||
+ TestToolRecallMemory) which patches `a2a_tools_memory.foo` after the
|
||||
iter 4c retarget.
|
||||
|
||||
This file pins:
|
||||
|
||||
1. **Drift gate** — every previously-public symbol on ``a2a_tools``
|
||||
(``tool_commit_memory``, ``tool_recall_memory``) is the EXACT same
|
||||
callable as ``a2a_tools_memory.foo``. Refactor wrapping silently
|
||||
loses the existing test coverage; this gate makes that drift fail
|
||||
fast.
|
||||
2. **Import contract** — ``a2a_tools_memory`` does NOT pull in
|
||||
``a2a_tools`` at module-load time. The handlers depend on
|
||||
``a2a_tools_rbac`` (the layered architecture) and ``a2a_client``,
|
||||
not on the kitchen-sink module that re-exports them.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_workspace_id(monkeypatch):
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ============== Drift gate ==============
|
||||
|
||||
class TestBackCompatAliases:
|
||||
def test_tool_commit_memory_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_memory
|
||||
assert a2a_tools.tool_commit_memory is a2a_tools_memory.tool_commit_memory
|
||||
|
||||
def test_tool_recall_memory_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_memory
|
||||
assert a2a_tools.tool_recall_memory is a2a_tools_memory.tool_recall_memory
|
||||
|
||||
|
||||
# ============== Import contract ==============
|
||||
|
||||
class TestImportContract:
|
||||
def test_memory_module_does_not_load_a2a_tools(self, monkeypatch):
|
||||
"""`a2a_tools_memory` must depend on `a2a_tools_rbac` (the layered
|
||||
architecture) and `a2a_client`, NEVER on the kitchen-sink
|
||||
`a2a_tools`. Top-level `from a2a_tools import …` would defeat
|
||||
the modularization goal and risk a circular-import."""
|
||||
# Drop both modules to control import order
|
||||
for m in ("a2a_tools", "a2a_tools_memory"):
|
||||
sys.modules.pop(m, None)
|
||||
|
||||
# Import memory module. Should succeed without a2a_tools loaded.
|
||||
import a2a_tools_memory # noqa: F401
|
||||
assert "a2a_tools_memory" in sys.modules
|
||||
|
||||
def test_a2a_tools_re_exports_memory_handlers(self):
|
||||
"""The opposite direction: a2a_tools must surface every memory
|
||||
symbol so existing call sites + tests work unchanged."""
|
||||
import a2a_tools
|
||||
assert hasattr(a2a_tools, "tool_commit_memory")
|
||||
assert hasattr(a2a_tools, "tool_recall_memory")
|
||||
@@ -555,16 +555,34 @@ def test_poll_once_self_notify_does_not_fire_notification(state: inbox.InboxStat
|
||||
def test_start_poller_thread_is_daemon(state: inbox.InboxState):
|
||||
"""Daemon flag is required so the poller dies with the parent
|
||||
process; a non-daemon poller would leak across `claude` restarts
|
||||
and write to a stale workspace."""
|
||||
and write to a stale workspace.
|
||||
|
||||
Stop_event is plumbed so the thread cleans up at the end of the
|
||||
test instead of leaking into later tests. Without cleanup, the
|
||||
daemon's ~10ms tick races with later tests that patch httpx.Client
|
||||
— the leaked thread sees their patched response and runs an
|
||||
unwanted iteration of _poll_once that double-counts mocked calls
|
||||
(caught when test_batch_fetcher_owns_client_when_not_supplied
|
||||
surfaced this on Python 3.11 CI but not 3.13 local).
|
||||
"""
|
||||
resp = _make_response(200, [])
|
||||
p, _ = _patch_httpx(resp)
|
||||
stop_event = threading.Event()
|
||||
with p, patch("platform_auth.auth_headers", return_value={}):
|
||||
# Use a very short interval so the loop body runs at least once
|
||||
# before we exit the test.
|
||||
t = inbox.start_poller_thread(state, "http://platform", "ws-1", interval=0.01)
|
||||
t = inbox.start_poller_thread(
|
||||
state, "http://platform", "ws-1", interval=0.01, stop_event=stop_event
|
||||
)
|
||||
time.sleep(0.05)
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
assert t.daemon is True
|
||||
assert t.is_alive()
|
||||
# Signal shutdown + wait for the thread to actually exit before
|
||||
# we leave the test scope. Without this join, the leaked thread
|
||||
# races with later tests' httpx patches.
|
||||
stop_event.set()
|
||||
t.join(timeout=2.0)
|
||||
assert not t.is_alive(), "poller thread did not exit on stop_event"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -577,6 +595,219 @@ def test_default_cursor_path_uses_configs_dir(monkeypatch, tmp_path: Path):
|
||||
assert inbox.default_cursor_path() == tmp_path / ".mcp_inbox_cursor"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 5b — BatchFetcher integration with the poll loop
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# These tests pin the cross-module contract between inbox._poll_once and
|
||||
# inbox_uploads.BatchFetcher: chat_upload_receive rows must be submitted
|
||||
# to a single BatchFetcher AND drained (URI cache populated) before any
|
||||
# subsequent message row is processed. Without the drain, the
|
||||
# rewrite_request_body path inside message_from_activity surfaces the
|
||||
# un-rewritten ``platform-pending:`` URI to the agent.
|
||||
|
||||
|
||||
def _upload_row(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "chat_upload_receive",
|
||||
"summary": f"chat_upload_receive: {file_id}.pdf",
|
||||
"request_body": {
|
||||
"file_id": file_id,
|
||||
"name": f"{file_id}.pdf",
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 3,
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _message_row_referencing(act_id: str, file_id: str) -> dict:
|
||||
return {
|
||||
"id": act_id,
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {
|
||||
"params": {
|
||||
"message": {
|
||||
"parts": [
|
||||
{"kind": "text", "text": "have a look"},
|
||||
{
|
||||
"kind": "file",
|
||||
"file": {
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"name": f"{file_id}.pdf",
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"created_at": "2026-05-04T10:00:01Z",
|
||||
}
|
||||
|
||||
|
||||
def _patch_httpx_routing(activity_rows: list[dict], upload_bytes: bytes = b"PDF"):
|
||||
"""Replace ``httpx.Client`` so:
|
||||
|
||||
- GET /activity returns ``activity_rows``
|
||||
- GET /workspaces/.../content returns ``upload_bytes`` with content-type
|
||||
- POST /ack returns 200
|
||||
|
||||
Returns the patch context manager; tests use ``with p:``. Each new
|
||||
Client(...) gets a fresh MagicMock so the test can verify
|
||||
constructor-count expectations without pinning singletons.
|
||||
"""
|
||||
def _client_factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.__enter__ = MagicMock(return_value=c)
|
||||
c.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
def _get(url, params=None, headers=None):
|
||||
if "/activity" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = activity_rows
|
||||
resp.text = ""
|
||||
return resp
|
||||
if "/pending-uploads/" in url and "/content" in url:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.content = upload_bytes
|
||||
resp.headers = {"content-type": "application/pdf"}
|
||||
resp.text = ""
|
||||
return resp
|
||||
resp = MagicMock()
|
||||
resp.status_code = 404
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
def _post(url, headers=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.text = ""
|
||||
return resp
|
||||
|
||||
c.get = MagicMock(side_effect=_get)
|
||||
c.post = MagicMock(side_effect=_post)
|
||||
c.close = MagicMock()
|
||||
return c
|
||||
|
||||
return patch("httpx.Client", side_effect=_client_factory)
|
||||
|
||||
|
||||
def test_poll_once_drains_uploads_before_processing_message_row(state: inbox.InboxState, tmp_path):
|
||||
"""The chat-message row's file.uri MUST be rewritten to the local
|
||||
workspace: URI by the time it lands in the InboxState queue. This
|
||||
requires BatchFetcher.wait_all() to run before message_from_activity
|
||||
on the second row.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Sandbox the on-disk staging dir so the test can't pollute the
|
||||
# workspace's real chat-uploads.
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [
|
||||
_upload_row("act-1", "file-A"),
|
||||
_message_row_referencing("act-2", "file-A"),
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF-bytes"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
|
||||
assert n == 1, "exactly one message row should be enqueued (the upload row is a side-effect, not a message)"
|
||||
queued = state.peek(10)
|
||||
assert len(queued) == 1
|
||||
# The contract this test exists to pin: the platform-pending: URI
|
||||
# was rewritten to workspace: BEFORE the message landed in the
|
||||
# state queue. message_from_activity mutates row['request_body']
|
||||
# in-place, so the rewritten URI is observable on the row dict
|
||||
# we passed in.
|
||||
rewritten_part = rows[1]["request_body"]["params"]["message"]["parts"][1]
|
||||
assert rewritten_part["file"]["uri"].startswith("workspace:"), (
|
||||
f"upload barrier broken: file.uri = {rewritten_part['file']['uri']!r}; "
|
||||
"rewrite_request_body ran before BatchFetcher.wait_all populated the cache"
|
||||
)
|
||||
# Cursor advanced past BOTH rows — upload-receive (act-1) is
|
||||
# acknowledged via the inbox cursor regardless of fetch outcome.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_with_only_upload_rows_drains_at_loop_end(state: inbox.InboxState, tmp_path):
|
||||
"""End-of-batch drain: a poll that contains ONLY upload rows (no
|
||||
chat-message row to trigger the inline drain) must still drain the
|
||||
BatchFetcher before _poll_once returns. Otherwise a future poll
|
||||
that picks up the corresponding chat-message row would race with
|
||||
in-flight fetches from the previous batch.
|
||||
"""
|
||||
import inbox_uploads
|
||||
inbox_uploads.get_cache().clear()
|
||||
real_dir = inbox_uploads.CHAT_UPLOAD_DIR
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = str(tmp_path / "chat-uploads")
|
||||
try:
|
||||
rows = [_upload_row("act-1", "file-A"), _upload_row("act-2", "file-B")]
|
||||
state.save_cursor("act-old")
|
||||
with _patch_httpx_routing(rows, upload_bytes=b"PDF"):
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
# By the time _poll_once returned, the URI cache must be hot
|
||||
# for both file_ids — proves the end-of-loop drain ran.
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-A") is not None
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/file-B") is not None
|
||||
finally:
|
||||
inbox_uploads.CHAT_UPLOAD_DIR = real_dir
|
||||
inbox_uploads.get_cache().clear()
|
||||
# Upload rows are NOT message rows; queue stays empty.
|
||||
assert n == 0
|
||||
# Cursor advances past both upload rows.
|
||||
assert state.load_cursor() == "act-2"
|
||||
|
||||
|
||||
def test_poll_once_no_uploads_does_not_construct_batch_fetcher(state: inbox.InboxState):
|
||||
"""A batch with no upload-receive rows must not pay the BatchFetcher
|
||||
construction cost — the executor + httpx client allocation is
|
||||
deferred until the first upload row appears.
|
||||
"""
|
||||
import inbox_uploads
|
||||
|
||||
constructed: list[Any] = []
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
constructed.append(kwargs)
|
||||
# Don't actually run __init__; we never hit submit/wait_all.
|
||||
self._closed = False
|
||||
self._futures = []
|
||||
self._executor = MagicMock()
|
||||
self._client = MagicMock()
|
||||
self._own_client = False
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": "act-1",
|
||||
"source_id": None,
|
||||
"method": "message/send",
|
||||
"summary": None,
|
||||
"request_body": {"parts": [{"type": "text", "text": "hi"}]},
|
||||
"created_at": "2026-04-30T22:00:00Z",
|
||||
},
|
||||
]
|
||||
state.save_cursor("act-old")
|
||||
resp = _make_response(200, rows)
|
||||
p, _ = _patch_httpx(resp)
|
||||
with patch.object(inbox_uploads.BatchFetcher, "__init__", _patched_init), p:
|
||||
n = inbox._poll_once(state, "http://platform", "ws-1", {})
|
||||
|
||||
assert n == 1
|
||||
assert constructed == [], "BatchFetcher must not be constructed when no upload rows are present"
|
||||
|
||||
|
||||
def test_default_cursor_path_falls_back_to_default(tmp_path, monkeypatch):
|
||||
"""When CONFIGS_DIR is unset, the cursor path resolves through
|
||||
configs_dir.resolve() — /configs in-container, ~/.molecule-workspace
|
||||
|
||||
@@ -695,3 +695,426 @@ def test_rewrite_request_body_handles_non_list_parts():
|
||||
def test_rewrite_request_body_handles_non_dict_file():
|
||||
body = {"parts": [{"kind": "file", "file": "not a dict"}]}
|
||||
inbox_uploads.rewrite_request_body(body) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_and_stage with shared client — Phase 5b client-reuse contract
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# When a caller passes ``client=`` to fetch_and_stage, that client must be
|
||||
# used for BOTH the GET /content and the POST /ack — no fresh
|
||||
# ``httpx.Client(...)`` constructions should happen. The pre-Phase-5b
|
||||
# implementation made one new client for GET and another for ack; the new
|
||||
# shape lets BatchFetcher share one connection pool across an entire batch.
|
||||
|
||||
|
||||
def test_fetch_and_stage_with_supplied_client_does_not_construct_new_client(monkeypatch):
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
get_resp = _make_resp(200, content=b"PDF", content_type="application/pdf")
|
||||
ack_resp = _make_resp(200)
|
||||
supplied = MagicMock()
|
||||
supplied.get = MagicMock(return_value=get_resp)
|
||||
supplied.post = MagicMock(return_value=ack_resp)
|
||||
# Sentinel: any code path that constructs httpx.Client when one was
|
||||
# already supplied is a regression — count constructions.
|
||||
constructed: list[Any] = []
|
||||
|
||||
class _ShouldNotBeCalled:
|
||||
def __init__(self, *a, **kw):
|
||||
constructed.append((a, kw))
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _ShouldNotBeCalled)
|
||||
|
||||
local_uri = inbox_uploads.fetch_and_stage(
|
||||
row,
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={"Authorization": "Bearer t"},
|
||||
client=supplied,
|
||||
)
|
||||
assert local_uri is not None
|
||||
assert constructed == [], "supplied client must be reused; no new Client should be constructed"
|
||||
# GET + POST ack both went through the supplied client.
|
||||
supplied.get.assert_called_once()
|
||||
supplied.post.assert_called_once()
|
||||
# Caller-owned client must NOT be closed by fetch_and_stage; the
|
||||
# batch fetcher (or test) closes it once the whole batch is done.
|
||||
supplied.close.assert_not_called()
|
||||
|
||||
|
||||
def test_fetch_and_stage_without_supplied_client_constructs_and_closes_one(monkeypatch):
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
get_resp = _make_resp(200, content=b"PDF", content_type="application/pdf")
|
||||
ack_resp = _make_resp(200)
|
||||
built: list[MagicMock] = []
|
||||
|
||||
def _factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.get = MagicMock(return_value=get_resp)
|
||||
c.post = MagicMock(return_value=ack_resp)
|
||||
built.append(c)
|
||||
return c
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _factory)
|
||||
|
||||
local_uri = inbox_uploads.fetch_and_stage(
|
||||
row, platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
assert local_uri is not None
|
||||
# Pre-Phase-5b built TWO clients (one for GET, one for ack); now exactly one.
|
||||
assert len(built) == 1, f"expected 1 httpx.Client construction, got {len(built)}"
|
||||
# Same client must serve BOTH calls.
|
||||
built[0].get.assert_called_once()
|
||||
built[0].post.assert_called_once()
|
||||
# Owned client must be closed by fetch_and_stage on the way out.
|
||||
built[0].close.assert_called_once()
|
||||
|
||||
|
||||
def test_fetch_and_stage_with_supplied_client_does_not_close_caller_client():
|
||||
# Even on failure the supplied client must not be closed — the
|
||||
# BatchFetcher owns the lifecycle for the whole batch.
|
||||
row = _row(uri="platform-pending:ws-1/file-1")
|
||||
supplied = MagicMock()
|
||||
supplied.get = MagicMock(side_effect=RuntimeError("network down"))
|
||||
supplied.post = MagicMock() # should not be reached on GET failure
|
||||
inbox_uploads.fetch_and_stage(
|
||||
row,
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=supplied,
|
||||
)
|
||||
supplied.close.assert_not_called()
|
||||
supplied.post.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BatchFetcher — concurrent fetch + URI cache barrier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_with_id(act_id: str, file_id: str) -> dict:
|
||||
"""Helper: an upload-receive row with a distinct activity id + file id."""
|
||||
return {
|
||||
"id": act_id,
|
||||
"method": "chat_upload_receive",
|
||||
"request_body": {
|
||||
"file_id": file_id,
|
||||
"name": f"{file_id}.pdf",
|
||||
"uri": f"platform-pending:ws-1/{file_id}",
|
||||
"mimeType": "application/pdf",
|
||||
"size": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _stub_client_for_batch(get_responses: dict[str, MagicMock]) -> MagicMock:
|
||||
"""Build one MagicMock client that returns per-file_id responses
|
||||
based on the file_id segment of the URL.
|
||||
"""
|
||||
client = MagicMock()
|
||||
|
||||
def _get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
for fid, resp in get_responses.items():
|
||||
if f"/pending-uploads/{fid}/content" in url:
|
||||
return resp
|
||||
return _make_resp(404)
|
||||
|
||||
def _post(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
return _make_resp(200)
|
||||
|
||||
client.get = MagicMock(side_effect=_get)
|
||||
client.post = MagicMock(side_effect=_post)
|
||||
return client
|
||||
|
||||
|
||||
def test_batch_fetcher_runs_submitted_rows_concurrently():
|
||||
# Three rows whose .get() blocks for ~120ms each. With 4 workers the
|
||||
# batch should complete in ~120ms (parallel), not ~360ms (serial).
|
||||
# The 250ms ceiling accommodates CI scheduler jitter while still
|
||||
# discriminating concurrent (~120ms) from serial (~360ms).
|
||||
import time
|
||||
|
||||
barrier_start = [0.0]
|
||||
|
||||
def _slow_get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
time.sleep(0.12)
|
||||
for fid in ("a", "b", "c"):
|
||||
if f"/pending-uploads/{fid}/content" in url:
|
||||
return _make_resp(200, content=b"X", content_type="text/plain")
|
||||
return _make_resp(404)
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=4,
|
||||
)
|
||||
barrier_start[0] = time.time()
|
||||
for fid in ("a", "b", "c"):
|
||||
bf.submit(_row_with_id(f"act-{fid}", fid))
|
||||
bf.wait_all()
|
||||
elapsed = time.time() - barrier_start[0]
|
||||
bf.close()
|
||||
|
||||
assert elapsed < 0.25, (
|
||||
f"3 rows × 120ms with 4 workers should finish in <250ms; got {elapsed:.3f}s "
|
||||
"(suggests serial execution — Phase 5b regression)"
|
||||
)
|
||||
assert client.get.call_count == 3
|
||||
assert client.post.call_count == 3
|
||||
|
||||
|
||||
def test_batch_fetcher_wait_all_blocks_until_uri_cache_populated():
|
||||
"""Pin the correctness invariant: when wait_all returns, the URI
|
||||
cache is hot for every submitted row. Without this barrier the
|
||||
inbox loop would process the chat-message row before its uploads
|
||||
were staged, and rewrite_request_body would surface the un-rewritten
|
||||
platform-pending: URI to the agent.
|
||||
"""
|
||||
import time
|
||||
|
||||
def _slow_get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
time.sleep(0.05)
|
||||
return _make_resp(200, content=b"data", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
inbox_uploads.get_cache().clear()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.submit(_row_with_id("act-b", "b"))
|
||||
bf.wait_all()
|
||||
# Cache must be hot for BOTH rows by the time wait_all returns.
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/a") is not None
|
||||
assert inbox_uploads.get_cache().get("platform-pending:ws-1/b") is not None
|
||||
|
||||
|
||||
def test_batch_fetcher_isolates_per_row_failure():
|
||||
"""One failing fetch must not abort siblings. Sibling rows complete,
|
||||
URI cache populates for them; the bad row's cache entry stays absent.
|
||||
"""
|
||||
def _get(url: str, headers: dict[str, str] | None = None) -> MagicMock:
|
||||
if "/pending-uploads/bad/content" in url:
|
||||
return _make_resp(500, text="upstream broken")
|
||||
return _make_resp(200, content=b"ok", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
inbox_uploads.get_cache().clear()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-1", "good1"))
|
||||
bf.submit(_row_with_id("act-2", "bad"))
|
||||
bf.submit(_row_with_id("act-3", "good2"))
|
||||
bf.wait_all()
|
||||
|
||||
cache = inbox_uploads.get_cache()
|
||||
assert cache.get("platform-pending:ws-1/good1") is not None
|
||||
assert cache.get("platform-pending:ws-1/good2") is not None
|
||||
assert cache.get("platform-pending:ws-1/bad") is None
|
||||
|
||||
|
||||
def test_batch_fetcher_reuses_one_client_across_all_submits():
|
||||
"""Every row in the batch must share the same client instance. This
|
||||
is the connection-pool-reuse leg of the perf win: a second fetch
|
||||
to the same host reuses the TCP+TLS handshake from the first.
|
||||
"""
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
for fid in ("a", "b", "c"):
|
||||
bf.submit(_row_with_id(f"act-{fid}", fid))
|
||||
bf.wait_all()
|
||||
|
||||
# 3 GETs + 3 POST acks all on the same client — no per-row Client
|
||||
# construction.
|
||||
assert client.get.call_count == 3
|
||||
assert client.post.call_count == 3
|
||||
|
||||
|
||||
def test_batch_fetcher_close_idempotent():
|
||||
client = MagicMock()
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
)
|
||||
bf.close()
|
||||
bf.close() # second call must not raise
|
||||
|
||||
|
||||
def test_batch_fetcher_submit_after_close_raises():
|
||||
client = MagicMock()
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
)
|
||||
bf.close()
|
||||
with pytest.raises(RuntimeError, match="submit after close"):
|
||||
bf.submit(_row_with_id("act-x", "x"))
|
||||
|
||||
|
||||
def test_batch_fetcher_owns_client_when_not_supplied(monkeypatch):
|
||||
built: list[MagicMock] = []
|
||||
|
||||
def _factory(*args, **kwargs):
|
||||
c = MagicMock()
|
||||
c.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
c.post = MagicMock(return_value=_make_resp(200))
|
||||
built.append(c)
|
||||
return c
|
||||
|
||||
monkeypatch.setattr("httpx.Client", _factory)
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
|
||||
assert len(built) == 1, "expected one owned client per BatchFetcher"
|
||||
built[0].close.assert_called_once()
|
||||
|
||||
|
||||
def test_batch_fetcher_does_not_close_supplied_client():
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(return_value=_make_resp(200, content=b"x", content_type="text/plain"))
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
# Supplied client survives the BatchFetcher's close — caller's lifecycle.
|
||||
client.close.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_fetcher_wait_all_no_op_on_empty_batch():
|
||||
client = MagicMock()
|
||||
with inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}, client=client
|
||||
) as bf:
|
||||
bf.wait_all() # nothing submitted; must not block, must not raise
|
||||
client.get.assert_not_called()
|
||||
client.post.assert_not_called()
|
||||
|
||||
|
||||
def test_batch_fetcher_httpx_missing_makes_submit_a_noop(monkeypatch):
|
||||
# No client supplied + httpx import fails → BatchFetcher degrades
|
||||
# gracefully: submit() returns None and the row is silently skipped.
|
||||
import sys
|
||||
|
||||
real_httpx = sys.modules.pop("httpx", None)
|
||||
monkeypatch.setitem(sys.modules, "httpx", None)
|
||||
try:
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat", workspace_id="ws-1", headers={}
|
||||
)
|
||||
result = bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.wait_all()
|
||||
bf.close()
|
||||
finally:
|
||||
if real_httpx is not None:
|
||||
sys.modules["httpx"] = real_httpx
|
||||
else:
|
||||
sys.modules.pop("httpx", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_batch_fetcher_close_after_timeout_does_not_block_on_running_workers():
|
||||
"""The deadline contract: when wait_all times out, close() must NOT
|
||||
block waiting for the leaked worker threads. Otherwise the inbox
|
||||
poll loop stalls indefinitely on a hung /content fetch — undoing
|
||||
the user-facing timeout.
|
||||
|
||||
Strategy: build a client whose .get() blocks on a threading.Event
|
||||
that the test never sets. Submit a row, wait_all with a tiny
|
||||
timeout, then time close(). If close() drained-and-waited it would
|
||||
block until we set the event (i.e., forever in this test).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
blocker = threading.Event() # never set — workers stay running
|
||||
|
||||
def _hang_get(url, headers=None):
|
||||
# Wait at most ~5s so a buggy implementation eventually unblocks
|
||||
# the test instead of timing out the whole pytest run, but
|
||||
# nothing legitimate should reach this fallback.
|
||||
blocker.wait(timeout=5.0)
|
||||
return _make_resp(200, content=b"x", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_hang_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=1, # serialize so submitting 1 keeps the worker busy
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
# Tiny timeout — wait_all must report the future as not_done.
|
||||
bf.wait_all(timeout=0.05)
|
||||
t0 = time.time()
|
||||
bf.close()
|
||||
elapsed = time.time() - t0
|
||||
# Unblock the lingering worker so it doesn't pollute later tests.
|
||||
blocker.set()
|
||||
|
||||
# Without the cancel-on-timeout fix, close() would block until
|
||||
# blocker.set() — i.e., the full ~5s. With the fix it returns
|
||||
# immediately because shutdown(wait=False) doesn't drain.
|
||||
assert elapsed < 1.0, (
|
||||
f"close() blocked for {elapsed:.2f}s after wait_all timeout — "
|
||||
"cancel-on-timeout regression: close() is draining instead of bailing"
|
||||
)
|
||||
|
||||
|
||||
def test_batch_fetcher_close_without_timeout_still_drains():
|
||||
"""Negative leg of the timeout contract: when wait_all completes
|
||||
cleanly (no timeout), close() must KEEP its drain-and-wait
|
||||
behavior so a still-queued ack POST isn't dropped mid-write.
|
||||
"""
|
||||
import time
|
||||
|
||||
def _slow_get(url, headers=None):
|
||||
time.sleep(0.05)
|
||||
return _make_resp(200, content=b"x", content_type="text/plain")
|
||||
|
||||
client = MagicMock()
|
||||
client.get = MagicMock(side_effect=_slow_get)
|
||||
client.post = MagicMock(return_value=_make_resp(200))
|
||||
|
||||
bf = inbox_uploads.BatchFetcher(
|
||||
platform_url="http://plat",
|
||||
workspace_id="ws-1",
|
||||
headers={},
|
||||
client=client,
|
||||
max_workers=2,
|
||||
)
|
||||
bf.submit(_row_with_id("act-a", "a"))
|
||||
bf.submit(_row_with_id("act-b", "b"))
|
||||
bf.wait_all() # generous default timeout — should not fire
|
||||
bf.close()
|
||||
|
||||
# All 2 GETs + 2 ACK POSTs ran to completion via drain-and-wait.
|
||||
assert client.get.call_count == 2
|
||||
assert client.post.call_count == 2
|
||||
|
||||
@@ -63,7 +63,7 @@ async def test_commit_memory_success(monkeypatch):
|
||||
mcp = _load_mcp()
|
||||
|
||||
client = FakeClient()
|
||||
monkeypatch.setattr("a2a_tools.httpx.AsyncClient", lambda **kw: client)
|
||||
monkeypatch.setattr("a2a_tools_memory.httpx.AsyncClient", lambda **kw: client)
|
||||
|
||||
result = await mcp.handle_tool_call("commit_memory", {
|
||||
"content": "Architecture decision: use Go for backend",
|
||||
@@ -92,7 +92,7 @@ async def test_commit_memory_default_scope(monkeypatch):
|
||||
mcp = _load_mcp()
|
||||
|
||||
client = FakeClient()
|
||||
monkeypatch.setattr("a2a_tools.httpx.AsyncClient", lambda **kw: client)
|
||||
monkeypatch.setattr("a2a_tools_memory.httpx.AsyncClient", lambda **kw: client)
|
||||
|
||||
result = await mcp.handle_tool_call("commit_memory", {
|
||||
"content": "Some note",
|
||||
@@ -108,7 +108,7 @@ async def test_recall_memory_success(monkeypatch):
|
||||
mcp = _load_mcp()
|
||||
|
||||
client = FakeClient()
|
||||
monkeypatch.setattr("a2a_tools.httpx.AsyncClient", lambda **kw: client)
|
||||
monkeypatch.setattr("a2a_tools_memory.httpx.AsyncClient", lambda **kw: client)
|
||||
|
||||
result = await mcp.handle_tool_call("recall_memory", {"query": "architecture"})
|
||||
|
||||
@@ -127,7 +127,7 @@ async def test_recall_memory_empty(monkeypatch):
|
||||
async def get(self, url, params=None, headers=None, **kwargs):
|
||||
return FakeResponse(200, [])
|
||||
|
||||
monkeypatch.setattr("a2a_tools.httpx.AsyncClient", lambda **kw: EmptyClient())
|
||||
monkeypatch.setattr("a2a_tools_memory.httpx.AsyncClient", lambda **kw: EmptyClient())
|
||||
|
||||
result = await mcp.handle_tool_call("recall_memory", {})
|
||||
assert "No memories found" in result
|
||||
@@ -139,7 +139,7 @@ async def test_recall_memory_with_scope_filter(monkeypatch):
|
||||
mcp = _load_mcp()
|
||||
|
||||
client = FakeClient()
|
||||
monkeypatch.setattr("a2a_tools.httpx.AsyncClient", lambda **kw: client)
|
||||
monkeypatch.setattr("a2a_tools_memory.httpx.AsyncClient", lambda **kw: client)
|
||||
|
||||
await mcp.handle_tool_call("recall_memory", {"scope": "TEAM"})
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class TestA2AToolCommitMemoryRedactsSecrets:
|
||||
|
||||
fake_client.post = _capture
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=fake_client):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=fake_client):
|
||||
await a2a_tools.tool_commit_memory(content_with_secret)
|
||||
|
||||
stored = captured.get("content", "")
|
||||
@@ -385,7 +385,7 @@ class TestA2AToolCommitMemoryRedactsSecrets:
|
||||
|
||||
fake_client.post = _capture
|
||||
|
||||
with patch("a2a_tools.httpx.AsyncClient", return_value=fake_client):
|
||||
with patch("a2a_tools_memory.httpx.AsyncClient", return_value=fake_client):
|
||||
await a2a_tools.tool_commit_memory(f"key={key}")
|
||||
|
||||
stored = captured.get("content", "")
|
||||
|
||||
Reference in New Issue
Block a user