Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0bf7eb92e5 |
+16
-24
@@ -49,7 +49,7 @@ on:
|
||||
# `merge_group` (GitHub merge-queue trigger) dropped — Gitea has no merge
|
||||
# queue. The .github/ original retains it; this Gitea-side copy drops it.
|
||||
|
||||
# Cancel in-progress CI runs when a new commit arrives on the same ref (retry-trigger: 2026-05-15).
|
||||
# Cancel in-progress CI runs when a new commit arrives on the same ref.
|
||||
# Stale runs queue up otherwise. PR refs and main/staging refs each get
|
||||
# their own group because github.ref differs.
|
||||
concurrency:
|
||||
@@ -145,11 +145,10 @@ jobs:
|
||||
# the diagnostic step with its own continue-on-error: true (line 203).
|
||||
# Flip confirmed by CI / Platform (Go) status = success on main HEAD 363905d3.
|
||||
continue-on-error: false
|
||||
# Job-level ceiling. The go test step below runs with a per-step 70m timeout;
|
||||
# this cap catches any step that leaks past that. Set well above 70m so
|
||||
# the per-step timeout is the active constraint. Raised to 75m
|
||||
# to account for golangci-lint ~17m + test suite ~20-30m on cold runner (mc#1099).
|
||||
timeout-minutes: 75
|
||||
# Job-level ceiling. The go test step below runs with a per-step 10m timeout;
|
||||
# this cap catches any step that leaks past that. Set well above 10m so
|
||||
# the per-step timeout is the active constraint.
|
||||
timeout-minutes: 15
|
||||
defaults:
|
||||
run:
|
||||
working-directory: workspace-server
|
||||
@@ -173,22 +172,16 @@ jobs:
|
||||
- if: always()
|
||||
name: Install golangci-lint
|
||||
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
|
||||
- if: success()
|
||||
- if: always()
|
||||
name: Run golangci-lint
|
||||
# mc#1099: --no-config bypasses .golangci.yaml ceiling; --timeout 30m
|
||||
# is the active constraint. Cold runner: fetch-depth:0 clone (5-10m) + Go
|
||||
# toolchain (5-10m) + mod download (2-5m) + build + vet + install lint
|
||||
# (5m) = ~15-20m before linting even starts. 30m gives headroom.
|
||||
run: $(go env GOPATH)/bin/golangci-lint run --no-config --timeout 30m ./...
|
||||
- if: success()
|
||||
name: Diagnostic — per-package verbose 600s
|
||||
# mc#1099: step-level ceiling above the 600s Go timeout for cold-runner headroom.
|
||||
timeout-minutes: 20
|
||||
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
|
||||
- if: always()
|
||||
name: Diagnostic — per-package verbose 60s
|
||||
run: |
|
||||
set +e
|
||||
go test -race -v -timeout 600s ./internal/handlers/... 2>&1 | tee /tmp/test-handlers.log
|
||||
go test -race -v -timeout 60s ./internal/handlers/... 2>&1 | tee /tmp/test-handlers.log
|
||||
handlers_exit=$?
|
||||
go test -race -v -timeout 600s ./internal/pendinguploads/... 2>&1 | tee /tmp/test-pu.log
|
||||
go test -race -v -timeout 60s ./internal/pendinguploads/... 2>&1 | tee /tmp/test-pu.log
|
||||
pu_exit=$?
|
||||
echo "::group::handlers exit=$handlers_exit (last 100 lines)"
|
||||
tail -100 /tmp/test-handlers.log
|
||||
@@ -200,12 +193,11 @@ jobs:
|
||||
continue-on-error: true
|
||||
- if: always()
|
||||
name: Run tests with race detection and coverage
|
||||
# mc#1099: cold runner (~5-20m) + race detector (3-5x overhead) can push
|
||||
# the suite past 10m. Per-step ceiling must exceed Go-level timeout so
|
||||
# Go's timeout fires first (clean interrupt) rather than the step ceiling
|
||||
# (SIGKILL). Job-level ceiling (75m) is the outer backstop.
|
||||
timeout-minutes: 70
|
||||
run: go test -race -timeout 60m -coverprofile=coverage.out ./...
|
||||
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
|
||||
# full ./... suite with race detection + coverage. A 10m per-step timeout
|
||||
# lets the suite complete on cold cache (~5-7m) while failing cleanly
|
||||
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
|
||||
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
|
||||
|
||||
- if: always()
|
||||
name: Per-file coverage report
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupQueueStatusHandlerDB creates a sqlmock DB with QueryMatcherEqual for exact SQL string matching.
|
||||
func setupQueueStatusHandlerDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// Exact SQL strings used by the production code.
|
||||
const (
|
||||
sqlQueueRowAuthFields = `SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`
|
||||
sqlQueueStatusByID = `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
q.status,
|
||||
q.priority,
|
||||
q.attempts,
|
||||
q.last_error,
|
||||
q.enqueued_at::text,
|
||||
q.dispatched_at::text,
|
||||
q.completed_at::text,
|
||||
q.expires_at::text,
|
||||
al.response_body::text
|
||||
FROM a2a_queue q
|
||||
LEFT JOIN activity_logs al
|
||||
ON al.method = 'delegate_result'
|
||||
AND al.target_id = q.workspace_id
|
||||
AND al.workspace_id = q.caller_id
|
||||
AND al.response_body->>'delegation_id' = (q.body->'params'->'message'->'metadata'->>'delegation_id')
|
||||
WHERE q.id = $1`
|
||||
)
|
||||
|
||||
// ── GetA2AQueueStatus HTTP handler tests ──────────────────────────────────────
|
||||
|
||||
// TestGetA2AQueueStatus_QueueIDEmpty_Returns400 exercises the handler directly
|
||||
// (not via router) so we can verify the empty-value branch without relying on
|
||||
// Gin route-matching behaviour.
|
||||
func TestGetA2AQueueStatus_QueueIDEmpty_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
// queue_id param is empty string
|
||||
c.Params = gin.Params{
|
||||
{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||
{Key: "queue_id", Value: ""},
|
||||
}
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h.GetA2AQueueStatus(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_NoIdentity_NoOrgToken_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/wsid/a2a/queue/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// No identity derivable → 404 (not 401) per existence-non-inference policy.
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_OrgToken_SkipsCallerCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow("other-ws", wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "queued", 50, 0,
|
||||
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
// Simulate org-token middleware setting org_token_id.
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", func(c *gin.Context) {
|
||||
c.Set("org_token_id", "org-admin")
|
||||
h.GetA2AQueueStatus(c)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/workspaces/wsid/a2a/queue/"+queueID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_CallerWorkspaceMatchesCallerID_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "completed", 50, 1,
|
||||
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
|
||||
nil, []byte(`{"text":"result"}`),
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_CallerWorkspaceMatchesWorkspaceID_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "queued", 50, 0,
|
||||
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wsID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_QueueNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_QueueAuthFieldsDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_WrongCallerWorkspace_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
wrongCaller := "dddddddd-dddd-dddd-dddd-dddddddddddd"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wrongCaller)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_StatusFetchDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_FullHappyPath_ReturnsJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
respBody := []byte(`{"text":"delegation result"}`)
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "completed", 50, 1,
|
||||
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
|
||||
nil, respBody,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wsID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.Len() == 0 {
|
||||
t.Error("response body is empty")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -79,18 +79,14 @@ func newTestBroadcaster() *events.Broadcaster {
|
||||
// for the duration of the test, so httptest.NewServer's loopback URLs
|
||||
// don't trip the SSRF guard. The 169.254 metadata, RFC-1918, TEST-NET,
|
||||
// CGNAT, and link-local guards stay active — only 127.0.0.0/8 and ::1
|
||||
// are relaxed. Protected by loopbackMu so concurrent tests don't race.
|
||||
// are relaxed. Always paired with t.Cleanup to restore; multiple
|
||||
// parallel tests won't race because Go test flips it sequentially per
|
||||
// test unless t.Parallel() is used, and these tests don't parallelize.
|
||||
func allowLoopbackForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
loopbackMu.Lock()
|
||||
prev := testAllowLoopback
|
||||
testAllowLoopback = true
|
||||
t.Cleanup(func() {
|
||||
loopbackMu.Lock()
|
||||
defer loopbackMu.Unlock()
|
||||
testAllowLoopback = prev
|
||||
})
|
||||
loopbackMu.Unlock()
|
||||
t.Cleanup(func() { testAllowLoopback = prev })
|
||||
}
|
||||
|
||||
// expectBudgetCheck adds the sqlmock expectation for the budget-check
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// devModeAllowsLoopback reports whether the SSRF defence should permit
|
||||
@@ -36,20 +35,13 @@ func devModeAllowsLoopback() bool {
|
||||
// loopback URLs and fake hostnames (*.example) don't trigger SSRF
|
||||
// rejections. Production code never mutates this.
|
||||
var ssrfCheckEnabled = true
|
||||
var ssrfMu sync.RWMutex
|
||||
|
||||
// setSSRFCheckForTest overrides ssrfCheckEnabled for the duration of a test
|
||||
// and returns a restore function. Use with defer in *_test.go only.
|
||||
func setSSRFCheckForTest(enabled bool) func() {
|
||||
ssrfMu.Lock()
|
||||
defer ssrfMu.Unlock()
|
||||
prev := ssrfCheckEnabled
|
||||
ssrfCheckEnabled = enabled
|
||||
return func() {
|
||||
ssrfMu.Lock()
|
||||
defer ssrfMu.Unlock()
|
||||
ssrfCheckEnabled = prev
|
||||
}
|
||||
return func() { ssrfCheckEnabled = prev }
|
||||
}
|
||||
|
||||
// isSafeURL validates that a URL resolves to a publicly-routable address,
|
||||
@@ -62,22 +54,9 @@ func setSSRFCheckForTest(enabled bool) func() {
|
||||
// the same VPC and register by their VPC-private IP. Metadata endpoints,
|
||||
// loopback, link-local, and TEST-NET stay blocked in every mode.
|
||||
func isSafeURL(rawURL string) error {
|
||||
// Capture both test-flag states under lock before any validation logic.
|
||||
// Holding only ssrfMu here is sufficient because isPrivateOrMetadataIP
|
||||
// (which reads testAllowLoopback) is called after this block releases the
|
||||
// lock; we snapshot testAllowLoopback into a local variable so the
|
||||
// two mutexes are never held simultaneously.
|
||||
ssrfMu.RLock()
|
||||
enabled := ssrfCheckEnabled
|
||||
ssrfMu.RUnlock()
|
||||
if !enabled {
|
||||
if !ssrfCheckEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
loopbackMu.RLock()
|
||||
allowLoopback := testAllowLoopback
|
||||
loopbackMu.RUnlock()
|
||||
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
@@ -90,7 +69,7 @@ func isSafeURL(rawURL string) error {
|
||||
return fmt.Errorf("empty hostname")
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if (ip.IsLoopback() && !allowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
|
||||
if (ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
|
||||
return fmt.Errorf("forbidden loopback/unspecified/link-local IP: %s", ip)
|
||||
}
|
||||
if isPrivateOrMetadataIP(ip) {
|
||||
@@ -110,7 +89,7 @@ func isSafeURL(rawURL string) error {
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
if (ip.IsLoopback() && !allowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
|
||||
if (ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
|
||||
return fmt.Errorf("hostname %s resolves to forbidden link-local/loopback IP: %s", host, ip)
|
||||
}
|
||||
if isPrivateOrMetadataIP(ip) {
|
||||
@@ -129,7 +108,6 @@ func isSafeURL(rawURL string) error {
|
||||
// The 169.254 metadata, RFC-1918, TEST-NET, CGNAT, and link-local
|
||||
// guards are NOT relaxed by this flag — only loopback.
|
||||
var testAllowLoopback = false
|
||||
var loopbackMu sync.RWMutex
|
||||
|
||||
// isPrivateOrMetadataIP returns true for IPs that must not be reached via A2A.
|
||||
//
|
||||
@@ -189,10 +167,7 @@ func isPrivateOrMetadataIP(ip net.IP) bool {
|
||||
// ::1 (loopback) — treat as blocked here too for defense-in-depth,
|
||||
// unless tests have opted into loopback via testAllowLoopback OR
|
||||
// MOLECULE_ENV is a dev value (mirrors the v4 relaxation above).
|
||||
loopbackMu.RLock()
|
||||
allowLB := testAllowLoopback
|
||||
loopbackMu.RUnlock()
|
||||
if ip.IsLoopback() && !allowLB && !devModeAllowsLoopback() {
|
||||
if ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback() {
|
||||
return true
|
||||
}
|
||||
// Link-local fe80::/10 — always blocked.
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupAbilitiesDB creates a sqlmock DB with QueryMatcherEqual for exact SQL matching.
|
||||
func setupAbilitiesDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// Exact SQL strings used by the production handler.
|
||||
const (
|
||||
sqlPatchAbilitiesExists = `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`
|
||||
sqlPatchBroadcastEnabled = `UPDATE workspaces SET broadcast_enabled = $2, updated_at = now() WHERE id = $1`
|
||||
sqlPatchTalkToUserEnabled = `UPDATE workspaces SET talk_to_user_enabled = $2, updated_at = now() WHERE id = $1`
|
||||
)
|
||||
|
||||
// ── PatchAbilities HTTP handler tests ──────────────────────────────────────────
|
||||
|
||||
func TestPatchAbilities_InvalidWorkspaceID_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/not-a-uuid/abilities", nil)
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_InvalidBody_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
|
||||
newFakeCloser([]byte("not json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_NoAbilityFields_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
|
||||
newFakeCloser([]byte(`{}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_WorkspaceNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_WorkspaceNotFound_ExistsFalse_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateBroadcastEnabled_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateTalkToUserEnabled_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, false).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateBothAbilities_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, false).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true,"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_BroadcastEnabledDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_TalkToUserEnabledDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
// newFakeCloser wraps a byte slice as an io.ReadCloser for request body injection.
|
||||
func newFakeCloser(data []byte) *fakeReadCloser {
|
||||
return &fakeReadCloser{data: data}
|
||||
}
|
||||
|
||||
type fakeReadCloser struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (f *fakeReadCloser) Read(p []byte) (n int, err error) {
|
||||
if f.pos >= len(f.data) {
|
||||
return 0, nil
|
||||
}
|
||||
n = copy(p, f.data[f.pos:])
|
||||
f.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (*fakeReadCloser) Close() error { return nil }
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// workspace_broadcast.go — POST /workspaces/:id/broadcast
|
||||
//
|
||||
// Allows a workspace with broadcast_enabled=true to send a message to every
|
||||
// non-removed agent workspace in the SAME ORG. The message is:
|
||||
// non-removed agent workspace in the org. The message is:
|
||||
//
|
||||
// • Persisted in each recipient's activity_logs (type='broadcast_receive')
|
||||
// so poll-mode agents pick it up via GET /activity.
|
||||
@@ -16,11 +16,6 @@ package handlers
|
||||
// Auth: WorkspaceAuth (the agent triggers this with its own bearer token).
|
||||
// The handler re-validates broadcast_enabled inside the DB lookup to prevent
|
||||
// TOCTOU — the middleware only proved the token is valid, not the ability.
|
||||
//
|
||||
// Org isolation (OFFSEC-015): recipients are scoped to the sender's org using
|
||||
// a recursive CTE that walks the parent_id chain to find the org root. This
|
||||
// prevents a compromised or misconfigured workspace from broadcasting to
|
||||
// workspaces in other tenants' orgs.
|
||||
|
||||
import (
|
||||
"log"
|
||||
@@ -79,49 +74,11 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the sender's org root by walking the parent_id chain.
|
||||
// Workspaces with parent_id = NULL are org roots; every other workspace
|
||||
// belongs to the org identified by its topmost ancestor.
|
||||
var orgRootID string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
WITH RECURSIVE org_chain AS (
|
||||
SELECT id, parent_id, id AS root_id
|
||||
FROM workspaces
|
||||
WHERE id = $1
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id, c.root_id
|
||||
FROM workspaces w
|
||||
JOIN org_chain c ON w.id = c.parent_id
|
||||
)
|
||||
SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1
|
||||
`, senderID).Scan(&orgRootID)
|
||||
if err != nil {
|
||||
log.Printf("Broadcast: org root lookup for %s: %v", senderID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Collect all non-removed agent workspaces in the SAME ORG (same root_id),
|
||||
// excluding the sender itself.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
WITH RECURSIVE org_chain AS (
|
||||
SELECT id, parent_id, id AS root_id
|
||||
FROM workspaces
|
||||
WHERE parent_id IS NULL
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id, c.root_id
|
||||
FROM workspaces w
|
||||
JOIN org_chain c ON w.parent_id = c.id
|
||||
)
|
||||
SELECT c.id
|
||||
FROM org_chain c
|
||||
WHERE c.root_id = $1
|
||||
AND c.id != $2
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM workspaces w
|
||||
WHERE w.id = c.id AND w.status != 'removed'
|
||||
)
|
||||
`, orgRootID, senderID)
|
||||
// Collect all non-removed agent workspaces (excludes the sender itself).
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`,
|
||||
senderID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Broadcast: recipient query failed for %s: %v", senderID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
|
||||
@@ -1,428 +1,403 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// -------- Org-scoped recipient query tests (OFFSEC-015) --------
|
||||
// broadcastBody is a convenience that returns an io.ReadCloser wrapping JSON body.
|
||||
func broadcastBody(body string) io.ReadCloser {
|
||||
return &broadcastFakeCloser{data: []byte(body)}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScopedRecipients verifies that a broadcast from Org-A does
|
||||
// NOT reach workspaces belonging to Org-B. This is the core regression test
|
||||
// for OFFSEC-015: the original query had no org filter, so a workspace in
|
||||
// Org-A could broadcast to every non-removed workspace in the entire DB,
|
||||
// including workspaces owned by other tenants.
|
||||
func TestBroadcast_OrgScopedRecipients(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
type broadcastFakeCloser struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
// Org-A structure:
|
||||
// org-a-root (parent_id = NULL) ← sender
|
||||
// ├── ws-a-child
|
||||
// Org-B structure:
|
||||
// org-b-root (parent_id = NULL)
|
||||
// └── ws-b-child
|
||||
senderID := "00000000-0000-0000-0000-000000000001" // org-a-root
|
||||
wsAChild := "00000000-0000-0000-0000-000000000002"
|
||||
// ws-b-child is in Org-B (different root); the org-scoped query MUST NOT include it.
|
||||
func (f *broadcastFakeCloser) Read(p []byte) (n int, err error) {
|
||||
if f.pos >= len(f.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(p, f.data[f.pos:])
|
||||
f.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 1. Sender lookup
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Org-A Root", true))
|
||||
func (*broadcastFakeCloser) Close() error { return nil }
|
||||
|
||||
// 2. Org root lookup — sender is its own root (parent_id = NULL)
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
// setupBroadcastDB creates a sqlmock DB with QueryMatcherEqual.
|
||||
func setupBroadcastDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// 3. Org-scoped recipient query — MUST include org filter so ws-b-child is NOT included.
|
||||
// The query joins on org_chain.root_id = orgRootID, which scopes to Org-A only.
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID, senderID). // orgRootID, senderID (EXCLUDED)
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(wsAChild)) // only Org-A child
|
||||
// Exact SQL strings from the production handler (whitespace must match verbatim).
|
||||
const (
|
||||
sqlBroadcastWorkspaceLookup = `SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`
|
||||
sqlBroadcastRecipients = `SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`
|
||||
sqlBroadcastReceiveInsert = `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status)
|
||||
VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')`
|
||||
sqlBroadcastSentInsert = `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status)
|
||||
VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')`
|
||||
)
|
||||
|
||||
// Activity log inserts
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(wsAChild, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// ── Broadcast HTTP handler tests ───────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_InvalidWorkspaceID_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/not-a-uuid/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"hello from org-a"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["status"] != "sent" {
|
||||
t.Errorf("expected status 'sent', got %v", resp["status"])
|
||||
}
|
||||
// ws-b-child is in a DIFFERENT org — the org-scoped query MUST NOT include it.
|
||||
// If it were included, the mock would have an unmet expectation.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet mock expectations — cross-org workspace was included in broadcast: %v", err)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_OrgRootSender verifies that when the sender IS the
|
||||
// org root (parent_id = NULL), broadcasts still reach sibling workspaces.
|
||||
func TestBroadcast_OrgScoped_OrgRootSender(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
func TestBroadcast_MissingMessage_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001" // org-a-root
|
||||
siblingID := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
|
||||
|
||||
// Sender is the org root — CTE returns sender's own ID as root
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
// Recipients in same org, excluding sender
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
|
||||
broadcastBody(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"hello siblings"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_ChildWorkspaceSender verifies that a non-root child
|
||||
// workspace can broadcast to siblings in the same org.
|
||||
func TestBroadcast_OrgScoped_ChildWorkspaceSender(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
func TestBroadcast_EmptyMessage_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
orgRootID := "00000000-0000-0000-0000-000000000001"
|
||||
senderID := "00000000-0000-0000-0000-000000000002" // child workspace
|
||||
siblingID := "00000000-0000-0000-0000-000000000003"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Child Agent", true))
|
||||
|
||||
// Org root lookup — walk up to find org-a-root
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(orgRootID))
|
||||
|
||||
// Recipients: same org, excluding sender
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(orgRootID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
|
||||
broadcastBody(`{"message":""}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"child broadcasting"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// -------- Non-regression cases --------
|
||||
func TestBroadcast_WorkspaceNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
func TestBroadcast_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000099"
|
||||
// UUID is valid, but no workspace row matches
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnError(errors.New("workspace not found"))
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"test"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_Disabled(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
func TestBroadcast_BroadcastDisabled_Returns403(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001"
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Disabled Agent", false))
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("test-workspace", false))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"should not send"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403, got %d: %s", w.Code, w.Body.String())
|
||||
t.Errorf("got %d, want 403: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["error"] != "broadcast_disabled" {
|
||||
t.Errorf("expected error 'broadcast_disabled', got %v", resp["error"])
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_EmptyOrg_NoRecipients(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
func TestBroadcast_NoRecipients_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001" // org root, only workspace in org
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Lone Root", true))
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("test-workspace", true))
|
||||
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
// No other workspaces in this org
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID, senderID).
|
||||
// No recipients (sender is the only non-removed workspace)
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// Sender's own activity log: 2 args (workspaceID, summary)
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello everyone"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"hello org"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["delivered"] != float64(0) {
|
||||
t.Errorf("expected delivered=0, got %v", resp["delivered"])
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_InvalidWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
func TestBroadcast_WithRecipients_Success_DeliversToAll(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(recipient1).
|
||||
AddRow(recipient2))
|
||||
|
||||
// broadcast_receive: 3 args (recipientID, senderID, summary)
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// broadcast_sent: 2 args (workspaceID, summary)
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello team"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
|
||||
body := `{"message":"test"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/not-a-uuid/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_MissingMessage(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000001"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/00000000-0000-0000-0000-000000000001/broadcast", bytes.NewBufferString("{}"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgRootLookupFails verifies that if the recursive CTE for
|
||||
// finding the org root errors, the handler returns 500 instead of proceeding
|
||||
// with an un-scoped query that would broadcast to all orgs.
|
||||
func TestBroadcast_OrgRootLookupFails(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
|
||||
|
||||
// Org root CTE fails
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"should not broadcast"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
// The recipient query MUST NOT be called — it would broadcast cross-org
|
||||
// if the org root lookup failed silently.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_OrgScoped_SelfBroadcastExcluded verifies that broadcasting
|
||||
// from a workspace does not send a broadcast_receive to the sender itself
|
||||
// (the sender logs broadcast_sent, not broadcast_receive).
|
||||
func TestBroadcast_OrgScoped_SelfBroadcastExcluded(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewBroadcastHandler(broadcaster)
|
||||
|
||||
senderID := "00000000-0000-0000-0000-000000000001"
|
||||
peerID := "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
|
||||
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
|
||||
|
||||
// Recipient query MUST exclude sender via id != senderID
|
||||
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
|
||||
WithArgs(senderID, senderID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(peerID))
|
||||
|
||||
// Peer receives broadcast_receive
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(peerID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// Sender logs broadcast_sent (NOT broadcast_receive)
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: senderID}}
|
||||
body := `{"message":"no echo to self"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Broadcast(c)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBroadcast_Truncate tests that messages are truncated with the Unicode ellipsis
|
||||
// TestBroadcast_Truncate tests that messages are truncated with the Unicode ellipsis
|
||||
// character (U+2026) when len(msg) > max. The truncated output is max runes + "…",
|
||||
// so truncating a 48-char string at max=20 produces 21 characters (20 runes + "…").
|
||||
func TestBroadcast_Truncate(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
max int
|
||||
expect string
|
||||
}{
|
||||
{"short", 120, "short"}, // under max — no truncation
|
||||
// exactly120chars (15) + 105 ones = 120 chars; at max=120 → unchanged
|
||||
{"exactly120chars1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111", 120, "exactly120chars111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111…"},
|
||||
// "this is a longer mes" = 20 runes; + "…" = 21 chars
|
||||
{"this is a longer message that needs truncating", 20, "this is a longer mes…"},
|
||||
// at-max boundary: 20 chars at max=20 → no truncation
|
||||
{"exactly twenty chars", 20, "exactly twenty chars"},
|
||||
// over max: 11 chars at max=10 → 10 + "…" = 11
|
||||
{"hello world!", 10, "hello worl…"},
|
||||
func TestBroadcast_RecipientInsertError_ContinuesAndSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(recipient1).
|
||||
AddRow(recipient2))
|
||||
|
||||
// First recipient insert fails — handler logs and continues
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
// Second recipient succeeds
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"partial delivery"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
for _, tc := range cases {
|
||||
result := broadcastTruncate(tc.msg, tc.max)
|
||||
if result != tc.expect {
|
||||
t.Errorf("broadcastTruncate(%q, %d) = %q; want %q", tc.msg, tc.max, result, tc.expect)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_SenderActivityLogError_StillReturns200(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Handler logs error but still returns 200
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── broadcastTruncate pure function tests ─────────────────────────────────────
|
||||
|
||||
func TestBroadcastTruncate_UnderLimit(t *testing.T) {
|
||||
input := "short message"
|
||||
got := broadcastTruncate(input, 50)
|
||||
if got != input {
|
||||
t.Errorf("broadcastTruncate(%q, 50) = %q, want %q", input, got, input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_ExactlyAtLimit(t *testing.T) {
|
||||
input := "exactly fifty char"
|
||||
got := broadcastTruncate(input, 18)
|
||||
if got != input {
|
||||
t.Errorf("broadcastTruncate(%q, 18) = %q, want %q", input, got, input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_OverLimit_TruncatesAndAddsEllipsis(t *testing.T) {
|
||||
// 150 ASCII chars → over 120 rune limit → truncate to 120 + ellipsis
|
||||
input := strings.Repeat("x", 150)
|
||||
got := broadcastTruncate(input, 120)
|
||||
if len([]rune(got)) != 121 { // 120 + 1 ellipsis rune
|
||||
t.Errorf("len(broadcastTruncate) = %d, want 121 (120 + ellipsis)", len([]rune(got)))
|
||||
}
|
||||
if got[:len(got)-len("…")] != strings.Repeat("x", 120) {
|
||||
t.Errorf("broadcastTruncate did not truncate correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_UnicodeChars_TreatsAsRunes(t *testing.T) {
|
||||
// Each emoji is 1 rune but multiple bytes. 50 emojis > 30 limit.
|
||||
input := strings.Repeat("🎉", 50)
|
||||
got := broadcastTruncate(input, 30)
|
||||
if len([]rune(got)) != 31 { // 30 + ellipsis
|
||||
t.Errorf("len(broadcastTruncate with emoji) = %d, want 31", len([]rune(got)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_ZeroLimit_ReturnsEllipsis(t *testing.T) {
|
||||
got := broadcastTruncate("hello", 0)
|
||||
if got != "…" {
|
||||
t.Errorf("broadcastTruncate with max=0 = %q, want …", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user