Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f89f7a34d9 | |||
| aff7f810bc | |||
| fb0a35f22c | |||
| 6a08219724 | |||
| 0466a228e2 |
@@ -176,7 +176,7 @@ export function deriveProvidersFromModels(models: ModelSpec[]): string[] {
|
||||
// exactly the point of the platform adaptor. The deep `~/.hermes/
|
||||
// config.yaml` on the container is a separate runtime-internal file,
|
||||
// not this one.
|
||||
const RUNTIMES_WITH_OWN_CONFIG = new Set<string>(["external", "kimi", "kimi-cli"]);
|
||||
const RUNTIMES_WITH_OWN_CONFIG = new Set<string>(["external", "kimi", "kimi-cli", "openclaw"]);
|
||||
|
||||
const FALLBACK_RUNTIME_OPTIONS: RuntimeOption[] = [
|
||||
{ value: "", label: "LangGraph (default)", models: [], providers: [] },
|
||||
|
||||
+12
-8
@@ -8,14 +8,18 @@ import { getTenantSlug } from "./tenant";
|
||||
export const PLATFORM_URL =
|
||||
process.env.NEXT_PUBLIC_PLATFORM_URL ?? "http://localhost:8080";
|
||||
|
||||
// 15s is long enough for slow CP queries but short enough that a
|
||||
// hung backend doesn't leave the UI spinning forever. The abort
|
||||
// propagates through AbortController so React components can observe
|
||||
// the error and render a retry affordance. Callers that know the
|
||||
// endpoint is intentionally slow (org import walks a tree of
|
||||
// workspaces with server-side pacing) can pass `timeoutMs` to
|
||||
// override.
|
||||
const DEFAULT_TIMEOUT_MS = 15_000;
|
||||
// 35s is long enough for the slowest server-side path (EIC SSH
|
||||
// tunnel for tenant EC2 file operations, bounded server-side by
|
||||
// `eicFileOpTimeout = 30 * time.Second` in
|
||||
// workspace-server/internal/handlers/template_files_eic.go) so the
|
||||
// canvas surfaces the server's real error instead of aborting first
|
||||
// with a generic timeout. Shorter values caused "Save & Restart" to
|
||||
// time out at the client before the backend returned its 5xx. The
|
||||
// abort still propagates through AbortController so React components
|
||||
// can render a retry affordance. Callers that know an endpoint is
|
||||
// intentionally slow (org import walks a tree of workspaces with
|
||||
// server-side pacing) can pass `timeoutMs` to override.
|
||||
const DEFAULT_TIMEOUT_MS = 35_000;
|
||||
|
||||
export interface RequestOptions {
|
||||
timeoutMs?: number;
|
||||
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// workspace_broadcast.go — POST /workspaces/:id/broadcast
|
||||
//
|
||||
// Allows a workspace with broadcast_enabled=true to send a message to every
|
||||
// non-removed agent workspace in the org. The message is:
|
||||
// non-removed agent workspace in the SAME ORG. The message is:
|
||||
//
|
||||
// • Persisted in each recipient's activity_logs (type='broadcast_receive')
|
||||
// so poll-mode agents pick it up via GET /activity.
|
||||
@@ -16,6 +16,11 @@ package handlers
|
||||
// Auth: WorkspaceAuth (the agent triggers this with its own bearer token).
|
||||
// The handler re-validates broadcast_enabled inside the DB lookup to prevent
|
||||
// TOCTOU — the middleware only proved the token is valid, not the ability.
|
||||
//
|
||||
// Org isolation (OFFSEC-015): recipients are scoped to the sender's org using
|
||||
// a recursive CTE that walks the parent_id chain to find the org root. This
|
||||
// prevents a compromised or misconfigured workspace from broadcasting to
|
||||
// workspaces in other tenants' orgs.
|
||||
|
||||
import (
|
||||
"log"
|
||||
@@ -74,11 +79,49 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect all non-removed agent workspaces (excludes the sender itself).
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`,
|
||||
senderID,
|
||||
)
|
||||
// Find the sender's org root by walking the parent_id chain.
|
||||
// Workspaces with parent_id = NULL are org roots; every other workspace
|
||||
// belongs to the org identified by its topmost ancestor.
|
||||
var orgRootID string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
WITH RECURSIVE org_chain AS (
|
||||
SELECT id, parent_id, id AS root_id
|
||||
FROM workspaces
|
||||
WHERE id = $1
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id, c.root_id
|
||||
FROM workspaces w
|
||||
JOIN org_chain c ON w.id = c.parent_id
|
||||
)
|
||||
SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1
|
||||
`, senderID).Scan(&orgRootID)
|
||||
if err != nil {
|
||||
log.Printf("Broadcast: org root lookup for %s: %v", senderID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Collect all non-removed agent workspaces in the SAME ORG (same root_id),
|
||||
// excluding the sender itself.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
WITH RECURSIVE org_chain AS (
|
||||
SELECT id, parent_id, id AS root_id
|
||||
FROM workspaces
|
||||
WHERE parent_id IS NULL
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id, c.root_id
|
||||
FROM workspaces w
|
||||
JOIN org_chain c ON w.parent_id = c.id
|
||||
)
|
||||
SELECT c.id
|
||||
FROM org_chain c
|
||||
WHERE c.root_id = $1
|
||||
AND c.id != $2
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspaces w
|
||||
WHERE w.id = c.id AND w.status != 'removed'
|
||||
)
|
||||
`, orgRootID, senderID)
|
||||
if err != nil {
|
||||
log.Printf("Broadcast: recipient query failed for %s: %v", senderID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
|
||||
@@ -0,0 +1,666 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/ws"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupBroadcastDB uses QueryMatcherEqual so SQL strings with quoted literals
|
||||
// (e.g. status != 'removed') are compared verbatim, not as regex.
|
||||
func setupBroadcastDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// broadcastTestUUID is a properly formatted test UUID.
|
||||
const broadcastTestUUID = "bbbbbbbb-0001-0001-0001-000000000001"
|
||||
|
||||
// buildBroadcastCtx creates a gin.Context wired for POST /workspaces/:id/broadcast.
|
||||
func buildBroadcastCtx(id, body string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+id+"/broadcast", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req.WithContext(context.Background())
|
||||
c.Params = gin.Params{{Key: "id", Value: id}}
|
||||
return c, w
|
||||
}
|
||||
|
||||
// ─── Pure function ────────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcastTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
max int
|
||||
want string
|
||||
}{
|
||||
{"empty string", "", 10, ""},
|
||||
{"under limit", "hello", 10, "hello"},
|
||||
{"exactly at limit", "hello", 5, "hello"},
|
||||
{"over limit", "hello world", 5, "hello…"},
|
||||
{"unicode over limit", "こんにちは世界", 5, "こんにちは…"},
|
||||
{"ascii over limit", "abcdefghij", 5, "abcde…"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := broadcastTruncate(tc.s, tc.max)
|
||||
if got != tc.want {
|
||||
t.Errorf("broadcastTruncate(%q, %d) = %q; want %q", tc.s, tc.max, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Validation ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_InvalidWorkspaceID(t *testing.T) {
|
||||
c, w := buildBroadcastCtx("not-a-uuid", `{"message":"hello"}`)
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("want 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_MissingMessage(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{}`)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("want 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_MalformedJSON(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `not json`)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("want 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Auth / Authz ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
// Workspace lookup returns no rows.
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("want 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_WorkspaceLookupQueryError(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("want 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_BroadcastDisabled(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
// Workspace found but broadcast_enabled=false.
|
||||
rows := sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("test-workspace", false)
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(rows)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("want 403, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Org root lookup error (blocks cross-org broadcast) ──────────────────────
|
||||
|
||||
func TestBroadcast_OrgRootLookupError(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
// Workspace lookup succeeds.
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("test-workspace", true))
|
||||
|
||||
// Org root CTE fails — handler must NOT proceed to the recipient query
|
||||
// (which would broadcast cross-org if org root lookup failed silently).
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("want 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── DB error paths ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_RecipientQueryError(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
// Workspace lookup succeeds with broadcast_enabled=true.
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("test-workspace", true))
|
||||
|
||||
// Org root lookup succeeds.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// Recipient query fails.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("want 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_RecipientRowsError(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("test-workspace", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// Recipient query succeeds but rows.Err() fails.
|
||||
badRows := sqlmock.NewRows([]string{"id"}).AddRow("ws-2").RowError(0, sql.ErrConnDone)
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(badRows)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("want 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Success paths ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_Success_OneRecipient(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello world"}`)
|
||||
|
||||
// Workspace lookup.
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("sender-workspace", true))
|
||||
|
||||
// Org root lookup.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// Recipient query: one recipient.
|
||||
recipRows := sqlmock.NewRows([]string{"id"}).AddRow("ws-recipient-1")
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(recipRows)
|
||||
|
||||
// Activity log insert for recipient.
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs("ws-recipient-1", broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Activity log insert for sender (broadcast_sent).
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("want 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_Success_NoRecipients(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("solo-workspace", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// No recipients.
|
||||
recipRows := sqlmock.NewRows([]string{"id"})
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(recipRows)
|
||||
|
||||
// Activity log insert for sender (broadcast_sent).
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("want 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_Success_MultipleRecipients(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("broadcaster", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// Three recipients.
|
||||
recipRows := sqlmock.NewRows([]string{"id"}).
|
||||
AddRow("ws-1").AddRow("ws-2").AddRow("ws-3")
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(recipRows)
|
||||
|
||||
// Each recipient gets a broadcast_receive log.
|
||||
for _, rid := range []string{"ws-1", "ws-2", "ws-3"} {
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs(rid, broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
}
|
||||
|
||||
// Sender log.
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("want 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Recipient insert failure (logged, continues) ─────────────────────────────
|
||||
|
||||
func TestBroadcast_RecipientInsertError_ContinuesAndSucceeds(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("broadcaster", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
// Two recipients.
|
||||
recipRows := sqlmock.NewRows([]string{"id"}).AddRow("ws-1").AddRow("ws-2")
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(recipRows)
|
||||
|
||||
// First recipient insert fails (logged, continues).
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs("ws-1", broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
// Second recipient insert succeeds.
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs("ws-2", broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Sender log.
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
// Handler returns 200 even though one insert failed — it logs and continues.
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("want 200 despite insert error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Sender activity log insert failure (logged, still 200) ───────────────────
|
||||
|
||||
func TestBroadcast_SenderLogInsertError_Still200(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
c, w := buildBroadcastCtx(broadcastTestUUID, `{"message":"hello"}`)
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("broadcaster", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(broadcastTestUUID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(broadcastTestUUID))
|
||||
|
||||
recipRows := sqlmock.NewRows([]string{"id"}).AddRow("ws-1")
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(broadcastTestUUID, broadcastTestUUID).
|
||||
WillReturnRows(recipRows)
|
||||
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs("ws-1", broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Sender log fails — but handler still returns 200 (logged only).
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(broadcastTestUUID, sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
handler := NewBroadcastHandler(events.NewBroadcaster(ws.NewHub(nil)))
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("want 200 despite sender log error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Org-scoped recipient query tests (OFFSEC-015) ────────────────────────────
|
||||
|
||||
// TestBroadcast_OrgScopedRecipients verifies that a broadcast from Org-A does
|
||||
// NOT reach workspaces belonging to Org-B. This is the core regression test
|
||||
// for OFFSEC-015: the original query had no org filter, so a workspace in
|
||||
// Org-A could broadcast to every non-removed workspace in the entire DB.
|
||||
func TestBroadcast_OrgScopedRecipients(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001" // org-a-root
|
||||
wsAChild := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
// 1. Sender lookup
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Org-A Root", true))
|
||||
|
||||
// 2. Org root lookup — sender is its own root (parent_id = NULL)
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
// 3. Org-scoped recipient query — MUST include org filter so ws-b-child is NOT included.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(senderID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(wsAChild))
|
||||
|
||||
// Activity log inserts
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs(wsAChild, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"hello from org-a"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["status"] != "sent" {
|
||||
t.Errorf("expected status 'sent', got %v", resp["status"])
|
||||
}
|
||||
// ws-b-child is in a DIFFERENT org — the org-scoped query MUST NOT include it.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations — cross-org workspace was included: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_OrgRootSender verifies that when the sender IS the
|
||||
// org root (parent_id = NULL), broadcasts still reach sibling workspaces.
|
||||
func TestBroadcast_OrgScoped_OrgRootSender(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001"
|
||||
siblingID := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(senderID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
|
||||
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"hello siblings"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_ChildWorkspaceSender verifies that a non-root child
|
||||
// workspace can broadcast to siblings in the same org.
|
||||
func TestBroadcast_OrgScoped_ChildWorkspaceSender(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
orgRootID := "00000000-0000-0000-0000-000000000001"
|
||||
senderID := "00000000-0000-0000-0000-000000000002"
|
||||
siblingID := "00000000-0000-0000-0000-000000000003"
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Child Agent", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(orgRootID))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(orgRootID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
|
||||
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"child broadcasting"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_SelfBroadcastExcluded verifies that broadcasting
|
||||
// from a workspace does not send a broadcast_receive to the sender itself.
|
||||
func TestBroadcast_OrgScoped_SelfBroadcastExcluded(t *testing.T) {
|
||||
mock := setupBroadcastDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001"
|
||||
peerID := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
mock.ExpectQuery("SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE id = $1 UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.id = c.parent_id ) SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1").
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
// Recipient query MUST exclude sender via id != senderID
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS ( SELECT id, parent_id, id AS root_id FROM workspaces WHERE parent_id IS NULL UNION ALL SELECT w.id, w.parent_id, c.root_id FROM workspaces w JOIN org_chain c ON w.parent_id = c.id ) SELECT c.id FROM org_chain c WHERE c.root_id = $1 AND c.id != $2 AND EXISTS ( SELECT 1 FROM workspaces w WHERE w.id = c.id AND w.status != 'removed' )").
|
||||
WithArgs(senderID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(peerID))
|
||||
|
||||
// Peer receives broadcast_receive
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status) VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')").
|
||||
WithArgs(peerID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// Sender logs broadcast_sent (NOT broadcast_receive)
|
||||
mock.ExpectExec("INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status) VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')").
|
||||
WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"no echo to self"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -35,12 +35,14 @@ from a2a_tools import (
|
||||
tool_commit_memory,
|
||||
tool_delegate_task,
|
||||
tool_delegate_task_async,
|
||||
tool_get_runtime_identity,
|
||||
tool_get_workspace_info,
|
||||
tool_inbox_peek,
|
||||
tool_inbox_pop,
|
||||
tool_list_peers,
|
||||
tool_recall_memory,
|
||||
tool_send_message_to_user,
|
||||
tool_update_agent_card,
|
||||
tool_wait_for_message,
|
||||
)
|
||||
from platform_tools.registry import TOOLS as _PLATFORM_TOOL_SPECS
|
||||
@@ -130,6 +132,10 @@ async def handle_tool_call(name: str, arguments: dict) -> str:
|
||||
return await tool_get_workspace_info(
|
||||
source_workspace_id=arguments.get("source_workspace_id") or None,
|
||||
)
|
||||
elif name == "get_runtime_identity":
|
||||
return await tool_get_runtime_identity()
|
||||
elif name == "update_agent_card":
|
||||
return await tool_update_agent_card(arguments.get("card"))
|
||||
elif name == "commit_memory":
|
||||
return await tool_commit_memory(
|
||||
arguments.get("content", ""),
|
||||
|
||||
@@ -167,3 +167,15 @@ from a2a_tools_inbox import ( # noqa: E402 (import after the top-of-module imp
|
||||
tool_inbox_pop,
|
||||
tool_wait_for_message,
|
||||
)
|
||||
|
||||
|
||||
# Identity tool handlers — extracted to a2a_tools_identity. Ports the
|
||||
# two T4-tier MCP tools (``tool_get_runtime_identity`` +
|
||||
# ``tool_update_agent_card``) from molecule-ai-workspace-runtime PR#17.
|
||||
# That repo is mirror-only (reference_runtime_repo_is_mirror_only);
|
||||
# this is the canonical edit point, and the wheel mirror is
|
||||
# regenerated by publish-runtime.yml on merge.
|
||||
from a2a_tools_identity import ( # noqa: E402 (import after the top-of-module imports)
|
||||
tool_get_runtime_identity,
|
||||
tool_update_agent_card,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Identity tool handlers — single-concern slice of the a2a_tools surface.
|
||||
|
||||
Owns the two MCP tools that close the T4-tier workspace owner-permission
|
||||
gaps reported via the canvas:
|
||||
|
||||
* ``tool_get_runtime_identity`` — env-only; returns model, model_provider,
|
||||
molecule_model, anthropic_base_url, tier, workspace_id, runtime
|
||||
(ADAPTER_MODULE). No HTTP call. Always permitted by RBAC — even
|
||||
read-only agents may know what model they are.
|
||||
|
||||
* ``tool_update_agent_card`` — POSTs the card to ``/registry/update-card``
|
||||
with the workspace's own bearer (same auth path as ``tool_commit_memory``
|
||||
via ``a2a_tools_rbac.auth_headers_for_heartbeat``). The platform
|
||||
replaces the stored card and broadcasts an ``agent_card_updated``
|
||||
event so the canvas reflects the new card live. Gated on
|
||||
``memory.write`` capability via the existing RBAC permission map so
|
||||
read-only roles can't silently rewrite the platform card.
|
||||
|
||||
Both originated as a port of molecule-ai-workspace-runtime PR#17
|
||||
(``feat(mcp): add update_agent_card + get_runtime_identity tools``).
|
||||
The mirror-only PR#17 was closed without merge per
|
||||
``reference_runtime_repo_is_mirror_only``; the canonical edit point is
|
||||
this monorepo at ``workspace/`` and the wheel mirror is regenerated
|
||||
automatically by the publish-runtime workflow.
|
||||
|
||||
Imports the auth-header primitive from ``a2a_tools_rbac`` (iter 4a) —
|
||||
NOT from ``a2a_tools`` — to avoid a circular import with the
|
||||
kitchen-sink re-export module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from a2a_client import PLATFORM_URL
|
||||
from a2a_tools_rbac import (
|
||||
auth_headers_for_heartbeat as _auth_headers_for_heartbeat,
|
||||
check_memory_write_permission as _check_memory_write_permission,
|
||||
)
|
||||
|
||||
|
||||
def _runtime_identity_payload() -> dict[str, Any]:
|
||||
"""Build the identity dict — env-only, no I/O.
|
||||
|
||||
Factored out from ``tool_get_runtime_identity`` so tests can assert
|
||||
against the exact key set without re-parsing JSON. The MCP tool
|
||||
handler ``tool_get_runtime_identity`` is the only public caller in
|
||||
production; tests call this helper directly.
|
||||
"""
|
||||
return {
|
||||
"model": os.environ.get("MODEL", ""),
|
||||
"model_provider": os.environ.get("MODEL_PROVIDER", ""),
|
||||
"molecule_model": os.environ.get("MOLECULE_MODEL", ""),
|
||||
"anthropic_base_url": os.environ.get("ANTHROPIC_BASE_URL", ""),
|
||||
"tier": os.environ.get("TIER", ""),
|
||||
"workspace_id": os.environ.get("WORKSPACE_ID", ""),
|
||||
# Adapter module is the closest thing the runtime has to a
|
||||
# "template slug" — e.g. "adapter" for claude-code-default,
|
||||
# "hermes" for hermes-template, etc. Picked from
|
||||
# $ADAPTER_MODULE env baked by each template's Dockerfile.
|
||||
"runtime": os.environ.get("ADAPTER_MODULE", ""),
|
||||
}
|
||||
|
||||
|
||||
async def tool_get_runtime_identity() -> str:
|
||||
"""Return this runtime's identity — model, provider, tier, IDs.
|
||||
|
||||
Env-only; no HTTP call. Useful so the agent can answer "what model
|
||||
am I?" correctly instead of guessing from a stale system prompt
|
||||
that the operator may have changed between boots.
|
||||
|
||||
Returns the identity as a JSON-encoded string (the dispatch contract
|
||||
every MCP tool in this module follows). Tests that want to assert
|
||||
individual fields can call ``_runtime_identity_payload()`` directly,
|
||||
or ``json.loads`` the return value.
|
||||
|
||||
Always permitted by RBAC — there is no sensitive information here
|
||||
that isn't already available to the process via ``os.environ``.
|
||||
The point of the tool is to surface those env values to the agent
|
||||
layer in a stable, documented shape rather than expecting every
|
||||
agent runtime to know to ``echo $MODEL``.
|
||||
"""
|
||||
return json.dumps(_runtime_identity_payload(), indent=2)
|
||||
|
||||
|
||||
async def tool_update_agent_card(card: Any) -> str:
|
||||
"""Update this workspace's agent_card on the platform.
|
||||
|
||||
POSTs the provided card to ``/registry/update-card`` with the
|
||||
workspace's own bearer token (same auth path as ``tool_commit_memory``
|
||||
and ``tool_get_workspace_info``). The platform validates required
|
||||
fields server-side, replaces the stored card, and broadcasts an
|
||||
``agent_card_updated`` event so the canvas updates live.
|
||||
|
||||
Args:
|
||||
card: A JSON-serialisable object (typically a dict) holding the
|
||||
new card. The platform validates required fields server-side.
|
||||
|
||||
Returns:
|
||||
JSON-encoded string. Body:
|
||||
- ``{"success": true, "status": "updated"}`` on success;
|
||||
- ``{"success": false, "error": "<msg>", "status_code": <int>}``
|
||||
on platform error;
|
||||
- ``{"success": false, "error": "<reason>"}`` on local validation
|
||||
(non-dict card, missing WORKSPACE_ID, network error).
|
||||
|
||||
Permission gate: this tool requires the ``memory.write`` RBAC
|
||||
capability — same gate as ``tool_commit_memory``. The check runs
|
||||
inline rather than at the dispatcher layer to keep ``a2a_mcp_server``
|
||||
permission-agnostic (the gate sits with the implementation, not the
|
||||
transport). Read-only roles get a clear error string back instead
|
||||
of a 403 from the platform.
|
||||
|
||||
We re-check ``isinstance(card, dict)`` here defensively rather than
|
||||
trust the MCP schema validator alone — the schema only constrains
|
||||
the transport, not the in-process call surface used by tests and
|
||||
sibling modules.
|
||||
"""
|
||||
payload = await _update_agent_card_impl(card)
|
||||
return json.dumps(payload, indent=2)
|
||||
|
||||
|
||||
async def _update_agent_card_impl(card: Any) -> dict[str, Any]:
|
||||
"""Dict-returning core of ``tool_update_agent_card``.
|
||||
|
||||
Split out so tests can assert against the raw dict shape (status
|
||||
codes, error messages) without re-parsing JSON on every assertion.
|
||||
The string-returning ``tool_update_agent_card`` is a thin wrapper
|
||||
invoked by the MCP dispatcher.
|
||||
"""
|
||||
# RBAC: require memory.write permission. Same gate as
|
||||
# tool_commit_memory (the agent already needs this capability to
|
||||
# persist anything outbound). Read-only roles can still call
|
||||
# get_runtime_identity / get_workspace_info to introspect — those
|
||||
# are env-only / read-only and have no inline gate.
|
||||
if not _check_memory_write_permission():
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
"RBAC — this workspace does not have the 'memory.write' "
|
||||
"permission required to update the agent_card."
|
||||
),
|
||||
}
|
||||
if not isinstance(card, dict):
|
||||
return {
|
||||
"success": False,
|
||||
"error": "card must be a JSON object (dict)",
|
||||
}
|
||||
ws_id = os.environ.get("WORKSPACE_ID", "")
|
||||
if not ws_id:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "WORKSPACE_ID env not set; cannot identify caller",
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/registry/update-card",
|
||||
json={"workspace_id": ws_id, "agent_card": card},
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
body: dict[str, Any] = {}
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"success": True,
|
||||
"status": body.get("status", "updated"),
|
||||
}
|
||||
# Non-200 — surface what the platform returned.
|
||||
error_msg = ""
|
||||
try:
|
||||
error_msg = resp.json().get("error", "") or resp.text
|
||||
except Exception:
|
||||
error_msg = resp.text
|
||||
return {
|
||||
"success": False,
|
||||
"status_code": resp.status_code,
|
||||
"error": error_msg,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"network error: {e}"}
|
||||
@@ -340,6 +340,16 @@ _CLI_A2A_COMMAND_KEYWORDS: dict[str, str | None] = {
|
||||
"delegate_task_async": "delegate --async",
|
||||
"check_task_status": "status",
|
||||
"get_workspace_info": "info",
|
||||
# `get_runtime_identity` + `update_agent_card` are MCP-first
|
||||
# capabilities — the CLI subprocess interface doesn't expose them
|
||||
# today. `get_runtime_identity` is env-only and an agent on a
|
||||
# CLI-only runtime can already `echo $MODEL` etc, so there's no
|
||||
# functional gap. `update_agent_card` requires a JSON object
|
||||
# argument that wouldn't survive a positional-arg shell invocation
|
||||
# cleanly. Mapped to None — flip to a keyword if a2a_cli grows
|
||||
# `identity` / `card` subcommands in the future.
|
||||
"get_runtime_identity": None,
|
||||
"update_agent_card": None,
|
||||
# `broadcast_message` is not exposed via the CLI subprocess interface
|
||||
# today — it's an MCP-first capability. If a2a_cli grows a `broadcast`
|
||||
# subcommand, map it here and the alignment test will gate the change.
|
||||
|
||||
@@ -57,12 +57,14 @@ from a2a_tools import (
|
||||
tool_commit_memory,
|
||||
tool_delegate_task,
|
||||
tool_delegate_task_async,
|
||||
tool_get_runtime_identity,
|
||||
tool_get_workspace_info,
|
||||
tool_inbox_peek,
|
||||
tool_inbox_pop,
|
||||
tool_list_peers,
|
||||
tool_recall_memory,
|
||||
tool_send_message_to_user,
|
||||
tool_update_agent_card,
|
||||
tool_wait_for_message,
|
||||
)
|
||||
|
||||
@@ -289,6 +291,61 @@ _GET_WORKSPACE_INFO = ToolSpec(
|
||||
section=A2A_SECTION,
|
||||
)
|
||||
|
||||
_GET_RUNTIME_IDENTITY = ToolSpec(
|
||||
name="get_runtime_identity",
|
||||
short=(
|
||||
"Return this runtime's identity — model, model_provider, tier, "
|
||||
"workspace_id, runtime template. Reads from process env; no HTTP call."
|
||||
),
|
||||
when_to_use=(
|
||||
"Use this to answer 'what model am I?' truthfully instead of "
|
||||
"guessing from a stale system prompt — the operator may have "
|
||||
"routed you to a different model via persona env between boots. "
|
||||
"Always permitted by RBAC: even read-only agents may know what "
|
||||
"model they are. Distinct from get_workspace_info — that one "
|
||||
"calls the platform for ID/role/tier/parent (workspace metadata); "
|
||||
"this one returns the live process env (MODEL, MODEL_PROVIDER, "
|
||||
"MOLECULE_MODEL, ANTHROPIC_BASE_URL, TIER, WORKSPACE_ID, "
|
||||
"ADAPTER_MODULE)."
|
||||
),
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
impl=tool_get_runtime_identity,
|
||||
section=A2A_SECTION,
|
||||
)
|
||||
|
||||
_UPDATE_AGENT_CARD = ToolSpec(
|
||||
name="update_agent_card",
|
||||
short=(
|
||||
"Replace this workspace's agent_card on the platform. The "
|
||||
"platform validates required fields and broadcasts an "
|
||||
"agent_card_updated event so the canvas reflects the change live."
|
||||
),
|
||||
when_to_use=(
|
||||
"Use when the workspace's capabilities, skills, description, or "
|
||||
"name change and the canvas display needs to follow. The "
|
||||
"platform stores the new card and pushes an "
|
||||
"``agent_card_updated`` event to subscribers. Gated behind the "
|
||||
"``memory.write`` RBAC capability — read-only roles cannot "
|
||||
"rewrite the card. Tier-1+ owners always have this capability."
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"card": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"The new agent_card object (name, version, "
|
||||
"description, skills, etc). Server-side validation "
|
||||
"rejects payloads missing required fields."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["card"],
|
||||
},
|
||||
impl=tool_update_agent_card,
|
||||
section=A2A_SECTION,
|
||||
)
|
||||
|
||||
_BROADCAST_MESSAGE = ToolSpec(
|
||||
name="broadcast_message",
|
||||
short=(
|
||||
@@ -642,6 +699,8 @@ TOOLS: list[ToolSpec] = [
|
||||
_CHECK_TASK_STATUS,
|
||||
_LIST_PEERS,
|
||||
_GET_WORKSPACE_INFO,
|
||||
_GET_RUNTIME_IDENTITY,
|
||||
_UPDATE_AGENT_CARD,
|
||||
_BROADCAST_MESSAGE,
|
||||
_SEND_MESSAGE_TO_USER,
|
||||
# Inbox (standalone-only; in-container returns informational error)
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
- **check_task_status**: Poll the status of a task started with delegate_task_async; returns result when done.
|
||||
- **list_peers**: List the workspaces this agent can communicate with — name, ID, status, role for each.
|
||||
- **get_workspace_info**: Get this workspace's own info — ID, name, role, tier, parent, status.
|
||||
- **get_runtime_identity**: Return this runtime's identity — model, model_provider, tier, workspace_id, runtime template. Reads from process env; no HTTP call.
|
||||
- **update_agent_card**: Replace this workspace's agent_card on the platform. The platform validates required fields and broadcasts an agent_card_updated event so the canvas reflects the change live.
|
||||
- **broadcast_message**: Send a message to ALL agent workspaces in the org simultaneously. Requires broadcast_enabled=true on this workspace (set by user/admin).
|
||||
- **send_message_to_user**: Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to: (1) acknowledge a task immediately ('Got it, I'll start working on this'), (2) send interim progress updates while doing long work, (3) deliver follow-up results after delegation completes, (4) attach files (zip, pdf, csv, image) for the user to download via the `attachments` field (NEVER paste file URLs in `message`). The message appears in the user's chat as if you're proactively reaching out.
|
||||
- **wait_for_message**: Block until the next inbound message (canvas user OR peer agent) arrives, or until ``timeout_secs`` elapses.
|
||||
@@ -27,6 +29,12 @@ Call this first when you need to delegate but don't know the target's ID. Access
|
||||
### get_workspace_info
|
||||
Use to introspect your own identity (e.g. before reporting back to the user, or to determine whether you're a tier-0 root that can write GLOBAL memory).
|
||||
|
||||
### get_runtime_identity
|
||||
Use this to answer 'what model am I?' truthfully instead of guessing from a stale system prompt — the operator may have routed you to a different model via persona env between boots. Always permitted by RBAC: even read-only agents may know what model they are. Distinct from get_workspace_info — that one calls the platform for ID/role/tier/parent (workspace metadata); this one returns the live process env (MODEL, MODEL_PROVIDER, MOLECULE_MODEL, ANTHROPIC_BASE_URL, TIER, WORKSPACE_ID, ADAPTER_MODULE).
|
||||
|
||||
### update_agent_card
|
||||
Use when the workspace's capabilities, skills, description, or name change and the canvas display needs to follow. The platform stores the new card and pushes an ``agent_card_updated`` event to subscribers. Gated behind the ``memory.write`` RBAC capability — read-only roles cannot rewrite the card. Tier-1+ owners always have this capability.
|
||||
|
||||
### broadcast_message
|
||||
Use for urgent, org-wide signals: critical status changes, emergency stop instructions, coordinated task announcements. Every non-removed workspace receives the message in its activity log (poll-mode agents see it on their next poll; push-mode canvases get a real-time banner). This tool returns an error if broadcast_enabled is false — a user or admin must enable it via the workspace abilities settings first.
|
||||
|
||||
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Tests for ``tool_get_runtime_identity`` and ``tool_update_agent_card``.
|
||||
|
||||
These two MCP tools close the T4-tier workspace owner-permission gaps
|
||||
reported via the canvas:
|
||||
|
||||
- the agent could not update its own ``agent_card`` (no MCP tool
|
||||
wrapped the existing ``POST /registry/update-card`` endpoint);
|
||||
- the agent could not identify which model it was running (the
|
||||
``MODEL`` env var is injected by ``provisioner.workspace_provision``
|
||||
but nothing surfaced it back to the agent).
|
||||
|
||||
Ported from molecule-ai-workspace-runtime PR#17 (mirror-only repo;
|
||||
canonical edit point per ``reference_runtime_repo_is_mirror_only``).
|
||||
Adapted to core's conventions:
|
||||
|
||||
* tool functions return ``str`` (JSON-encoded), matching every other
|
||||
tool in ``a2a_tools_*`` modules. Tests ``json.loads`` to inspect.
|
||||
* permission check ``memory.write`` runs inline in
|
||||
``tool_update_agent_card`` (same pattern as
|
||||
``a2a_tools_memory.tool_commit_memory``).
|
||||
* ``WORKSPACE_ID`` is read directly from ``os.environ`` — core does
|
||||
not have the runtime's validated-cache layer (``molecule_runtime.
|
||||
builtin_tools.validation``).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- Drift gate: re-export aliases on a2a_tools ------------------------------
|
||||
|
||||
class TestBackCompatAliases:
|
||||
"""Pin that ``a2a_tools.tool_*`` resolves to the same callable as
|
||||
``a2a_tools_identity.tool_*``. Refactor wrapping (e.g. a doc-string
|
||||
wrapper that loses the function identity) silently breaks call
|
||||
sites that ``patch("a2a_tools.tool_update_agent_card", ...)`` —
|
||||
this gate makes that drift fail fast."""
|
||||
|
||||
def test_tool_get_runtime_identity_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_identity
|
||||
assert a2a_tools.tool_get_runtime_identity is a2a_tools_identity.tool_get_runtime_identity
|
||||
|
||||
def test_tool_update_agent_card_alias(self):
|
||||
import a2a_tools
|
||||
import a2a_tools_identity
|
||||
assert a2a_tools.tool_update_agent_card is a2a_tools_identity.tool_update_agent_card
|
||||
|
||||
|
||||
# --- tool_get_runtime_identity ----------------------------------------------
|
||||
|
||||
class TestGetRuntimeIdentity:
|
||||
"""The tool returns env-derived runtime identity. No HTTP call."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_known_env_fields(self, monkeypatch):
|
||||
from a2a_tools_identity import tool_get_runtime_identity
|
||||
|
||||
monkeypatch.setenv("MODEL", "claude-opus-4-7")
|
||||
monkeypatch.setenv("MODEL_PROVIDER", "anthropic")
|
||||
monkeypatch.setenv("TIER", "T4")
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-abc")
|
||||
monkeypatch.setenv("ADAPTER_MODULE", "adapter")
|
||||
monkeypatch.setenv("MOLECULE_MODEL", "claude-opus-4-7")
|
||||
monkeypatch.setenv("ANTHROPIC_BASE_URL", "https://api.anthropic.com")
|
||||
|
||||
out = await tool_get_runtime_identity()
|
||||
# MCP tools return JSON-encoded strings (matches the contract
|
||||
# every other tool_* in a2a_tools_* uses).
|
||||
assert isinstance(out, str)
|
||||
parsed = json.loads(out)
|
||||
|
||||
assert parsed["model"] == "claude-opus-4-7"
|
||||
assert parsed["model_provider"] == "anthropic"
|
||||
assert parsed["tier"] == "T4"
|
||||
assert parsed["workspace_id"] == "ws-abc"
|
||||
assert parsed["runtime"] == "adapter"
|
||||
assert parsed["molecule_model"] == "claude-opus-4-7"
|
||||
assert parsed["anthropic_base_url"] == "https://api.anthropic.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_env_returns_empty_strings(self, monkeypatch):
|
||||
"""Tool MUST NOT raise when env vars are absent — every key is
|
||||
present but the value is the empty string. The agent then knows
|
||||
the slot exists but is unset."""
|
||||
from a2a_tools_identity import tool_get_runtime_identity
|
||||
|
||||
for var in (
|
||||
"MODEL", "MODEL_PROVIDER", "TIER", "WORKSPACE_ID",
|
||||
"ADAPTER_MODULE", "MOLECULE_MODEL", "ANTHROPIC_BASE_URL",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
parsed = json.loads(await tool_get_runtime_identity())
|
||||
assert parsed["model"] == ""
|
||||
assert parsed["model_provider"] == ""
|
||||
assert parsed["tier"] == ""
|
||||
assert parsed["workspace_id"] == ""
|
||||
assert parsed["runtime"] == ""
|
||||
assert parsed["molecule_model"] == ""
|
||||
assert parsed["anthropic_base_url"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_http_call_made(self, monkeypatch):
|
||||
"""``get_runtime_identity`` is env-only — must not open
|
||||
httpx.AsyncClient even if the call would otherwise succeed.
|
||||
Tripwire any client construction."""
|
||||
import httpx
|
||||
|
||||
from a2a_tools_identity import tool_get_runtime_identity
|
||||
|
||||
class _Tripwire:
|
||||
def __init__(self, *_a, **_kw):
|
||||
raise AssertionError(
|
||||
"tool_get_runtime_identity must not open httpx.AsyncClient"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(httpx, "AsyncClient", _Tripwire)
|
||||
# Must not raise.
|
||||
await tool_get_runtime_identity()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_helper_dict_matches_string_payload(self, monkeypatch):
|
||||
"""``_runtime_identity_payload`` is the dict-returning helper
|
||||
used by both the public tool and tests. Verify the public tool
|
||||
json.dumps the same dict — no field is dropped or renamed by
|
||||
the encoding step."""
|
||||
from a2a_tools_identity import (
|
||||
_runtime_identity_payload,
|
||||
tool_get_runtime_identity,
|
||||
)
|
||||
|
||||
monkeypatch.setenv("MODEL", "claude-opus-4-7")
|
||||
monkeypatch.setenv("TIER", "T4")
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-helper-check")
|
||||
|
||||
helper = _runtime_identity_payload()
|
||||
tool_str = await tool_get_runtime_identity()
|
||||
assert json.loads(tool_str) == helper
|
||||
|
||||
|
||||
# --- tool_update_agent_card -------------------------------------------------
|
||||
|
||||
|
||||
class _MockResponse:
|
||||
def __init__(self, status_code: int, payload: dict):
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = json.dumps(payload)
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _MockClient:
|
||||
"""Drop-in for httpx.AsyncClient context manager.
|
||||
|
||||
Records the URL + json body + headers the tool POSTed so the test
|
||||
can assert against them. Returns the canned _MockResponse passed
|
||||
in at construction time.
|
||||
"""
|
||||
|
||||
def __init__(self, *, response: _MockResponse, captured: dict):
|
||||
self._response = response
|
||||
self._captured = captured
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_args):
|
||||
return False
|
||||
|
||||
async def post(self, url, *, json=None, headers=None, **_kw): # noqa: A002
|
||||
self._captured["url"] = url
|
||||
self._captured["json"] = json
|
||||
self._captured["headers"] = headers
|
||||
return self._response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _grant_memory_write(monkeypatch):
|
||||
"""Force the inline RBAC gate inside ``tool_update_agent_card`` to
|
||||
succeed. The gate calls
|
||||
``a2a_tools_rbac.check_memory_write_permission`` which inspects
|
||||
``$MOLECULE_ROLES`` / the role table; the patch sidesteps that
|
||||
machinery so tests can focus on the platform-call shape.
|
||||
"""
|
||||
import a2a_tools_identity
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity, "_check_memory_write_permission", lambda: True
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateAgentCard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_posts_to_registry_update_card(
|
||||
self, monkeypatch, _grant_memory_write,
|
||||
):
|
||||
"""Hits POST {PLATFORM_URL}/registry/update-card with the
|
||||
workspace bearer and the {workspace_id, agent_card} body shape
|
||||
the platform handler expects (workspace-server
|
||||
``internal/handlers/registry.go``)."""
|
||||
import a2a_tools_identity
|
||||
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-42")
|
||||
# Ensure PLATFORM_URL re-import sees a deterministic value —
|
||||
# a2a_client imports it at module load so we patch the symbol
|
||||
# on a2a_tools_identity directly (the module's own reference).
|
||||
monkeypatch.setattr(a2a_tools_identity, "PLATFORM_URL", "http://test.invalid")
|
||||
|
||||
captured: dict = {}
|
||||
response = _MockResponse(200, {"status": "updated"})
|
||||
|
||||
def _client_factory(*_a, **_kw):
|
||||
return _MockClient(response=response, captured=captured)
|
||||
|
||||
monkeypatch.setattr(a2a_tools_identity.httpx, "AsyncClient", _client_factory)
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity, "_auth_headers_for_heartbeat",
|
||||
lambda: {"Authorization": "Bearer ws-token-xyz"},
|
||||
)
|
||||
|
||||
card = {"name": "agent-foo", "version": "0.1.0", "description": "demo"}
|
||||
result_str = await a2a_tools_identity.tool_update_agent_card(card)
|
||||
result = json.loads(result_str)
|
||||
|
||||
# URL: PLATFORM_URL + /registry/update-card
|
||||
assert captured["url"] == "http://test.invalid/registry/update-card"
|
||||
|
||||
# The platform handler expects {workspace_id, agent_card}; the
|
||||
# agent_card is the raw object the agent submitted.
|
||||
body = captured["json"]
|
||||
assert body["workspace_id"] == "ws-42"
|
||||
assert body["agent_card"] == card
|
||||
|
||||
# Auth header from auth_headers_for_heartbeat is forwarded
|
||||
# verbatim — same path commit_memory uses.
|
||||
assert captured["headers"]["Authorization"] == "Bearer ws-token-xyz"
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["status"] == "updated"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagates_server_error(
|
||||
self, monkeypatch, _grant_memory_write,
|
||||
):
|
||||
"""Non-200 from platform surfaces as a structured error to the
|
||||
agent. The agent sees {success:false, status_code, error} and
|
||||
can decide whether to retry, fall back, or escalate."""
|
||||
import a2a_tools_identity
|
||||
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-42")
|
||||
monkeypatch.setattr(a2a_tools_identity, "PLATFORM_URL", "http://test.invalid")
|
||||
|
||||
captured: dict = {}
|
||||
response = _MockResponse(400, {"error": "invalid card"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity.httpx, "AsyncClient",
|
||||
lambda *a, **kw: _MockClient(response=response, captured=captured),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity, "_auth_headers_for_heartbeat", lambda: {},
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
await a2a_tools_identity.tool_update_agent_card({"name": "x"})
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert result["status_code"] == 400
|
||||
assert "invalid card" in str(result["error"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_non_dict_card(self, _grant_memory_write):
|
||||
"""The MCP schema constrains transport callers to pass a dict;
|
||||
in-process callers (tests, sibling modules) can still pass any
|
||||
type. Reject non-dict defensively so the platform isn't asked
|
||||
to validate JSON-encoded strings or lists."""
|
||||
from a2a_tools_identity import tool_update_agent_card
|
||||
|
||||
result = json.loads(await tool_update_agent_card("not-a-dict"))
|
||||
assert result["success"] is False
|
||||
assert "dict" in str(result["error"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_id_missing_returns_error(
|
||||
self, monkeypatch, _grant_memory_write,
|
||||
):
|
||||
"""If WORKSPACE_ID is not set the tool refuses to issue the
|
||||
request — it would otherwise POST with an empty workspace_id
|
||||
and let the platform return a confusing 400."""
|
||||
from a2a_tools_identity import tool_update_agent_card
|
||||
|
||||
monkeypatch.delenv("WORKSPACE_ID", raising=False)
|
||||
|
||||
result = json.loads(await tool_update_agent_card({"name": "x"}))
|
||||
assert result["success"] is False
|
||||
assert "workspace_id" in str(result["error"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_denies_when_memory_write_permission_missing(self, monkeypatch):
|
||||
"""The agent's RBAC role must grant ``memory.write`` to update
|
||||
the card. Read-only roles get an RBAC error string back
|
||||
immediately, never touching the platform."""
|
||||
import a2a_tools_identity
|
||||
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-42")
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity, "_check_memory_write_permission", lambda: False,
|
||||
)
|
||||
|
||||
# Tripwire httpx — must not be called when RBAC denies.
|
||||
import httpx
|
||||
|
||||
class _Tripwire:
|
||||
def __init__(self, *_a, **_kw):
|
||||
raise AssertionError("RBAC denial must short-circuit before httpx call")
|
||||
|
||||
monkeypatch.setattr(httpx, "AsyncClient", _Tripwire)
|
||||
|
||||
result = json.loads(
|
||||
await a2a_tools_identity.tool_update_agent_card({"name": "x"}),
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "memory.write" in str(result["error"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_exception_returns_structured_error(
|
||||
self, monkeypatch, _grant_memory_write,
|
||||
):
|
||||
"""A network exception (DNS failure, connect timeout, etc) is
|
||||
wrapped into a structured error dict instead of bubbling up
|
||||
to the MCP transport layer."""
|
||||
import a2a_tools_identity
|
||||
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-42")
|
||||
monkeypatch.setattr(a2a_tools_identity, "PLATFORM_URL", "http://test.invalid")
|
||||
|
||||
class _ExplodingClient:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_a):
|
||||
return False
|
||||
|
||||
async def post(self, *_a, **_kw):
|
||||
raise RuntimeError("simulated DNS failure")
|
||||
|
||||
monkeypatch.setattr(
|
||||
a2a_tools_identity.httpx, "AsyncClient",
|
||||
lambda *a, **kw: _ExplodingClient(),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
await a2a_tools_identity.tool_update_agent_card({"name": "x"})
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "network" in str(result["error"]).lower()
|
||||
|
||||
|
||||
# --- Registry contract ------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryContract:
|
||||
"""Pin the new tools' registration in platform_tools.registry. The
|
||||
structural tests in ``test_platform_tools.py`` already check
|
||||
registry↔MCP alignment; these are tighter assertions specific to
|
||||
the two new tools so a future contributor deleting one entry sees
|
||||
a focused failure."""
|
||||
|
||||
def test_get_runtime_identity_in_registry(self):
|
||||
from platform_tools.registry import by_name
|
||||
spec = by_name("get_runtime_identity")
|
||||
assert spec.section == "a2a"
|
||||
# No input parameters — env-only call.
|
||||
assert spec.input_schema == {"type": "object", "properties": {}}
|
||||
# impl points at the actual tool function, not a shim.
|
||||
from a2a_tools_identity import tool_get_runtime_identity
|
||||
assert spec.impl is tool_get_runtime_identity
|
||||
|
||||
def test_update_agent_card_in_registry(self):
|
||||
from platform_tools.registry import by_name
|
||||
spec = by_name("update_agent_card")
|
||||
assert spec.section == "a2a"
|
||||
assert "card" in spec.input_schema["properties"]
|
||||
assert spec.input_schema["required"] == ["card"]
|
||||
from a2a_tools_identity import tool_update_agent_card
|
||||
assert spec.impl is tool_update_agent_card
|
||||
Reference in New Issue
Block a user