Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0262e59c60 |
@@ -1,408 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupQueueStatusHandlerDB creates a sqlmock DB with QueryMatcherEqual for exact SQL string matching.
|
||||
func setupQueueStatusHandlerDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// Exact SQL strings used by the production code.
|
||||
const (
|
||||
sqlQueueRowAuthFields = `SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`
|
||||
sqlQueueStatusByID = `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
q.status,
|
||||
q.priority,
|
||||
q.attempts,
|
||||
q.last_error,
|
||||
q.enqueued_at::text,
|
||||
q.dispatched_at::text,
|
||||
q.completed_at::text,
|
||||
q.expires_at::text,
|
||||
al.response_body::text
|
||||
FROM a2a_queue q
|
||||
LEFT JOIN activity_logs al
|
||||
ON al.method = 'delegate_result'
|
||||
AND al.target_id = q.workspace_id
|
||||
AND al.workspace_id = q.caller_id
|
||||
AND al.response_body->>'delegation_id' = (q.body->'params'->'message'->'metadata'->>'delegation_id')
|
||||
WHERE q.id = $1`
|
||||
)
|
||||
|
||||
// ── GetA2AQueueStatus HTTP handler tests ──────────────────────────────────────
|
||||
|
||||
// TestGetA2AQueueStatus_QueueIDEmpty_Returns400 exercises the handler directly
|
||||
// (not via router) so we can verify the empty-value branch without relying on
|
||||
// Gin route-matching behaviour.
|
||||
func TestGetA2AQueueStatus_QueueIDEmpty_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
// queue_id param is empty string
|
||||
c.Params = gin.Params{
|
||||
{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"},
|
||||
{Key: "queue_id", Value: ""},
|
||||
}
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h.GetA2AQueueStatus(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_NoIdentity_NoOrgToken_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/wsid/a2a/queue/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// No identity derivable → 404 (not 401) per existence-non-inference policy.
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_OrgToken_SkipsCallerCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow("other-ws", wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "queued", 50, 0,
|
||||
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
// Simulate org-token middleware setting org_token_id.
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", func(c *gin.Context) {
|
||||
c.Set("org_token_id", "org-admin")
|
||||
h.GetA2AQueueStatus(c)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/workspaces/wsid/a2a/queue/"+queueID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_CallerWorkspaceMatchesCallerID_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "completed", 50, 1,
|
||||
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
|
||||
nil, []byte(`{"text":"result"}`),
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_CallerWorkspaceMatchesWorkspaceID_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "queued", 50, 0,
|
||||
nil, "2026-01-01T00:00:00Z", nil, nil, nil, nil,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wsID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_QueueNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_QueueAuthFieldsDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_WrongCallerWorkspace_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
wrongCaller := "dddddddd-dddd-dddd-dddd-dddddddddddd"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wrongCaller)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_StatusFetchDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", callerID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetA2AQueueStatus_FullHappyPath_ReturnsJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupQueueStatusHandlerDB(t)
|
||||
h := &WorkspaceHandler{}
|
||||
|
||||
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
authRows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
|
||||
AddRow(callerID, wsID)
|
||||
mock.ExpectQuery(sqlQueueRowAuthFields).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(authRows)
|
||||
|
||||
respBody := []byte(`{"text":"delegation result"}`)
|
||||
statusRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "status", "priority", "attempts",
|
||||
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
|
||||
"response_body",
|
||||
}).AddRow(
|
||||
queueID, wsID, "completed", 50, 1,
|
||||
nil, "2026-01-01T00:00:00Z", "2026-01-01T00:01:00Z", "2026-01-01T00:02:00Z",
|
||||
nil, respBody,
|
||||
)
|
||||
mock.ExpectQuery(sqlQueueStatusByID).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(statusRows)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/workspaces/:id/a2a/queue/:queue_id", h.GetA2AQueueStatus)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/workspaces/"+wsID+"/a2a/queue/"+queueID, nil)
|
||||
req.Header.Set("X-Workspace-ID", wsID)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if w.Body.Len() == 0 {
|
||||
t.Error("response body is empty")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupAbilitiesDB creates a sqlmock DB with QueryMatcherEqual for exact SQL matching.
|
||||
func setupAbilitiesDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// Exact SQL strings used by the production handler.
|
||||
const (
|
||||
sqlPatchAbilitiesExists = `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`
|
||||
sqlPatchBroadcastEnabled = `UPDATE workspaces SET broadcast_enabled = $2, updated_at = now() WHERE id = $1`
|
||||
sqlPatchTalkToUserEnabled = `UPDATE workspaces SET talk_to_user_enabled = $2, updated_at = now() WHERE id = $1`
|
||||
)
|
||||
|
||||
// ── PatchAbilities HTTP handler tests ──────────────────────────────────────────
|
||||
|
||||
func TestPatchAbilities_InvalidWorkspaceID_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/not-a-uuid/abilities", nil)
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_InvalidBody_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
|
||||
newFakeCloser([]byte("not json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_NoAbilityFields_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupAbilitiesDB(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/abilities",
|
||||
newFakeCloser([]byte(`{}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_WorkspaceNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_WorkspaceNotFound_ExistsFalse_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateBroadcastEnabled_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateTalkToUserEnabled_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, false).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_UpdateBothAbilities_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, false).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true,"talk_to_user_enabled":false}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_BroadcastEnabledDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchBroadcastEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"broadcast_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchAbilities_TalkToUserEnabledDBError_Returns500(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupAbilitiesDB(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlPatchAbilitiesExists).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
mock.ExpectExec(sqlPatchTalkToUserEnabled).
|
||||
WithArgs(wsID, true).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest(http.MethodPatch, "/workspaces/"+wsID+"/abilities",
|
||||
newFakeCloser([]byte(`{"talk_to_user_enabled":true}`)))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PatchAbilities(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("got %d, want 500: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────────────────────
|
||||
|
||||
// newFakeCloser wraps a byte slice as an io.ReadCloser for request body injection.
|
||||
func newFakeCloser(data []byte) *fakeReadCloser {
|
||||
return &fakeReadCloser{data: data}
|
||||
}
|
||||
|
||||
type fakeReadCloser struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (f *fakeReadCloser) Read(p []byte) (n int, err error) {
|
||||
if f.pos >= len(f.data) {
|
||||
return 0, nil
|
||||
}
|
||||
n = copy(p, f.data[f.pos:])
|
||||
f.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (*fakeReadCloser) Close() error { return nil }
|
||||
@@ -1,403 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// broadcastBody is a convenience that returns an io.ReadCloser wrapping JSON body.
|
||||
func broadcastBody(body string) io.ReadCloser {
|
||||
return &broadcastFakeCloser{data: []byte(body)}
|
||||
}
|
||||
|
||||
type broadcastFakeCloser struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (f *broadcastFakeCloser) Read(p []byte) (n int, err error) {
|
||||
if f.pos >= len(f.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n = copy(p, f.data[f.pos:])
|
||||
f.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (*broadcastFakeCloser) Close() error { return nil }
|
||||
|
||||
// setupBroadcastDB creates a sqlmock DB with QueryMatcherEqual.
|
||||
func setupBroadcastDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
// Exact SQL strings from the production handler (whitespace must match verbatim).
|
||||
const (
|
||||
sqlBroadcastWorkspaceLookup = `SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`
|
||||
sqlBroadcastRecipients = `SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`
|
||||
sqlBroadcastReceiveInsert = `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status)
|
||||
VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')`
|
||||
sqlBroadcastSentInsert = `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status)
|
||||
VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')`
|
||||
)
|
||||
|
||||
// ── Broadcast HTTP handler tests ───────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_InvalidWorkspaceID_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/not-a-uuid/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_MissingMessage_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
|
||||
broadcastBody(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_EmptyMessage_Returns400(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost,
|
||||
"/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/broadcast",
|
||||
broadcastBody(`{"message":""}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("got %d, want 400: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_WorkspaceNotFound_Returns404(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("got %d, want 404: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_BroadcastDisabled_Returns403(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("test-workspace", false))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("got %d, want 403: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_NoRecipients_Success(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("test-workspace", true))
|
||||
|
||||
// No recipients (sender is the only non-removed workspace)
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
// Sender's own activity log: 2 args (workspaceID, summary)
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello everyone"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_WithRecipients_Success_DeliversToAll(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(recipient1).
|
||||
AddRow(recipient2))
|
||||
|
||||
// broadcast_receive: 3 args (recipientID, senderID, summary)
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// broadcast_sent: 2 args (workspaceID, summary)
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello team"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_RecipientInsertError_ContinuesAndSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
recipient1 := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
|
||||
recipient2 := "cccccccc-cccc-cccc-cccc-cccccccccccc"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(recipient1).
|
||||
AddRow(recipient2))
|
||||
|
||||
// First recipient insert fails — handler logs and continues
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient1, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
// Second recipient succeeds
|
||||
mock.ExpectExec(sqlBroadcastReceiveInsert).
|
||||
WithArgs(recipient2, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"partial delivery"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_SenderActivityLogError_StillReturns200(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupBroadcastDB(t)
|
||||
h := NewBroadcastHandler(newTestBroadcaster())
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastWorkspaceLookup).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).
|
||||
AddRow("broadcaster-ws", true))
|
||||
|
||||
mock.ExpectQuery(sqlBroadcastRecipients).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
|
||||
mock.ExpectExec(sqlBroadcastSentInsert).
|
||||
WithArgs(wsID, sqlmock.AnyArg()).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/workspaces/:id/broadcast", h.Broadcast)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/broadcast",
|
||||
broadcastBody(`{"message":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Handler logs error but still returns 200
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("got %d, want 200: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── broadcastTruncate pure function tests ─────────────────────────────────────
|
||||
|
||||
func TestBroadcastTruncate_UnderLimit(t *testing.T) {
|
||||
input := "short message"
|
||||
got := broadcastTruncate(input, 50)
|
||||
if got != input {
|
||||
t.Errorf("broadcastTruncate(%q, 50) = %q, want %q", input, got, input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_ExactlyAtLimit(t *testing.T) {
|
||||
input := "exactly fifty char"
|
||||
got := broadcastTruncate(input, 18)
|
||||
if got != input {
|
||||
t.Errorf("broadcastTruncate(%q, 18) = %q, want %q", input, got, input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_OverLimit_TruncatesAndAddsEllipsis(t *testing.T) {
|
||||
// 150 ASCII chars → over 120 rune limit → truncate to 120 + ellipsis
|
||||
input := strings.Repeat("x", 150)
|
||||
got := broadcastTruncate(input, 120)
|
||||
if len([]rune(got)) != 121 { // 120 + 1 ellipsis rune
|
||||
t.Errorf("len(broadcastTruncate) = %d, want 121 (120 + ellipsis)", len([]rune(got)))
|
||||
}
|
||||
if got[:len(got)-len("…")] != strings.Repeat("x", 120) {
|
||||
t.Errorf("broadcastTruncate did not truncate correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_UnicodeChars_TreatsAsRunes(t *testing.T) {
|
||||
// Each emoji is 1 rune but multiple bytes. 50 emojis > 30 limit.
|
||||
input := strings.Repeat("🎉", 50)
|
||||
got := broadcastTruncate(input, 30)
|
||||
if len([]rune(got)) != 31 { // 30 + ellipsis
|
||||
t.Errorf("len(broadcastTruncate with emoji) = %d, want 31", len([]rune(got)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastTruncate_ZeroLimit_ReturnsEllipsis(t *testing.T) {
|
||||
got := broadcastTruncate("hello", 0)
|
||||
if got != "…" {
|
||||
t.Errorf("broadcastTruncate with max=0 = %q, want …", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,432 @@
|
||||
"""BaseAdapter coverage gap tests — fills uncovered branches in adapter_base.py.
|
||||
|
||||
Covers:
|
||||
- resolve_provider_routing(): all URL-precedence branches + unknown prefix
|
||||
- RuntimeCapabilities.to_dict(): all flag combinations
|
||||
- BaseAdapter.capabilities(): returns RuntimeCapabilities() (platform-owns-everything)
|
||||
- BaseAdapter.idle_timeout_override(): returns None (use platform default)
|
||||
- BaseAdapter.get_config_schema(): returns {} (override per-subclass)
|
||||
- BaseAdapter.memory_filename(): returns "CLAUDE.md"
|
||||
- BaseAdapter.register_tool_hook(): no-op (override for dynamic registry)
|
||||
- BaseAdapter.register_subagent_hook(): no-op (override for DeepAgents)
|
||||
- BaseAdapter.transcript_lines(): returns supported=False dict
|
||||
- BaseAdapter.append_to_memory_hook(): idempotent append, marker deduplication
|
||||
- BaseAdapter.pre_stop_state(): captures session_id from executor + transcript_lines
|
||||
- BaseAdapter.restore_state(): stores session_id + transcript_lines from snapshot
|
||||
- BaseAdapter.inject_plugins(): delegates to install_plugins_via_registry
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
WORKSPACE_DIR = Path(__file__).parent.parent
|
||||
if str(WORKSPACE_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(WORKSPACE_DIR))
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor
|
||||
|
||||
from adapter_base import (
|
||||
AdapterConfig,
|
||||
BaseAdapter,
|
||||
ProviderRegistry,
|
||||
RuntimeCapabilities,
|
||||
resolve_provider_routing,
|
||||
)
|
||||
|
||||
|
||||
class _StubAdapter(BaseAdapter):
|
||||
"""Minimal concrete adapter for testing base-class default behaviour."""
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "stub"
|
||||
|
||||
@staticmethod
|
||||
def display_name() -> str:
|
||||
return "Stub"
|
||||
|
||||
@staticmethod
|
||||
def description() -> str:
|
||||
return "test stub"
|
||||
|
||||
async def setup(self, config: AdapterConfig) -> None:
|
||||
return None
|
||||
|
||||
async def create_executor(self, config: AdapterConfig) -> AgentExecutor: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_provider_routing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resolve_provider_routing_parses_prefix_and_model():
|
||||
"""'anthropic:claude-sonnet-4-6' splits into prefix + bare model."""
|
||||
api_key, base_url, model_id = resolve_provider_routing(
|
||||
"anthropic:claude-sonnet-4-6",
|
||||
{"ANTHROPIC_API_KEY": "sk-ant-test"},
|
||||
registry={"anthropic": (("ANTHROPIC_API_KEY",), "https://api.anthropic.com")},
|
||||
)
|
||||
assert api_key == "sk-ant-test"
|
||||
assert base_url == "https://api.anthropic.com"
|
||||
assert model_id == "claude-sonnet-4-6"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_falls_back_to_openai():
|
||||
"""Bare model without colon defaults to openai prefix."""
|
||||
api_key, base_url, model_id = resolve_provider_routing(
|
||||
"gpt-4o",
|
||||
{"OPENAI_API_KEY": "sk-openai-test"},
|
||||
registry={},
|
||||
)
|
||||
assert api_key == "sk-openai-test"
|
||||
assert base_url == "https://api.openai.com/v1"
|
||||
assert model_id == "gpt-4o"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_url_from_env_var():
|
||||
"""PREFIX_BASE_URL env var takes precedence over registry default."""
|
||||
env = {
|
||||
"OPENAI_API_KEY": "sk-test",
|
||||
"OPENAI_BASE_URL": "https://my-proxy.example.com/v1",
|
||||
}
|
||||
api_key, base_url, model_id = resolve_provider_routing(
|
||||
"openai:gpt-4o", env, registry={}
|
||||
)
|
||||
assert base_url == "https://my-proxy.example.com/v1"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_url_from_runtime_config():
|
||||
"""runtime_config['provider_url'] takes precedence over registry default."""
|
||||
env = {"OPENAI_API_KEY": "sk-test"}
|
||||
api_key, base_url, model_id = resolve_provider_routing(
|
||||
"openai:gpt-4o",
|
||||
env,
|
||||
registry={},
|
||||
runtime_config={"provider_url": "https://config-proxy.example.com/v1"},
|
||||
)
|
||||
assert base_url == "https://config-proxy.example.com/v1"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_env_overrides_runtime_config():
|
||||
"""env var PREFIX_BASE_URL wins over runtime_config['provider_url']."""
|
||||
env = {
|
||||
"OPENAI_API_KEY": "sk-test",
|
||||
"OPENAI_BASE_URL": "https://env-proxy.example.com/v1",
|
||||
}
|
||||
_, base_url, _ = resolve_provider_routing(
|
||||
"openai:gpt-4o",
|
||||
env,
|
||||
registry={},
|
||||
runtime_config={"provider_url": "https://config-proxy.example.com/v1"},
|
||||
)
|
||||
assert base_url == "https://env-proxy.example.com/v1"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_falls_back_to_openai_on_unknown_prefix():
|
||||
"""Unknown provider prefix falls back to OPENAI_API_KEY + openai.com."""
|
||||
env = {"OPENAI_API_KEY": "sk-fallback"}
|
||||
api_key, base_url, model_id = resolve_provider_routing(
|
||||
"unknown:some-model", env, registry={}
|
||||
)
|
||||
assert api_key == "sk-fallback"
|
||||
assert base_url == "https://api.openai.com/v1"
|
||||
assert model_id == "some-model"
|
||||
|
||||
|
||||
def test_resolve_provider_routing_raises_when_no_api_key():
|
||||
"""RuntimeError raised when no API key env var is set for the prefix."""
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
resolve_provider_routing(
|
||||
"anthropic:claude-sonnet-4-6",
|
||||
{}, # empty env — no ANTHROPIC_API_KEY
|
||||
registry={"anthropic": (("ANTHROPIC_API_KEY",), "https://api.anthropic.com")},
|
||||
)
|
||||
assert "No API key found" in str(exc_info.value)
|
||||
assert "anthropic" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_resolve_provider_routing_multiple_env_vars_first_found():
|
||||
"""registry tuple with multiple env vars — first present in env is used."""
|
||||
env = {
|
||||
# ANTHROPIC_API_KEY not set; ANTHROPIC_SECONDARY_KEY is
|
||||
"ANTHROPIC_SECONDARY_KEY": "sk-secondary",
|
||||
}
|
||||
api_key, _, _ = resolve_provider_routing(
|
||||
"anthropic:claude-sonnet-4-6",
|
||||
env,
|
||||
registry={"anthropic": (("ANTHROPIC_API_KEY", "ANTHROPIC_SECONDARY_KEY"), "https://api.anthropic.com")},
|
||||
)
|
||||
assert api_key == "sk-secondary"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RuntimeCapabilities tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_runtime_capabilities_to_dict_all_defaults():
|
||||
"""All flags default to False."""
|
||||
caps = RuntimeCapabilities()
|
||||
d = caps.to_dict()
|
||||
assert d == {
|
||||
"heartbeat": False,
|
||||
"scheduler": False,
|
||||
"session": False,
|
||||
"status_mgmt": False,
|
||||
"retry": False,
|
||||
"activity_decoration": False,
|
||||
"channel_dispatch": False,
|
||||
}
|
||||
|
||||
|
||||
def test_runtime_capabilities_to_dict_all_true():
|
||||
"""All flags can be set to True."""
|
||||
caps = RuntimeCapabilities(
|
||||
provides_native_heartbeat=True,
|
||||
provides_native_scheduler=True,
|
||||
provides_native_session=True,
|
||||
provides_native_status_mgmt=True,
|
||||
provides_native_retry=True,
|
||||
provides_activity_decoration=True,
|
||||
provides_channel_dispatch=True,
|
||||
)
|
||||
d = caps.to_dict()
|
||||
assert all(v is True for v in d.values())
|
||||
|
||||
|
||||
def test_runtime_capabilities_partial_flags():
|
||||
"""Partial flag set — only heartbeat and session True."""
|
||||
caps = RuntimeCapabilities(
|
||||
provides_native_heartbeat=True,
|
||||
provides_native_session=True,
|
||||
)
|
||||
d = caps.to_dict()
|
||||
assert d["heartbeat"] is True
|
||||
assert d["session"] is True
|
||||
assert d["scheduler"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseAdapter method default behaviour tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_capabilities_returns_empty_runtime_capabilities():
|
||||
"""Default capabilities() returns RuntimeCapabilities() with all flags off."""
|
||||
adapter = _StubAdapter()
|
||||
caps = adapter.capabilities()
|
||||
assert isinstance(caps, RuntimeCapabilities)
|
||||
d = caps.to_dict()
|
||||
assert all(v is False for v in d.values())
|
||||
|
||||
|
||||
def test_idle_timeout_override_returns_none():
|
||||
"""Default idle_timeout_override() returns None — use platform default."""
|
||||
adapter = _StubAdapter()
|
||||
assert adapter.idle_timeout_override() is None
|
||||
|
||||
|
||||
def test_get_config_schema_returns_empty_dict():
|
||||
"""Default get_config_schema() returns {} — override per-subclass."""
|
||||
adapter = _StubAdapter()
|
||||
assert adapter.get_config_schema() == {}
|
||||
|
||||
|
||||
def test_memory_filename_returns_claude_md():
|
||||
"""Default memory_filename() returns 'CLAUDE.md'."""
|
||||
adapter = _StubAdapter()
|
||||
assert adapter.memory_filename() == "CLAUDE.md"
|
||||
|
||||
|
||||
def test_register_tool_hook_returns_none():
|
||||
"""Default register_tool_hook() is a no-op that returns None."""
|
||||
adapter = _StubAdapter()
|
||||
result = adapter.register_tool_hook("some-plugin", MagicMock())
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_register_subagent_hook_returns_none():
|
||||
"""Default register_subagent_hook() is a no-op that returns None."""
|
||||
adapter = _StubAdapter()
|
||||
result = adapter.register_subagent_hook("deep-agent", {"name": "agent"})
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_lines_returns_unsupported():
|
||||
"""Default transcript_lines() returns supported=False (runtime doesn't expose a log)."""
|
||||
adapter = _StubAdapter()
|
||||
result = await adapter.transcript_lines(since=10, limit=50)
|
||||
assert result["supported"] is False
|
||||
assert result["lines"] == []
|
||||
assert result["cursor"] == 10 # preserved from since arg
|
||||
assert result["more"] is False
|
||||
assert result["source"] is None
|
||||
assert result["runtime"] == "stub"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# append_to_memory_hook tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_append_to_memory_hook_creates_new_file():
|
||||
"""append_to_memory_hook creates the target file if it doesn't exist."""
|
||||
adapter = _StubAdapter()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AdapterConfig(model="test", config_path=tmpdir)
|
||||
content = "# Plugin: test-plugin\nsome content"
|
||||
adapter.append_to_memory_hook(config, "CLAUDE.md", content)
|
||||
|
||||
path = os.path.join(tmpdir, "CLAUDE.md")
|
||||
assert os.path.exists(path)
|
||||
with open(path) as f:
|
||||
assert content in f.read()
|
||||
|
||||
|
||||
def test_append_to_memory_hook_idempotent_with_marker():
|
||||
"""Second append with same marker is skipped (idempotent)."""
|
||||
adapter = _StubAdapter()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AdapterConfig(model="test", config_path=tmpdir)
|
||||
marker_content = "# Plugin: test-plugin\nsome content"
|
||||
|
||||
adapter.append_to_memory_hook(config, "CLAUDE.md", marker_content)
|
||||
adapter.append_to_memory_hook(config, "CLAUDE.md", marker_content)
|
||||
|
||||
path = os.path.join(tmpdir, "CLAUDE.md")
|
||||
with open(path) as f:
|
||||
text = f.read()
|
||||
# Should appear only once (second append skipped)
|
||||
lines = [l for l in text.splitlines() if l.startswith("# Plugin: test-plugin")]
|
||||
assert len(lines) == 1
|
||||
|
||||
|
||||
def test_append_to_memory_hook_appends_without_marker():
|
||||
"""Appends when the marker line is not present (no deduplication needed)."""
|
||||
adapter = _StubAdapter()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AdapterConfig(model="test", config_path=tmpdir)
|
||||
|
||||
adapter.append_to_memory_hook(config, "CLAUDE.md", "# First plugin\ncontent A")
|
||||
adapter.append_to_memory_hook(config, "CLAUDE.md", "# Second plugin\ncontent B")
|
||||
|
||||
path = os.path.join(tmpdir, "CLAUDE.md")
|
||||
with open(path) as f:
|
||||
text = f.read()
|
||||
assert "# First plugin" in text
|
||||
assert "# Second plugin" in text
|
||||
|
||||
|
||||
def test_append_to_memory_hook_creates_parent_dirs():
|
||||
"""append_to_memory_hook creates intermediate directories."""
|
||||
adapter = _StubAdapter()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config = AdapterConfig(model="test", config_path=tmpdir)
|
||||
adapter.append_to_memory_hook(config, "subdir/CLAUDE.md", "# Nested")
|
||||
|
||||
path = os.path.join(tmpdir, "subdir", "CLAUDE.md")
|
||||
assert os.path.exists(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pre_stop_state tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_pre_stop_state_empty_when_no_executor():
|
||||
"""pre_stop_state returns {} when no _executor is attached."""
|
||||
adapter = _StubAdapter()
|
||||
state = adapter.pre_stop_state()
|
||||
assert state == {}
|
||||
|
||||
|
||||
def test_pre_stop_state_captures_session_id():
|
||||
"""pre_stop_state reads _executor._session_id when present."""
|
||||
adapter = _StubAdapter()
|
||||
mock_executor = MagicMock(spec=AgentExecutor)
|
||||
mock_executor._session_id = "session-abc123"
|
||||
adapter._executor = mock_executor
|
||||
|
||||
state = adapter.pre_stop_state()
|
||||
assert state["session_id"] == "session-abc123"
|
||||
|
||||
|
||||
def test_pre_stop_state_captures_transcript_lines():
|
||||
"""pre_stop_state calls transcript_lines() and includes lines when supported."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._executor = None # no session_id
|
||||
|
||||
# Override transcript_lines to return supported=True
|
||||
adapter.transcript_lines = MagicMock(return_value={
|
||||
"runtime": "stub",
|
||||
"supported": True,
|
||||
"lines": [{"role": "user", "content": "hello"}],
|
||||
"cursor": 0,
|
||||
"more": False,
|
||||
"source": "/tmp/transcript.jsonl",
|
||||
})
|
||||
|
||||
state = adapter.pre_stop_state()
|
||||
assert state["transcript_lines"] == [{"role": "user", "content": "hello"}]
|
||||
|
||||
|
||||
def test_pre_stop_state_suppresses_transcript_on_exception():
|
||||
"""pre_stop_state never raises — transcript capture is best-effort."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._executor = None
|
||||
|
||||
def broken_transcript(*args, **kwargs):
|
||||
raise RuntimeError("disk error")
|
||||
|
||||
adapter.transcript_lines = broken_transcript
|
||||
|
||||
# Must not raise
|
||||
state = adapter.pre_stop_state()
|
||||
assert state == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# restore_state tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_restore_state_stores_session_id():
|
||||
"""restore_state stores snapshot['session_id'] as _snapshot_session_id."""
|
||||
adapter = _StubAdapter()
|
||||
adapter.restore_state({"session_id": "restored-session-xyz"})
|
||||
assert adapter._snapshot_session_id == "restored-session-xyz"
|
||||
|
||||
|
||||
def test_restore_state_stores_transcript_lines():
|
||||
"""restore_state stores snapshot['transcript_lines'] as _snapshot_transcript."""
|
||||
adapter = _StubAdapter()
|
||||
lines = [{"role": "user", "content": "prior context"}]
|
||||
adapter.restore_state({"transcript_lines": lines})
|
||||
assert adapter._snapshot_transcript == lines
|
||||
|
||||
|
||||
def test_restore_state_handles_missing_keys():
|
||||
"""restore_state works when snapshot lacks session_id or transcript_lines."""
|
||||
adapter = _StubAdapter()
|
||||
adapter.restore_state({})
|
||||
assert adapter._snapshot_session_id is None
|
||||
assert adapter._snapshot_transcript is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# inject_plugins tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_plugins_delegates_to_install_plugins_via_registry():
|
||||
"""inject_plugins calls install_plugins_via_registry (default migration path)."""
|
||||
from unittest.mock import AsyncMock
|
||||
adapter = _StubAdapter()
|
||||
|
||||
with patch.object(adapter, "install_plugins_via_registry", new_callable=AsyncMock) as mock_install:
|
||||
mock_install.return_value = []
|
||||
await adapter.inject_plugins(AdapterConfig(model="test", config_path="/tmp"), MagicMock())
|
||||
mock_install.assert_called_once()
|
||||
Reference in New Issue
Block a user