Compare commits
103 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b705270291 | |||
| 466f303015 | |||
| b6b14a38d2 | |||
| 7270b89a85 | |||
| 76609f4129 | |||
| 8439a066b6 | |||
| d7d376118d | |||
| 026d1c5fae | |||
| 48ad38e795 | |||
| 4bdb10b5e2 | |||
| 6452456f75 | |||
| 4978601032 | |||
| ec3e27a4ec | |||
| 4cc0e32a53 | |||
| e9693e12ff | |||
| bcca139caa | |||
| 6cf6e608d8 | |||
| 6947774e1b | |||
| 9afecfdfc7 | |||
| 220ee57d0c | |||
| 2751861b04 | |||
| da416caeca | |||
| 250af4df36 | |||
| 884bb8c09f | |||
| 0c152a24d2 | |||
| 3345544921 | |||
| 8e2597c877 | |||
| d241dd7f9e | |||
| d437c31da4 | |||
| ca7665f573 | |||
| 11d4b398b7 | |||
| 48f65bc456 | |||
| 408dd452df | |||
| 29d735e431 | |||
| a921851124 | |||
| 3c982587cc | |||
| d59daf87c9 | |||
| 301d84f616 | |||
| 53ac6444c7 | |||
| 447016e652 | |||
| c6a222904e | |||
| f5c476f0c0 | |||
| 858af52d6f | |||
| 4e8b40d1ea | |||
| d5e362690f | |||
| 9f7b87de21 | |||
| 686c330708 | |||
| d021272558 | |||
| 36e85c1950 | |||
| 74ae043a8c | |||
| dd5b1a823f | |||
| 5b554f8afe | |||
| 8b1c867ff0 | |||
| 591d166179 | |||
| c2aacaef2e | |||
| 676cef0656 | |||
| 927663d5bf | |||
| a3eee58dbd | |||
| 9cf997597d | |||
| b713491eda | |||
| bbdb753e82 | |||
| 40df07e94d | |||
| 5efbbd9fa8 | |||
| 3d669b35de | |||
| aea1223b2e | |||
| e6d50ff5ba | |||
| f04e475eab | |||
| 0e34816def | |||
| 60c28ed872 | |||
| 607ab35d7c | |||
| 4b76fe43b1 | |||
| 0afbf3e6d4 | |||
| 57886b714c | |||
| 283fa10415 | |||
| ae75557e6b | |||
| 21cbad5867 | |||
| 79e9e51865 | |||
| 95deb8b98e | |||
| 829b32b867 | |||
| 7709c6bd54 | |||
| e16abf15de | |||
| 6448b38dd9 | |||
| c446329aad | |||
| 51e889f2f3 | |||
| 6a3e854329 | |||
| b94218e5c1 | |||
| 3968bdd92a | |||
| 5a79ccde4c | |||
| 783c9dc6a3 | |||
| 689d454920 | |||
| bb1be0a277 | |||
| 466c510547 | |||
| 1bfff48e9c | |||
| aacf191b6a | |||
| 9c43f6a6e3 | |||
| 1db69d520b | |||
| ca80e3cc91 | |||
| a72ccbb034 | |||
| 42ccaf2da6 | |||
| 7c61e8315e | |||
| 62d3866764 | |||
| ac15906025 | |||
| 6cbf880b04 |
@@ -47,6 +47,15 @@ REQUIRED_CONTEXTS_RAW = _env(
|
||||
"sop-checklist / all-items-acked (pull_request)"
|
||||
),
|
||||
)
|
||||
# Required contexts for push (main/staging) runs. The push CI uses the same
|
||||
# aggregator names with " (push)" suffix. Checking these explicitly instead of
|
||||
# the combined state avoids false-pause when non-blocking jobs (e.g. Platform
|
||||
# Go with continue-on-error: true due to mc#774) have failed — their failures
|
||||
# pollute the combined state but do not block merges.
|
||||
PUSH_REQUIRED_CONTEXTS_RAW = _env(
|
||||
"PUSH_REQUIRED_CONTEXTS",
|
||||
default="CI / all-required (push)",
|
||||
)
|
||||
|
||||
OWNER, NAME = (REPO.split("/", 1) + [""])[:2] if REPO else ("", "")
|
||||
API = f"https://{GITEA_HOST}/api/v1" if GITEA_HOST else ""
|
||||
@@ -118,16 +127,24 @@ def required_contexts(raw: str) -> list[str]:
|
||||
return [part.strip() for part in raw.split(",") if part.strip()]
|
||||
|
||||
|
||||
def push_required_contexts() -> list[str]:
|
||||
"""Required contexts for push (branch) CI runs. See PUSH_REQUIRED_CONTEXTS_RAW."""
|
||||
return required_contexts(PUSH_REQUIRED_CONTEXTS_RAW)
|
||||
|
||||
|
||||
def status_state(status: dict) -> str:
|
||||
return str(status.get("status") or status.get("state") or "").lower()
|
||||
|
||||
|
||||
def latest_statuses_by_context(statuses: list[dict]) -> dict[str, dict]:
|
||||
# Gitea /statuses endpoint returns entries in ascending id order (oldest
|
||||
# first). We need the LAST occurrence of each context, so iterate in
|
||||
# reverse to prefer newer entries.
|
||||
latest: dict[str, dict] = {}
|
||||
for status in statuses:
|
||||
for status in reversed(statuses):
|
||||
context = status.get("context")
|
||||
if isinstance(context, str) and context not in latest:
|
||||
latest[context] = status
|
||||
if isinstance(context, str):
|
||||
latest[context] = status # overwrite: reverse order → newest wins
|
||||
return latest
|
||||
|
||||
|
||||
@@ -193,16 +210,23 @@ def evaluate_merge_readiness(
|
||||
required_contexts: list[str],
|
||||
pr_has_current_base: bool,
|
||||
) -> MergeDecision:
|
||||
main_state = str(main_status.get("state") or "").lower()
|
||||
if main_state != "success":
|
||||
return MergeDecision(False, "pause", f"main status is {main_state or 'missing'}")
|
||||
# Check push-required contexts explicitly instead of combined state.
|
||||
# Combined state can be "failure" due to non-blocking jobs
|
||||
# (continue-on-error: true) that don't actually gate merges.
|
||||
# CI / all-required (push) is the authoritative gate — it respects
|
||||
# continue-on-error and correctly aggregates all blocking failures.
|
||||
main_latest = latest_statuses_by_context(main_status.get("statuses") or [])
|
||||
main_ok, main_bad = required_contexts_green(main_latest, push_required_contexts())
|
||||
if not main_ok:
|
||||
return MergeDecision(False, "pause", "main required contexts not green: " + ", ".join(main_bad))
|
||||
if not pr_has_current_base:
|
||||
return MergeDecision(False, "update", "PR head does not contain current main")
|
||||
|
||||
pr_state = str(pr_status.get("state") or "").lower()
|
||||
if pr_state != "success":
|
||||
return MergeDecision(False, "wait", f"PR combined status is {pr_state or 'missing'}")
|
||||
|
||||
# Check explicit required contexts instead of combined state. Combined state
|
||||
# can be "failure" due to non-blocking jobs with continue-on-error: true
|
||||
# (e.g. publish-runtime-autobump/pr-validate, qa-review on stale tokens).
|
||||
# The required_contexts list is the authoritative gate — it includes only
|
||||
# the checks that actually block merges.
|
||||
latest = latest_statuses_by_context(pr_status.get("statuses") or [])
|
||||
ok, missing_or_bad = required_contexts_green(latest, required_contexts)
|
||||
if not ok:
|
||||
@@ -220,10 +244,37 @@ def get_branch_head(branch: str) -> str:
|
||||
|
||||
|
||||
def get_combined_status(sha: str) -> dict:
|
||||
_, body = api("GET", f"/repos/{OWNER}/{NAME}/commits/{sha}/status")
|
||||
if not isinstance(body, dict):
|
||||
"""Combined status + all individual statuses for `sha`.
|
||||
|
||||
The /status endpoint caps the `statuses` array at 30 entries (Gitea
|
||||
default page size), so we fetch the full list via /statuses with a
|
||||
higher limit. The combined `state` still comes from /status.
|
||||
"""
|
||||
_, combined = api("GET", f"/repos/{OWNER}/{NAME}/commits/{sha}/status")
|
||||
if not isinstance(combined, dict):
|
||||
raise ApiError(f"status for {sha} response not object")
|
||||
return body
|
||||
# Fetch full statuses list; 200 covers >99% of real-world runs.
|
||||
# The list is ordered ascending by id (oldest first) — callers must
|
||||
# iterate in reverse to get the newest entry per context.
|
||||
# Best-effort: large repos (main with 550+ statuses) may time out.
|
||||
# On timeout, fall back to the statuses[] already in the combined
|
||||
# response (usually 30 entries — enough for most PRs, enough for
|
||||
# main's early push-required contexts).
|
||||
try:
|
||||
_, all_statuses = api(
|
||||
"GET",
|
||||
f"/repos/{OWNER}/{NAME}/commits/{sha}/statuses",
|
||||
query={"limit": "50"},
|
||||
)
|
||||
if isinstance(all_statuses, list):
|
||||
combined["statuses"] = all_statuses
|
||||
except (ApiError, urllib.error.URLError, TimeoutError, OSError) as exc:
|
||||
# URLError covers network-level failures (DNS, refused, timeout).
|
||||
# TimeoutError and OSError cover socket-level timeouts.
|
||||
sys.stderr.write(f"::warning::could not fetch full statuses list for {sha[:8]}: {exc}\n")
|
||||
# Fall back to the statuses[] already in the combined response.
|
||||
pass
|
||||
return combined
|
||||
|
||||
|
||||
def list_queued_issues() -> list[dict]:
|
||||
@@ -294,8 +345,12 @@ def process_once(*, dry_run: bool = False) -> int:
|
||||
contexts = required_contexts(REQUIRED_CONTEXTS_RAW)
|
||||
main_sha = get_branch_head(WATCH_BRANCH)
|
||||
main_status = get_combined_status(main_sha)
|
||||
if str(main_status.get("state") or "").lower() != "success":
|
||||
print(f"::notice::queue paused: {WATCH_BRANCH}@{main_sha[:8]} is not green")
|
||||
# Check push-required contexts explicitly instead of combined state.
|
||||
# See evaluate_merge_readiness for rationale.
|
||||
main_latest = latest_statuses_by_context(main_status.get("statuses") or [])
|
||||
main_ok, main_bad = required_contexts_green(main_latest, push_required_contexts())
|
||||
if not main_ok:
|
||||
print(f"::notice::queue paused: {WATCH_BRANCH}@{main_sha[:8]} required contexts not green: {', '.join(main_bad)}")
|
||||
return 0
|
||||
|
||||
issue = choose_next_queued_issue(
|
||||
|
||||
@@ -118,17 +118,19 @@ _DIRECTIVE_RE = re.compile(
|
||||
def parse_directives(
|
||||
comment_body: str,
|
||||
numeric_aliases: dict[int, str],
|
||||
) -> list[tuple[str, str, str]]:
|
||||
) -> tuple[list[tuple[str, str, str]], list]:
|
||||
"""Extract /sop-ack and /sop-revoke directives from a comment body.
|
||||
|
||||
Returns a list of (kind, canonical_slug, note) tuples where:
|
||||
kind is "sop-ack" or "sop-revoke"
|
||||
canonical_slug is the normalized form (or "" if unparseable)
|
||||
note is the trailing free-text (may be "")
|
||||
Returns (directives, na_directives) where:
|
||||
directives is a list of (kind, canonical_slug, note) tuples
|
||||
kind is "sop-ack" or "sop-revoke"
|
||||
canonical_slug is the normalized form (or "" if unparseable)
|
||||
note is the trailing free-text (may be "")
|
||||
na_directives is reserved for future N/A handling (always [] for now)
|
||||
"""
|
||||
out: list[tuple[str, str, str]] = []
|
||||
if not comment_body:
|
||||
return out
|
||||
return out, []
|
||||
for m in _DIRECTIVE_RE.finditer(comment_body):
|
||||
kind = m.group(1)
|
||||
raw_slug = (m.group(2) or "").strip()
|
||||
@@ -159,7 +161,7 @@ def parse_directives(
|
||||
# If we collapsed multi-word slug into kebab and there's a
|
||||
# trailing-text group too, append it.
|
||||
out.append((kind, canonical, note_from_group))
|
||||
return out
|
||||
return out, []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -249,7 +251,8 @@ def compute_ack_state(
|
||||
user = (c.get("user") or {}).get("login", "")
|
||||
if not user:
|
||||
continue
|
||||
for kind, slug, _note in parse_directives(body, numeric_aliases):
|
||||
directives, _na = parse_directives(body, numeric_aliases)
|
||||
for kind, slug, _note in directives:
|
||||
if not slug:
|
||||
unparseable_per_user[user] = unparseable_per_user.get(user, 0) + 1
|
||||
continue
|
||||
|
||||
@@ -85,7 +85,10 @@ def test_pr_needs_update_when_base_sha_absent_from_commits():
|
||||
|
||||
def test_merge_decision_requires_main_green_pr_green_and_current_base():
|
||||
required = ["CI / all-required (pull_request)"]
|
||||
main_status = {"state": "success", "statuses": []}
|
||||
main_status = {
|
||||
"state": "success",
|
||||
"statuses": [{"context": "CI / all-required (push)", "status": "success"}],
|
||||
}
|
||||
pr_status = {
|
||||
"state": "success",
|
||||
"statuses": [{"context": "CI / all-required (pull_request)", "status": "success"}],
|
||||
@@ -104,7 +107,10 @@ def test_merge_decision_requires_main_green_pr_green_and_current_base():
|
||||
|
||||
def test_merge_decision_updates_stale_pr_before_merge():
|
||||
decision = mq.evaluate_merge_readiness(
|
||||
main_status={"state": "success", "statuses": []},
|
||||
main_status={
|
||||
"state": "success",
|
||||
"statuses": [{"context": "CI / all-required (push)", "status": "success"}],
|
||||
},
|
||||
pr_status={"state": "success", "statuses": [{"context": "CI / all-required (pull_request)", "status": "success"}]},
|
||||
required_contexts=["CI / all-required (pull_request)"],
|
||||
pr_has_current_base=False,
|
||||
|
||||
+38
-27
@@ -133,7 +133,6 @@ jobs:
|
||||
# the name match works on PRs that don't touch workspace-server/).
|
||||
platform-build:
|
||||
name: Platform (Go)
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
# mc#774 (closed 2026-05-14): Phase 4 flip of the platform-build job.
|
||||
# Phase 4 (#656) originally flipped this to continue-on-error: false based on
|
||||
@@ -146,33 +145,37 @@ jobs:
|
||||
# the diagnostic step with its own continue-on-error: true (line 203).
|
||||
# Flip confirmed by CI / Platform (Go) status = success on main HEAD 363905d3.
|
||||
continue-on-error: false
|
||||
# Job-level ceiling. The go test step below runs with a per-step 10m timeout;
|
||||
# this cap catches any step that leaks past that. Set well above 10m so
|
||||
# the per-step timeout is the active constraint.
|
||||
timeout-minutes: 15
|
||||
defaults:
|
||||
run:
|
||||
working-directory: workspace-server
|
||||
steps:
|
||||
- if: needs.changes.outputs.platform != 'true'
|
||||
- if: false
|
||||
working-directory: .
|
||||
run: echo "No platform/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: 'stable'
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go mod download
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go build ./cmd/server
|
||||
# CLI (molecli) moved to standalone repo: git.moleculesai.app/molecule-ai/molecule-cli
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go vet ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Install golangci-lint
|
||||
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Run golangci-lint
|
||||
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Diagnostic — per-package verbose 60s
|
||||
run: |
|
||||
set +e
|
||||
@@ -188,11 +191,15 @@ jobs:
|
||||
echo "::endgroup::"
|
||||
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
|
||||
continue-on-error: true
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Run tests with race detection and coverage
|
||||
run: go test -race -coverprofile=coverage.out ./...
|
||||
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
|
||||
# full ./... suite with race detection + coverage. A 10m per-step timeout
|
||||
# lets the suite complete on cold cache (~5-7m) while failing cleanly
|
||||
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
|
||||
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
|
||||
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Per-file coverage report
|
||||
# Advisory — lists every source file with its coverage so reviewers
|
||||
# can see at-a-glance where gaps are. Sorted ascending so the worst
|
||||
@@ -206,7 +213,7 @@ jobs:
|
||||
END {for (f in s) printf "%6.1f%% %s\n", s[f]/c[f], f}' \
|
||||
| sort -n
|
||||
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Check coverage thresholds
|
||||
# Enforces two gates from #1823 Layer 1:
|
||||
# 1. Total floor (25% — ratchet plan in COVERAGE_FLOOR.md).
|
||||
@@ -294,28 +301,28 @@ jobs:
|
||||
# siblings — verified empirically on PR #2314).
|
||||
canvas-build:
|
||||
name: Canvas (Next.js)
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
# Phase 4 (RFC #219 §1): confirmed green on main 2026-05-12.
|
||||
continue-on-error: false
|
||||
defaults:
|
||||
run:
|
||||
working-directory: canvas
|
||||
steps:
|
||||
- if: needs.changes.outputs.canvas != 'true'
|
||||
- if: false
|
||||
working-directory: .
|
||||
run: echo "No canvas/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version: '22'
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
run: rm -f package-lock.json && npm install
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
run: npm run build
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
name: Run tests with coverage
|
||||
# Coverage instrumentation is configured in canvas/vitest.config.ts
|
||||
# (provider: v8, reporters: text + html + json-summary). Step 2 of
|
||||
@@ -324,7 +331,7 @@ jobs:
|
||||
# tracked in #1815) after the team sees what current coverage is.
|
||||
run: npx vitest run --coverage
|
||||
- name: Upload coverage summary as artifact
|
||||
if: needs.changes.outputs.canvas == 'true' && always()
|
||||
if: always()
|
||||
# Pinned to v3 for Gitea act_runner v0.6 compatibility — v4+ uses
|
||||
# the GHES 3.10+ artifact protocol that Gitea 1.22.x does NOT
|
||||
# implement, surfacing as `GHESNotSupportedError: @actions/artifact
|
||||
@@ -391,15 +398,18 @@ jobs:
|
||||
scripts/promote-tenant-image.sh \
|
||||
scripts/test-promote-tenant-image.sh
|
||||
|
||||
# mc#959 root-fix (sre)
|
||||
|
||||
canvas-deploy-reminder:
|
||||
name: Canvas Deploy Reminder
|
||||
runs-on: ubuntu-latest
|
||||
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
|
||||
continue-on-error: true
|
||||
# mc#774 root-fix: added job-level `if:` so ci-required-drift.py's
|
||||
# ci_job_names() detects this as github.ref-gated and skips it from F1.
|
||||
# The step-level exit 0 handles the "not main push" case; the job-level
|
||||
# `if:` makes the gating explicit so the drift script sees it.
|
||||
# continue-on-error removed (was mc#774 mask): step exits 0 when not applicable.
|
||||
if: ${{ github.ref == 'refs/heads/staging' }}
|
||||
needs: [changes, canvas-build]
|
||||
# Keep the job itself always runnable. Gitea 1.22.6 leaves job-level
|
||||
# event/ref `if:` gates as pending on PRs, which blocks the combined
|
||||
# status even though this reminder is intentionally non-required.
|
||||
steps:
|
||||
- name: Write deploy reminder to step summary
|
||||
env:
|
||||
@@ -586,6 +596,7 @@ jobs:
|
||||
- canvas-build
|
||||
- shellcheck
|
||||
- python-lint
|
||||
- canvas-deploy-reminder
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Assert every required dependency succeeded
|
||||
|
||||
@@ -48,4 +48,9 @@ jobs:
|
||||
REQUIRED_CONTEXTS: >-
|
||||
CI / all-required (pull_request),
|
||||
sop-checklist / all-items-acked (pull_request)
|
||||
# Push-side required contexts. Checking CI / all-required (push)
|
||||
# explicitly instead of the combined state avoids false-pause when
|
||||
# non-blocking jobs (continue-on-error: true) have failed — those
|
||||
# failures pollute combined state but do not gate merges.
|
||||
PUSH_REQUIRED_CONTEXTS: CI / all-required (push)
|
||||
run: python3 .gitea/scripts/gitea-merge-queue.py
|
||||
|
||||
@@ -344,7 +344,7 @@ function ProviderPickerModal({
|
||||
// wrapper's bounds instead of the viewport.
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
@@ -616,7 +616,7 @@ function AllKeysModal({
|
||||
if (!open) return null;
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
|
||||
@@ -62,11 +62,11 @@ export function ThemeToggle({ className = "" }: { className?: string }) {
|
||||
}
|
||||
setTheme(OPTIONS[next].value);
|
||||
// Move focus to the new button so arrow-key navigation is continuous.
|
||||
// Use direct-child query to scope strictly to this radiogroup's buttons
|
||||
// and avoid accidentally focusing unrelated [role=radio] elements
|
||||
// Query is already scoped to radiogroup so no child-combinator needed;
|
||||
// avoids accidentally focusing unrelated [role=radio] elements
|
||||
// elsewhere in the DOM (e.g. React Flow canvas nodes).
|
||||
const radiogroup = e.currentTarget.closest("[role=radiogroup]") as HTMLElement | null;
|
||||
const btns = radiogroup?.querySelectorAll<HTMLButtonElement>("> [role=radio]");
|
||||
const btns = radiogroup?.querySelectorAll<HTMLButtonElement>("[role=radio]");
|
||||
btns?.[next]?.focus();
|
||||
},
|
||||
[]
|
||||
|
||||
@@ -13,17 +13,20 @@ import { isExternalLikeRuntime } from "@/lib/externalRuntimes";
|
||||
|
||||
/** Descendant count for the "N sub" badge — children are first-class nodes
|
||||
* rendered as full cards inside this one via React Flow's native parentId,
|
||||
* so we don't need to subscribe to the actual child list here. */
|
||||
* so we don't need to subscribe to the actual child list here.
|
||||
* Selecting `nodes` stably avoids a new selector reference on every store
|
||||
* update (React error #185 / Zustand + React 19 Object.is strictness). */
|
||||
function useDescendantCount(nodeId: string): number {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => countDescendants(nodeId, s.nodes), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => countDescendants(nodeId, nodes), [nodeId, nodes]);
|
||||
}
|
||||
|
||||
/** Boolean flag used to drive min-size and NodeResizer dimensions.
|
||||
* Selecting `nodes` stably avoids re-render loops (same issue as
|
||||
* useDescendantCount). */
|
||||
function useHasChildren(nodeId: string): boolean {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => s.nodes.some((n) => n.data.parentId === nodeId), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => nodes.some((n) => n.data.parentId === nodeId), [nodes, nodeId]);
|
||||
}
|
||||
|
||||
/** Eject/extract arrow icon — visually distinct from delete ✕ */
|
||||
|
||||
@@ -24,16 +24,20 @@ import {
|
||||
*/
|
||||
export function DropTargetBadge() {
|
||||
const dragOverNodeId = useCanvasStore((s) => s.dragOverNodeId);
|
||||
const targetName = useCanvasStore((s) => {
|
||||
if (!s.dragOverNodeId) return null;
|
||||
const n = s.nodes.find((nn) => nn.id === s.dragOverNodeId);
|
||||
// Select nodes stably first — deriving targetName and childCount inside
|
||||
// the same selector creates a new return value on every store mutation
|
||||
// even when neither has changed (React error #185 / Zustand Object.is).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const targetName = (() => {
|
||||
if (!dragOverNodeId) return null;
|
||||
const n = nodes.find((nn) => nn.id === dragOverNodeId);
|
||||
return (n?.data as WorkspaceNodeData | undefined)?.name ?? null;
|
||||
});
|
||||
const childCount = useCanvasStore((s) =>
|
||||
!s.dragOverNodeId
|
||||
})();
|
||||
const childCount = (() =>
|
||||
!dragOverNodeId
|
||||
? 0
|
||||
: s.nodes.filter((n) => n.parentId === s.dragOverNodeId).length,
|
||||
);
|
||||
: nodes.filter((n) => n.parentId === dragOverNodeId).length
|
||||
)();
|
||||
const { getInternalNode, flowToScreenPosition } = useReactFlow();
|
||||
if (!dragOverNodeId || !targetName) return null;
|
||||
const internal = getInternalNode(dragOverNodeId);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
import { appendClass, removeClass } from "@/store/classNames";
|
||||
@@ -153,10 +153,17 @@ export function useCanvasViewport() {
|
||||
// fit, the user has to manually pan + zoom to find what they just
|
||||
// created. Only fires when TRANSITIONING from some-provisioning to
|
||||
// zero-provisioning — not on every re-render.
|
||||
const provisioningCount = useCanvasStore(
|
||||
(s) => s.nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
//
|
||||
// Selecting `nodes` stably (array reference) avoids the
|
||||
// `.filter().length` anti-pattern which creates a new number on every
|
||||
// store update and breaks the wasProvisioning/hasProvisioning
|
||||
// transition detection (React error #185 / Zustand + React 19).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const provisioningCount = useMemo(
|
||||
() => nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
[nodes],
|
||||
);
|
||||
const nodeCount = useCanvasStore((s) => s.nodes.length);
|
||||
const nodeCount = nodes.length;
|
||||
|
||||
useEffect(() => {
|
||||
const hasProvisioning = provisioningCount > 0;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
// that the desktop ChatTab uses, but with a slimmer surface: no
|
||||
// attachments, no A2A topology overlay, no conversation tracing.
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@@ -36,6 +36,20 @@ interface A2AResponseShape {
|
||||
error?: { message?: string };
|
||||
}
|
||||
|
||||
// Wire shape for GET /workspaces/:id/chat-history (chat_history.go → ChatHistoryResponse).
|
||||
interface ApiChatMessage {
|
||||
id: string;
|
||||
role: string; // "user" | "agent" | "system"
|
||||
content: string;
|
||||
timestamp: string;
|
||||
attachments?: Array<{ name: string; uri: string; mimeType?: string; size?: number }>;
|
||||
}
|
||||
|
||||
interface ChatHistoryResponse {
|
||||
messages: ApiChatMessage[];
|
||||
reached_end: boolean;
|
||||
}
|
||||
|
||||
const formatTime = (date: Date) =>
|
||||
date.toLocaleTimeString([], { hour: "numeric", minute: "2-digit" });
|
||||
|
||||
@@ -49,7 +63,10 @@ export function MobileChat({
|
||||
onBack: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
// Bootstrap from the canvas store's per-workspace message buffer so the
|
||||
// user sees their prior thread on entry. The store is updated by the
|
||||
// socket → ChatTab flows the desktop runs; on mobile we read from the
|
||||
@@ -58,18 +75,14 @@ export function MobileChat({
|
||||
// that creates a new [] reference on every store update when the key is
|
||||
// absent, causing infinite re-render (React error #185).
|
||||
const storedMessages = useCanvasStore((s) => s.agentMessages[agentId]);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>(() =>
|
||||
(storedMessages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})),
|
||||
);
|
||||
// Start empty — history is loaded via useEffect below.
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [draft, setDraft] = useState("");
|
||||
const [tab, setTab] = useState<SubTab>("my");
|
||||
const [sending, setSending] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [loading, setLoading] = useState(true); // history is loading on mount
|
||||
const [historyError, setHistoryError] = useState<string | null>(null);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
// Synchronous re-entry guard. `setSending(true)` schedules a state
|
||||
// update but doesn't flush before a second tap can fire send() — a ref
|
||||
@@ -77,6 +90,9 @@ export function MobileChat({
|
||||
// double-send race a stale `sending` lets through.
|
||||
const sendInFlightRef = useRef(false);
|
||||
const composerRef = useRef<HTMLTextAreaElement>(null);
|
||||
// Guard: don't treat the initial store population as a live push.
|
||||
// Set to false after the first render completes.
|
||||
const initDoneRef = useRef(false);
|
||||
|
||||
// Auto-grow the textarea: reset height to 'auto' so the scrollHeight
|
||||
// shrinks when the user deletes text, then size to scrollHeight up to
|
||||
@@ -89,6 +105,75 @@ export function MobileChat({
|
||||
el.style.height = `${next}px`;
|
||||
}, [draft]);
|
||||
|
||||
// Fetch chat history on mount; keep merging live agentMessages while the
|
||||
// panel is open. InitDoneRef prevents the initial store snapshot from
|
||||
// triggering the live-merge path (the store buffer is populated by
|
||||
// ChatTab on desktop, not on mobile — this effect loads history as the
|
||||
// mobile-native path).
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const mapApiMessage = (m: ApiChatMessage): ChatMessage => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
});
|
||||
|
||||
const syncLive = () => {
|
||||
const live = useCanvasStore.getState().agentMessages[agentId] ?? [];
|
||||
if (live.length > 0) {
|
||||
setMessages((prev) => {
|
||||
const existingIds = new Set(prev.map((m) => m.id));
|
||||
const newOnes = live
|
||||
.filter((m) => !existingIds.has(m.id))
|
||||
.map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent" as const,
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
}));
|
||||
return newOnes.length > 0 ? [...prev, ...newOnes] : prev;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const bootstrap = async (): Promise<(() => void) | undefined> => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
try {
|
||||
const res = await api.get<ChatHistoryResponse>(
|
||||
`/workspaces/${agentId}/chat-history?limit=50`,
|
||||
);
|
||||
if (cancelled) return;
|
||||
const initial = (res.messages ?? []).map(mapApiMessage);
|
||||
setMessages(initial);
|
||||
// Mark init done BEFORE marking loading=false so any store push
|
||||
// that arrives in the same tick is treated as live, not init.
|
||||
initDoneRef.current = true;
|
||||
setLoading(false);
|
||||
// Subscribe to live pushes after init is complete.
|
||||
syncLive();
|
||||
const unsubscribe = useCanvasStore.subscribe(syncLive);
|
||||
return unsubscribe; // returned for cleanup
|
||||
} catch (e) {
|
||||
if (cancelled) return;
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load chat history");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let maybeUnsubscribe: (() => void) | undefined;
|
||||
bootstrap().then((fn) => { maybeUnsubscribe = fn; });
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (maybeUnsubscribe) maybeUnsubscribe();
|
||||
};
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
||||
@@ -308,7 +393,61 @@ export function MobileChat({
|
||||
Agent Comms — peer-to-peer A2A traffic surfaces in the Comms tab.
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && messages.length === 0 && (
|
||||
{tab === "my" && loading && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
<div style={{ marginBottom: 6, opacity: 0.6, animation: "spin 1s linear infinite", display: "inline-block", fontSize: 16 }}>⟳</div>
|
||||
<div>Loading chat history…</div>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !loading && historyError && (
|
||||
<div
|
||||
role="alert"
|
||||
style={{
|
||||
padding: "14px 4px",
|
||||
textAlign: "center",
|
||||
color: p.failed,
|
||||
fontSize: 13,
|
||||
}}
|
||||
>
|
||||
<div style={{ marginBottom: 8 }}>Could not load chat history.</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
api.get(`/workspaces/${agentId}/chat-history?limit=50`).then(
|
||||
(res: unknown) => {
|
||||
const r = res as ChatHistoryResponse;
|
||||
setMessages((r.messages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})));
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
},
|
||||
).catch((e: unknown) => {
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
});
|
||||
}}
|
||||
style={{
|
||||
padding: "6px 14px",
|
||||
borderRadius: 14,
|
||||
border: `0.5px solid ${p.failed}`,
|
||||
background: "transparent",
|
||||
color: p.failed,
|
||||
fontSize: 12,
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !loading && !historyError && messages.length === 0 && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
Send a message to start chatting.
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// 03 · Agent detail — pills + tabbed content (Overview/Activity/Config/Memory).
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@@ -32,7 +32,10 @@ export function MobileDetail({
|
||||
onChat: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
const [tab, setTab] = useState<TabId>("overview");
|
||||
|
||||
if (!node) {
|
||||
|
||||
@@ -8,11 +8,19 @@
|
||||
* NOTE: No @testing-library/jest-dom — use DOM APIs.
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { cleanup, render } from "@testing-library/react";
|
||||
import { act, cleanup, render, waitFor } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
import { MobileChat } from "../MobileChat";
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
// vi.mock without a factory auto-mocks the module. In tests, we configure
|
||||
// api.get / api.post directly (they are vi.fn() from the auto-mock).
|
||||
// Tests that need specific behaviour use mockResolvedValueOnce on the
|
||||
// auto-mocked functions.
|
||||
vi.mock("@/lib/api");
|
||||
import { api } from "@/lib/api";
|
||||
|
||||
// ─── Mock store ───────────────────────────────────────────────────────────────
|
||||
|
||||
const mockAgentId = "ws-chat-test";
|
||||
@@ -32,8 +40,14 @@ const mockStoreState = {
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
vi.fn((sel) => sel(mockStoreState)),
|
||||
{ getState: () => mockStoreState },
|
||||
vi.fn((sel?: (state: typeof mockStoreState) => unknown) => {
|
||||
if (sel) return sel(mockStoreState);
|
||||
return mockStoreState;
|
||||
}),
|
||||
{
|
||||
getState: () => mockStoreState,
|
||||
subscribe: vi.fn(() => vi.fn()),
|
||||
},
|
||||
),
|
||||
summarizeWorkspaceCapabilities: vi.fn((data: Record<string, unknown>) => {
|
||||
const agentCard = data.agentCard as Record<string, unknown> | null;
|
||||
@@ -54,16 +68,6 @@ vi.mock("@/store/canvas", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const { mockApiPost } = vi.hoisted(() => ({
|
||||
mockApiPost: vi.fn().mockResolvedValue({ result: { parts: [] } }),
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: { post: mockApiPost },
|
||||
}));
|
||||
|
||||
// ─── Fixtures ────────────────────────────────────────────────────────────────
|
||||
|
||||
const onlineNode = {
|
||||
@@ -150,7 +154,15 @@ beforeEach(() => {
|
||||
mockOnBack.mockClear();
|
||||
mockStoreState.nodes = [];
|
||||
mockStoreState.agentMessages = {};
|
||||
mockApiPost.mockClear();
|
||||
// Set up spies on the real api methods. Tests override these per-call.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
const postSpy = vi.spyOn(api, "post");
|
||||
getSpy.mockResolvedValue({ messages: [], reached_end: true });
|
||||
postSpy.mockResolvedValue({ result: { parts: [] } });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -266,15 +278,26 @@ describe("MobileChat — empty state", () => {
|
||||
mockStoreState.nodes = [onlineNode];
|
||||
});
|
||||
|
||||
it('shows "Send a message to start chatting." when no messages', () => {
|
||||
const { container } = renderChat(mockAgentId);
|
||||
it('shows "Send a message to start chatting." when no messages', async () => {
|
||||
// History fetch resolves immediately in tests (mockResolvedValue).
|
||||
// act() flushes the microtask queue so the component reaches its
|
||||
// post-load state before we assert.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("shows no messages when agentMessages[agentId] is absent (undefined)", () => {
|
||||
it("shows no messages when agentMessages[agentId] is absent (undefined)", async () => {
|
||||
// Explicitly set to empty to simulate no stored messages
|
||||
mockStoreState.agentMessages = {};
|
||||
const { container } = renderChat(mockAgentId);
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
});
|
||||
@@ -321,3 +344,132 @@ describe("MobileChat — dark mode", () => {
|
||||
expect(container.querySelector('[aria-label="Back"]')).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Chat history loading ────────────────────────────────────────────────────
|
||||
|
||||
describe("MobileChat — chat history", () => {
|
||||
beforeEach(() => {
|
||||
mockStoreState.nodes = [onlineNode];
|
||||
});
|
||||
|
||||
it("calls GET /workspaces/:id/chat-history on mount", async () => {
|
||||
await act(async () => {
|
||||
renderChat(mockAgentId);
|
||||
});
|
||||
expect(api.get).toHaveBeenCalledWith(
|
||||
`/workspaces/${mockAgentId}/chat-history?limit=50`,
|
||||
);
|
||||
});
|
||||
|
||||
it("shows loading state while history is fetching", () => {
|
||||
// Do NOT await — check the pre-resolve state.
|
||||
const { container } = renderChat(mockAgentId);
|
||||
expect(container.textContent ?? "").toContain("Loading chat history…");
|
||||
});
|
||||
|
||||
it("shows empty state after history resolves with no messages", async () => {
|
||||
// beforeEach already sets api.get to resolve with empty — no override needed.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("renders messages from history response", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
content: "Hello agent",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
{
|
||||
id: "msg-2",
|
||||
role: "agent",
|
||||
content: "Hello back",
|
||||
timestamp: "2026-04-25T10:00:01Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Hello agent");
|
||||
expect(container.textContent ?? "").toContain("Hello back");
|
||||
});
|
||||
|
||||
it("maps user role from API correctly", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-u",
|
||||
role: "user",
|
||||
content: "user message",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
// User messages render right-aligned. The text content check is sufficient
|
||||
// to confirm the message appeared.
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("user message");
|
||||
});
|
||||
|
||||
it("shows error state when history fetch fails", async () => {
|
||||
vi.spyOn(api, "get").mockRejectedValue(new Error("Network error"));
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
});
|
||||
|
||||
it("Retry button re-fetches history after error", async () => {
|
||||
// Make the initial mount call fail so the Retry button appears, then
|
||||
// make the retry call succeed so we can verify the full flow.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
getSpy
|
||||
.mockRejectedValueOnce(new Error("Network error"))
|
||||
.mockResolvedValueOnce({ messages: [], reached_end: true });
|
||||
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
|
||||
// Error state should be shown with Retry button.
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
|
||||
// Click Retry — the button's onClick fires api.get again.
|
||||
// The second mockResolvedValueOnce makes it succeed.
|
||||
const retryBtn = Array.from(container.querySelectorAll("button")).find(
|
||||
(b) => b.textContent?.trim() === "Retry",
|
||||
);
|
||||
expect(retryBtn).toBeTruthy();
|
||||
await act(async () => {
|
||||
retryBtn?.click();
|
||||
});
|
||||
|
||||
// waitFor polls until the retry resolves and component re-renders.
|
||||
await waitFor(() => {
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
// Initial call + retry = 2.
|
||||
expect(getSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -243,7 +243,7 @@ export function BudgetSection({ workspaceId }: Props) {
|
||||
onClick={handleSave}
|
||||
disabled={saving}
|
||||
data-testid="budget-save-btn"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{saving ? "Saving…" : "Save"}
|
||||
</button>
|
||||
|
||||
@@ -255,7 +255,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => setShowForm(!showForm)}
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition"
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{showForm ? "Cancel" : "+ Connect"}
|
||||
</button>
|
||||
@@ -308,7 +308,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={handleDiscover}
|
||||
disabled={discovering || !formValues["bot_token"]}
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40"
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{discovering ? "Detecting..." : "Detect Chats"}
|
||||
</button>
|
||||
|
||||
@@ -962,6 +962,32 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* talk_to_user disabled banner — shown when the workspace has
|
||||
talk_to_user_enabled=false. The agent cannot send canvas messages;
|
||||
the user can re-enable the ability from here without opening settings. */}
|
||||
{data.talkToUserEnabled === false && (
|
||||
<div className="flex items-center gap-2 px-3 py-2 bg-surface-sunken border-b border-line/40 shrink-0">
|
||||
<svg width="14" height="14" viewBox="0 0 16 16" fill="none" aria-hidden="true" className="shrink-0 text-ink-mid">
|
||||
<path d="M8 1a7 7 0 1 0 0 14A7 7 0 0 0 8 1Zm0 10.5a.75.75 0 1 1 0-1.5.75.75 0 0 1 0 1.5ZM8 4a.75.75 0 0 1 .75.75v4a.75.75 0 0 1-1.5 0v-4A.75.75 0 0 1 8 4Z" fill="currentColor"/>
|
||||
</svg>
|
||||
<span className="text-[10px] text-ink-mid flex-1">
|
||||
Agent is not enabled to chat with you.
|
||||
</span>
|
||||
<button
|
||||
onClick={async () => {
|
||||
try {
|
||||
await api.patch(`/workspaces/${workspaceId}/abilities`, { talk_to_user_enabled: true });
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { talkToUserEnabled: true });
|
||||
} catch {
|
||||
// ignore — user will see no change and can retry
|
||||
}
|
||||
}}
|
||||
className="px-2 py-0.5 text-[10px] font-medium bg-accent/10 hover:bg-accent/20 text-accent rounded border border-accent/30 transition-colors shrink-0"
|
||||
>
|
||||
Enable
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Messages */}
|
||||
<div ref={containerRef} className="flex-1 overflow-y-auto p-3 space-y-3">
|
||||
{loading && (
|
||||
|
||||
@@ -194,7 +194,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
</span>
|
||||
<button
|
||||
onClick={() => { resetForm(); setShowForm(true); }}
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors"
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
+ Add Schedule
|
||||
</button>
|
||||
@@ -339,7 +339,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
? "Last run OK — click to disable"
|
||||
: "Never run — click to enable"
|
||||
}
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 ${
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900 ${
|
||||
sched.last_status === "error"
|
||||
? "bg-red-400"
|
||||
: sched.last_status === "ok"
|
||||
@@ -376,7 +376,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleRunNow(sched)}
|
||||
aria-label={`Run schedule ${sched.name} now`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Run now"
|
||||
>
|
||||
▶
|
||||
@@ -384,7 +384,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleEdit(sched)}
|
||||
aria-label={`Edit schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Edit"
|
||||
>
|
||||
✎
|
||||
@@ -392,7 +392,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => setPendingDelete({ id: sched.id, name: sched.name })}
|
||||
aria-label={`Delete schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-red-400 focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Delete"
|
||||
>
|
||||
✕
|
||||
|
||||
@@ -21,8 +21,8 @@ export function statusDotClass(status: string): string {
|
||||
export const TIER_CONFIG: Record<number, { label: string; color: string; border: string }> = {
|
||||
1: { label: "T1", color: "text-ink-mid bg-surface-card border border-line", border: "text-ink-mid border-line" },
|
||||
2: { label: "T2", color: "text-white bg-accent border border-accent-strong", border: "text-accent border-accent" },
|
||||
3: { label: "T3", color: "text-white bg-violet-600 border border-violet-700", border: "text-violet-600 border-violet-500" },
|
||||
4: { label: "T4", color: "text-white bg-warm border border-warm", border: "text-warm border-warm" },
|
||||
3: { label: "T3", color: "text-white bg-violet-600 border border-violet-700", border: "text-white border-violet-500" },
|
||||
4: { label: "T4", color: "text-white bg-warm border border-warm", border: "text-white border-warm" },
|
||||
};
|
||||
|
||||
export const COMM_TYPE_LABELS: Record<string, string> = {
|
||||
|
||||
@@ -519,6 +519,10 @@ export function buildNodesAndEdges(
|
||||
// #2054 — server-declared per-workspace provisioning timeout.
|
||||
// Falls through to the runtime profile when null/absent.
|
||||
provisionTimeoutMs: ws.provision_timeout_ms ?? null,
|
||||
// Workspace abilities — defaults preserved for old platform versions
|
||||
// that don't yet include these columns in the GET response.
|
||||
broadcastEnabled: ws.broadcast_enabled ?? false,
|
||||
talkToUserEnabled: ws.talk_to_user_enabled ?? true,
|
||||
},
|
||||
};
|
||||
if (hasParent) {
|
||||
|
||||
@@ -99,6 +99,13 @@ export interface WorkspaceNodeData extends Record<string, unknown> {
|
||||
* @/lib/runtimeProfiles. Lets a slow runtime declare its cold-boot
|
||||
* expectation without a canvas release. */
|
||||
provisionTimeoutMs?: number | null;
|
||||
/** When true the workspace may POST /broadcast to send org-wide messages.
|
||||
* Default false. Toggled by user/admin via PATCH /workspaces/:id/abilities. */
|
||||
broadcastEnabled?: boolean;
|
||||
/** When false the workspace cannot deliver canvas chat messages.
|
||||
* send_message_to_user / POST /notify return 403 and the canvas
|
||||
* shows a "not enabled" state with a button to re-enable. Default true. */
|
||||
talkToUserEnabled?: boolean;
|
||||
}
|
||||
|
||||
export type PanelTab = "details" | "skills" | "chat" | "terminal" | "config" | "schedule" | "channels" | "files" | "memory" | "traces" | "events" | "activity" | "audit";
|
||||
|
||||
@@ -299,6 +299,9 @@ export interface WorkspaceData {
|
||||
* `@/lib/runtimeProfiles` when absent (the default behavior for any
|
||||
* template that hasn't yet declared the field). */
|
||||
provision_timeout_ms?: number | null;
|
||||
/** Workspace ability flags (migration 20260514). */
|
||||
broadcast_enabled?: boolean;
|
||||
talk_to_user_enabled?: boolean;
|
||||
}
|
||||
|
||||
let socket: ReconnectingSocket | null = null;
|
||||
|
||||
Executable
+296
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env bash
|
||||
# E2E test: workspace broadcast and talk-to-user platform abilities.
|
||||
#
|
||||
# What this proves:
|
||||
# 1. talk_to_user_enabled (default true) — POST /notify works out-of-the-box.
|
||||
# 2. PATCH /workspaces/:id/abilities { talk_to_user_enabled: false } disables
|
||||
# delivery: /notify → 403 with error="talk_to_user_disabled" + delegate hint.
|
||||
# 3. Re-enabling talk_to_user_enabled restores delivery.
|
||||
# 4. broadcast_enabled (default false) — POST /broadcast → 403 when disabled.
|
||||
# 5. PATCH { broadcast_enabled: true } enables fan-out.
|
||||
# 6. POST /broadcast delivers to all non-sender, non-removed workspaces:
|
||||
# - Returns {"status":"sent","delivered":N}
|
||||
# - Receiver's activity log has a broadcast_receive entry with the message.
|
||||
# - Sender's activity log has a broadcast_sent entry.
|
||||
# 7. The sender itself does NOT receive a broadcast_receive entry.
|
||||
#
|
||||
# Usage: tests/e2e/test_workspace_abilities_e2e.sh
|
||||
# Prereqs: workspace-server on http://localhost:8080, MOLECULE_ENV != production
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
source "$(dirname "$0")/_lib.sh"
|
||||
|
||||
PASS=0
|
||||
FAIL=0
|
||||
SENDER_ID=""
|
||||
RECEIVER_ID=""
|
||||
|
||||
cleanup() {
|
||||
for wid in "$SENDER_ID" "$RECEIVER_ID"; do
|
||||
if [ -n "$wid" ]; then
|
||||
curl -s -X DELETE "$BASE/workspaces/$wid?confirm=true" > /dev/null || true
|
||||
fi
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
assert() {
|
||||
local label="$1" actual="$2" expected="$3"
|
||||
if [ "$actual" = "$expected" ]; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " expected: $expected"
|
||||
echo " actual: $actual"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_not_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if ! echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label (unexpected match)"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Pre-sweep: remove any stale leftover workspaces from a prior aborted run ──
|
||||
echo "=== Setup ==="
|
||||
for NAME in "Abilities Sender" "Abilities Receiver"; do
|
||||
PRIOR=$(curl -s "$BASE/workspaces" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
print(' '.join(w['id'] for w in json.load(sys.stdin) if w.get('name') == '$NAME'))
|
||||
except Exception:
|
||||
pass
|
||||
")
|
||||
for _wid in $PRIOR; do
|
||||
echo "Sweeping leftover '$NAME' workspace: $_wid"
|
||||
curl -s -X DELETE "$BASE/workspaces/$_wid?confirm=true" > /dev/null || true
|
||||
done
|
||||
done
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Sender","tier":1}')
|
||||
SENDER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$SENDER_ID" ] || { echo "Failed to create sender workspace: $R"; exit 1; }
|
||||
echo "Created sender workspace: $SENDER_ID"
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Receiver","tier":1}')
|
||||
RECEIVER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$RECEIVER_ID" ] || { echo "Failed to create receiver workspace: $R"; exit 1; }
|
||||
echo "Created receiver workspace: $RECEIVER_ID"
|
||||
|
||||
# Mint workspace-scoped bearer tokens (test-only endpoint, disabled in prod).
|
||||
SENDER_TOKEN=$(e2e_mint_test_token "$SENDER_ID")
|
||||
[ -n "$SENDER_TOKEN" ] || { echo "Failed to mint sender token"; exit 1; }
|
||||
SENDER_AUTH="Authorization: Bearer $SENDER_TOKEN"
|
||||
|
||||
# Admin token — any live workspace bearer satisfies AdminAuth in local dev.
|
||||
# In production-like envs, set MOLECULE_ADMIN_TOKEN.
|
||||
ADMIN_TOKEN="${MOLECULE_ADMIN_TOKEN:-$SENDER_TOKEN}"
|
||||
ADMIN_AUTH="Authorization: Bearer $ADMIN_TOKEN"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 1: talk_to_user ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: /notify works with default talk_to_user_enabled=true ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Hello from sender"}')
|
||||
assert "POST /notify returns 200 when talk_to_user_enabled=true (default)" "$CODE" "200"
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Disable talk_to_user ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=false returns 200" "$CODE" "200"
|
||||
|
||||
# Verify the flag is reflected in the workspace GET response.
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects talk_to_user_enabled=false" "$FLAG" "False"
|
||||
|
||||
echo ""
|
||||
echo "--- 1c: /notify blocked when talk_to_user disabled ---"
|
||||
BODY=$(curl -s -w "" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /notify returns 403 when talk_to_user_enabled=false" "$CODE" "403"
|
||||
|
||||
ERR=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("error",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains talk_to_user_disabled error code" "$ERR" "talk_to_user_disabled"
|
||||
|
||||
HINT=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("hint",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains delegate_task hint" "$HINT" "delegate_task"
|
||||
|
||||
echo ""
|
||||
echo "--- 1d: Re-enable talk_to_user and verify /notify works again ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": true}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Re-enabled, should work"}')
|
||||
assert "POST /notify returns 200 after re-enabling talk_to_user" "$CODE" "200"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 2: broadcast ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: Broadcast blocked by default (broadcast_enabled=false) ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /broadcast returns 403 when broadcast_enabled=false (default)" "$CODE" "403"
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: Enable broadcast ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": true}')
|
||||
assert "PATCH /abilities broadcast_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects broadcast_enabled=true" "$FLAG" "True"
|
||||
|
||||
echo ""
|
||||
echo "--- 2c: Successful broadcast fan-out ---"
|
||||
BCAST=$(curl -s -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Org-wide notice: scheduled maintenance in 5 minutes."}')
|
||||
BSTATUS=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("status",""))' 2>/dev/null || echo "")
|
||||
BDELIVERED=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("delivered","-1"))' 2>/dev/null || echo "-1")
|
||||
assert "POST /broadcast returns status=sent" "$BSTATUS" "sent"
|
||||
|
||||
# delivered count must be >= 1 (the receiver workspace).
|
||||
echo " INFO — broadcast delivered=$BDELIVERED"
|
||||
if python3 -c "import sys; sys.exit(0 if int('$BDELIVERED') >= 1 else 1)" 2>/dev/null; then
|
||||
echo " PASS — delivered count >= 1"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — expected delivered >= 1, got $BDELIVERED"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2d: Receiver activity log has broadcast_receive entry ---"
|
||||
RECEIVER_TOKEN=$(e2e_mint_test_token "$RECEIVER_ID")
|
||||
[ -n "$RECEIVER_TOKEN" ] || { echo "Failed to mint receiver token"; exit 1; }
|
||||
RECEIVER_AUTH="Authorization: Bearer $RECEIVER_TOKEN"
|
||||
|
||||
ACT=$(curl -s -H "$RECEIVER_AUTH" "$BASE/workspaces/$RECEIVER_ID/activity?source=agent&limit=20")
|
||||
ROW=$(echo "$ACT" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_receive row in receiver activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$ROW" ]; then
|
||||
# Message is stored in summary field.
|
||||
MSG=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("summary",""))')
|
||||
assert_contains "broadcast_receive row summary has original message" "$MSG" "scheduled maintenance"
|
||||
# Sender ID is stored in source_id field.
|
||||
SRC=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("source_id",""))')
|
||||
assert "broadcast_receive row source_id is sender workspace" "$SRC" "$SENDER_ID"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2e: Sender activity log has broadcast_sent entry ---"
|
||||
ACT_SENDER=$(curl -s -H "$SENDER_AUTH" "$BASE/workspaces/$SENDER_ID/activity?limit=20")
|
||||
SENT_ROW=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_sent":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$SENT_ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_sent row in sender activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$SENT_ROW" ]; then
|
||||
# Delivered count is baked into the summary field (no response_body for sender row).
|
||||
SUMMARY=$(echo "$SENT_ROW" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("summary",""))')
|
||||
assert_contains "broadcast_sent summary mentions workspace count" "$SUMMARY" "workspace"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2f: Sender does NOT receive a broadcast_receive entry ---"
|
||||
SELF_RECV=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print("found")
|
||||
break
|
||||
')
|
||||
assert_not_contains "sender has no broadcast_receive in own activity log" "${SELF_RECV:-}" "found"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- 2g: Empty message is rejected ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":""}')
|
||||
assert "POST /broadcast with empty message returns 400" "$CODE" "400"
|
||||
|
||||
echo ""
|
||||
echo "--- 2h: Partial PATCH does not clobber other flags ---"
|
||||
# Set talk_to_user=false, then patch only broadcast — talk_to_user must stay false.
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}'
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": false}'
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
TUF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
BEF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "partial PATCH preserves talk_to_user_enabled=false" "$TUF" "False"
|
||||
assert "partial PATCH sets broadcast_enabled=false" "$BEF" "False"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Results: $PASS passed, $FAIL failed ==="
|
||||
[ "$FAIL" -eq 0 ]
|
||||
@@ -121,7 +121,7 @@ func main() {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
result, err := db.DB.ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
result, err := db.GetDB().ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
if err != nil {
|
||||
log.Printf("Activity log cleanup error: %v", err)
|
||||
} else if n, _ := result.RowsAffected(); n > 0 {
|
||||
@@ -184,7 +184,7 @@ func main() {
|
||||
// WorkspaceHandler) get the same plugin/resolver pair. memBundle
|
||||
// is nil when MEMORY_PLUGIN_URL is unset — every consumer
|
||||
// nil-checks before using.
|
||||
memBundle := memwiring.Build(db.DB)
|
||||
memBundle := memwiring.Build(db.GetDB())
|
||||
if memBundle != nil {
|
||||
wh.WithNamespaceCleanup(memBundle.NamespaceCleanupFn())
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func main() {
|
||||
// pending_uploads table grows unbounded; even with the 24h hard TTL,
|
||||
// nothing actually deletes a row, just makes it un-fetchable.
|
||||
go supervised.RunWithRecover(ctx, "pending-uploads-sweeper", func(c context.Context) {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.GetDB()), 0)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
@@ -513,7 +513,7 @@ func fixAdminTokenPlaceholder() {
|
||||
// Read the current stored value. We only upsert when the placeholder is
|
||||
// present so we don't repeatedly write rows that are already correct.
|
||||
var storedValue []byte
|
||||
err := db.DB.QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
err := db.GetDB().QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
if err != nil {
|
||||
// No row — nothing to fix. The control plane injects ADMIN_TOKEN via
|
||||
// Secrets Manager bootstrap; the global_secrets path is a legacy seed.
|
||||
@@ -545,7 +545,7 @@ func fixAdminTokenPlaceholder() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(`
|
||||
_, err = db.GetDB().Exec(`
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
|
||||
@@ -28,7 +28,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
var agentCard []byte
|
||||
var parentID *string
|
||||
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT name, COALESCE(role, ''), tier, status,
|
||||
COALESCE(agent_card, 'null'::jsonb), parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
@@ -79,7 +79,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
}
|
||||
|
||||
// Recursively export sub-workspaces
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, workspaceID)
|
||||
if err == nil {
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
@@ -41,7 +41,7 @@ func Import(
|
||||
}
|
||||
|
||||
// Create workspace record
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, status, parent_id, source_bundle_id)
|
||||
VALUES ($1, $2, $3, $4, 'provisioning', $5, $6)
|
||||
`, wsID, b.Name, nilIfEmpty(b.Description), b.Tier, parentID, b.ID)
|
||||
@@ -72,7 +72,7 @@ func Import(
|
||||
}
|
||||
}
|
||||
// Store runtime in DB
|
||||
_, _ = db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
_, _ = db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
|
||||
// Provision the container if provisioner is available
|
||||
if prov != nil {
|
||||
@@ -92,7 +92,7 @@ func Import(
|
||||
if err != nil {
|
||||
markFailed(provCtx, wsID, broadcaster, err)
|
||||
} else if url != "" {
|
||||
db.DB.ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
db.GetDB().ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func markFailed(ctx context.Context, wsID string, broadcaster *events.Broadcaste
|
||||
// markProvisionFailed in workspace-server/internal/handlers/
|
||||
// workspace_provision_shared.go.
|
||||
msg := err.Error()
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, last_sample_error = $2, updated_at = now() WHERE id = $3`,
|
||||
models.StatusFailed, msg, wsID)
|
||||
broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisionFailed), wsID, map[string]interface{}{
|
||||
|
||||
@@ -600,7 +600,7 @@ func TestManager_SendOutbound_NoChatID(t *testing.T) {
|
||||
|
||||
// The callback is a package-level var set by NewManager; we verify both its
|
||||
// default (safe no-op) and the wired-up path via a UPDATE assertion against
|
||||
// a sqlmock-backed db.DB. Two tests guard the contract: the var is callable
|
||||
// a sqlmock-backed db.GetDB(). Two tests guard the contract: the var is callable
|
||||
// at zero-value, and a wired callback issues the right UPDATE.
|
||||
|
||||
func TestDisableChannelByChatID_DefaultIsNoOp(t *testing.T) {
|
||||
|
||||
@@ -68,10 +68,10 @@ func NewManager(proxy A2AProxy, broadcaster Broadcaster) *Manager {
|
||||
// row disabled and reload in-memory manager state. Without this, outbound
|
||||
// messages keep trying the dead chat and log 403s forever.
|
||||
disableChannelByChatID = func(ctx context.Context, chatID string) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET enabled = false, updated_at = now()
|
||||
WHERE channel_type = 'telegram'
|
||||
@@ -122,7 +122,7 @@ func (m *Manager) PausePollersForToken(workspaceID, botToken string) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(context.Background(), `
|
||||
rows, err := db.GetDB().QueryContext(context.Background(), `
|
||||
SELECT id, channel_config FROM workspace_channels
|
||||
WHERE enabled = true AND workspace_id = $1
|
||||
`, workspaceID)
|
||||
@@ -185,7 +185,7 @@ func (m *Manager) Stop() {
|
||||
// Reload re-reads enabled channels from DB and diffs against running pollers.
|
||||
// New channels get started, removed/disabled channels get stopped.
|
||||
func (m *Manager) Reload(ctx context.Context) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE enabled = true
|
||||
@@ -374,8 +374,8 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
m.appendHistory(ctx, historyKey, msg.Username, msg.Text, replyText)
|
||||
|
||||
// Update stats in DB
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -419,8 +419,8 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
}
|
||||
}
|
||||
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -447,7 +447,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
// completion posts to both #mol-engineering AND #mol-firehose if the
|
||||
// workspace has both configured via chat_id comma-separation.
|
||||
func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) {
|
||||
if text == "" || db.DB == nil {
|
||||
if text == "" || db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
// Truncate to keep Slack messages digestible (rune-safe for CJK/emoji)
|
||||
@@ -457,7 +457,7 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
}
|
||||
// Only auto-post to Slack channels. Telegram is CEO-only — explicit
|
||||
// escalations via the agent's outbound call, never auto-post from crons.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND enabled = true AND channel_type = 'slack'
|
||||
`, workspaceID)
|
||||
@@ -478,10 +478,10 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
// FetchWorkspaceChannelContext returns recent Slack channel messages formatted
|
||||
// as ambient context for cron prompts (Level 3).
|
||||
func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT channel_config FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true
|
||||
LIMIT 1
|
||||
@@ -548,7 +548,7 @@ func truncID(id string) string {
|
||||
func (m *Manager) loadChannel(ctx context.Context, channelID string) (ChannelRow, error) {
|
||||
var ch ChannelRow
|
||||
var configJSON, allowedJSON []byte
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels WHERE id = $1
|
||||
`, channelID).Scan(&ch.ID, &ch.WorkspaceID, &ch.ChannelType, &configJSON, &ch.Enabled, &allowedJSON)
|
||||
|
||||
@@ -8,24 +8,57 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// mu guards DB against concurrent read/write. setupTestDB swaps the
|
||||
// connection during test cleanup; concurrent goroutines from the test
|
||||
// body may be reading DB at that moment.
|
||||
var mu sync.RWMutex
|
||||
|
||||
// DB is the package-level postgres connection. In production it is set
|
||||
// once by InitPostgres and never mutated. In tests, setupTestDB swaps it
|
||||
// for a sqlmock. Access via GetDB() to avoid data races.
|
||||
var DB *sql.DB
|
||||
|
||||
// GetDB returns the current *sql.DB, acquired under a read lock so that
|
||||
// concurrent readers (async goroutines from test bodies) and writers
|
||||
// (setupTestDB cleanup) do not race.
|
||||
func GetDB() *sql.DB {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return DB
|
||||
}
|
||||
|
||||
// Lock acquires an exclusive write lock on the DB. Used by test helpers
|
||||
// (setupTestDB) to safely swap db.DB without racing against concurrent
|
||||
// GetDB() readers.
|
||||
func Lock() {
|
||||
mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock releases the exclusive write lock acquired by Lock().
|
||||
func Unlock() {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func InitPostgres(databaseURL string) error {
|
||||
var err error
|
||||
DB, err = sql.Open("postgres", databaseURL)
|
||||
conn, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
DB.SetMaxOpenConns(25)
|
||||
DB.SetMaxIdleConns(5)
|
||||
conn.SetMaxOpenConns(25)
|
||||
conn.SetMaxIdleConns(5)
|
||||
|
||||
if err := DB.Ping(); err != nil {
|
||||
if err := conn.Ping(); err != nil {
|
||||
return fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
mu.Lock()
|
||||
DB = conn
|
||||
mu.Unlock()
|
||||
log.Println("Connected to Postgres")
|
||||
return nil
|
||||
}
|
||||
@@ -51,8 +84,9 @@ func InitPostgres(databaseURL string) error {
|
||||
// Migration authors must write idempotent SQL. A real schema_migrations
|
||||
// tracking table would be better; tracked as follow-up.
|
||||
func RunMigrations(migrationsDir string) error {
|
||||
realDB := GetDB()
|
||||
// Create tracking table if it doesn't exist.
|
||||
if _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
if _, err := realDB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)`); err != nil {
|
||||
@@ -81,7 +115,7 @@ func RunMigrations(migrationsDir string) error {
|
||||
|
||||
// Check if already applied.
|
||||
var exists bool
|
||||
if err := DB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
if err := realDB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
return fmt.Errorf("check migration %s: %w", base, err)
|
||||
}
|
||||
if exists {
|
||||
@@ -94,12 +128,12 @@ func RunMigrations(migrationsDir string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", f, err)
|
||||
}
|
||||
if _, err := DB.Exec(string(content)); err != nil {
|
||||
if _, err := realDB.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("exec %s: %w", base, err)
|
||||
}
|
||||
|
||||
// Record as applied.
|
||||
if _, err := DB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
if _, err := realDB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
return fmt.Errorf("record migration %s: %w", base, err)
|
||||
}
|
||||
applied++
|
||||
|
||||
@@ -17,7 +17,9 @@ func TestRunMigrations_FirstBoot_AppliesAndRecords(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -55,7 +57,9 @@ func TestRunMigrations_SecondBoot_SkipsApplied(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -92,7 +96,9 @@ func TestRunMigrations_MixedState_AppliesOnlyNew(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_old.up.sql"), []byte("SELECT 1;"), 0o644)
|
||||
@@ -135,7 +141,9 @@ func TestRunMigrations_SkipsDownSqlFilesEvenInTracking(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestWorkspaceStatusFailed_MustSetLastSampleError(t *testing.T) {
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// Match db.DB.ExecContext / db.DB.QueryContext / db.DB.QueryRowContext
|
||||
// Match db.GetDB().ExecContext / db.GetDB().QueryContext / db.GetDB().QueryRowContext
|
||||
// — the three SQL execution surfaces this codebase uses.
|
||||
methodName := sel.Sel.Name
|
||||
if methodName != "ExecContext" && methodName != "QueryContext" && methodName != "QueryRowContext" {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (b *Broadcaster) RecordAndBroadcast(ctx context.Context, eventType string,
|
||||
}
|
||||
|
||||
// Insert into structure_events — cast to jsonb explicitly
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, workspace_id, payload)
|
||||
VALUES ($1, $2, $3::jsonb)
|
||||
`, eventType, workspaceID, string(payloadJSON))
|
||||
|
||||
@@ -276,7 +276,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
if callerID == "" {
|
||||
if _, isOrg := c.Get("org_token_id"); !isOrg {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerID = wsID
|
||||
}
|
||||
}
|
||||
@@ -332,7 +332,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
func (h *WorkspaceHandler) checkWorkspaceBudget(ctx context.Context, workspaceID string) *proxyA2AError {
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&budgetLimit, &monthlySpend)
|
||||
@@ -623,7 +623,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
if err != nil {
|
||||
var urlNullable sql.NullString
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable, &status)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -645,7 +645,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
// the caller can retry once the workspace is back online (~10s).
|
||||
if status == "hibernated" {
|
||||
log.Printf("ProxyA2A: waking hibernated workspace %s", workspaceID)
|
||||
go h.RestartByID(workspaceID)
|
||||
h.goAsync(func() { h.RestartByID(workspaceID) })
|
||||
return "", &proxyA2AError{
|
||||
Status: http.StatusServiceUnavailable,
|
||||
Headers: map[string]string{"Retry-After": "15"},
|
||||
|
||||
@@ -161,7 +161,7 @@ func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspace
|
||||
// canvas-chat-to-dead-workspace incident traces to exactly this gap.
|
||||
func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool {
|
||||
var wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
if isExternalLikeRuntime(wsRuntime) {
|
||||
return false
|
||||
}
|
||||
@@ -189,7 +189,7 @@ func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspace
|
||||
return false
|
||||
}
|
||||
log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID)
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
@@ -234,7 +234,7 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
// (same effect as maybeMarkContainerDead's branch), and return the
|
||||
// structured 503 immediately so the caller skips the forward.
|
||||
log.Printf("ProxyA2A preflight: container for %s is not running — marking offline and triggering restart (#36)", workspaceID)
|
||||
if _, dbErr := db.DB.ExecContext(ctx,
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`,
|
||||
models.StatusOffline, workspaceID); dbErr != nil {
|
||||
log.Printf("ProxyA2A preflight: failed to mark workspace %s offline: %v", workspaceID, dbErr)
|
||||
@@ -257,7 +257,7 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) {
|
||||
errMsg := err.Error()
|
||||
var errWsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
if errWsName == "" {
|
||||
errWsName = workspaceID
|
||||
}
|
||||
@@ -289,7 +289,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
logStatus = "error"
|
||||
}
|
||||
var wsNameForLog string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
if wsNameForLog == "" {
|
||||
wsNameForLog = workspaceID
|
||||
}
|
||||
@@ -301,7 +301,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := db.DB.ExecContext(bgCtx,
|
||||
if _, err := db.GetDB().ExecContext(bgCtx,
|
||||
`UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil {
|
||||
log.Printf("last_outbound_at update failed for %s: %v", callerID, err)
|
||||
}
|
||||
@@ -354,7 +354,7 @@ func nilIfEmpty(s string) *string {
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), callerID)
|
||||
if err != nil {
|
||||
// Fail-open here matches the heartbeat path — A2A caller auth is
|
||||
// defense-in-depth on top of access-control hierarchy, not the
|
||||
@@ -371,7 +371,7 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) e
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"})
|
||||
return errInvalidCallerToken
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"})
|
||||
return err
|
||||
}
|
||||
@@ -475,7 +475,7 @@ func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) {
|
||||
// proxy-side read used for the short-circuit in proxyA2ARequest.
|
||||
func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if err != nil {
|
||||
@@ -505,7 +505,7 @@ func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
// without a public URL.
|
||||
func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string) {
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName == "" {
|
||||
wsName = workspaceID
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func EnqueueA2A(
|
||||
// ON CONFLICT — only true CONSTRAINTs work for that). On conflict we
|
||||
// then look up the existing row's id so the caller always receives a
|
||||
// valid queue entry reference.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO a2a_queue (workspace_id, caller_id, priority, body, method, idempotency_key, expires_at)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7)
|
||||
ON CONFLICT (workspace_id, idempotency_key)
|
||||
@@ -146,7 +146,7 @@ func EnqueueA2A(
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) && idempotencyKey != "" {
|
||||
// Conflict — look up the existing active row and use its id.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
AND status IN ('queued','dispatched')
|
||||
@@ -160,7 +160,7 @@ func EnqueueA2A(
|
||||
}
|
||||
|
||||
// Return current queue depth for the caller's visibility.
|
||||
_ = db.DB.QueryRowContext(ctx, `
|
||||
_ = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND status = 'queued'
|
||||
`, workspaceID).Scan(&depth)
|
||||
@@ -175,7 +175,7 @@ func EnqueueA2A(
|
||||
//
|
||||
// Returns (nil, nil) when the queue is empty — not an error.
|
||||
func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
// MarkQueueItemCompleted flips the queue row to 'completed' on a successful
|
||||
// drain dispatch.
|
||||
func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE a2a_queue SET status = 'completed', completed_at = now() WHERE id = $1`, id,
|
||||
); err != nil {
|
||||
log.Printf("A2AQueue: failed to mark %s completed: %v", id, err)
|
||||
@@ -233,7 +233,7 @@ func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
// forever.
|
||||
func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
const maxAttempts = 5
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE a2a_queue
|
||||
SET status = CASE WHEN attempts >= $2 THEN 'failed' ELSE 'queued' END,
|
||||
last_error = $3,
|
||||
@@ -249,7 +249,7 @@ func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
// can see how many ahead of them.
|
||||
func QueueDepth(ctx context.Context, workspaceID string) int {
|
||||
var n int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`,
|
||||
workspaceID,
|
||||
).Scan(&n)
|
||||
@@ -266,7 +266,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
var rows int64
|
||||
var err error
|
||||
if workspaceID != "" {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -285,7 +285,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
SELECT count(*) FROM dropped
|
||||
`, workspaceID, maxAgeMinutes).Scan(&rows)
|
||||
} else {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -419,7 +419,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = 'completed',
|
||||
summary = $1,
|
||||
|
||||
@@ -86,7 +86,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// so a completed delegation surfaces its result inline — non-delegation
|
||||
// queue rows simply won't have a matching activity_logs row and the field
|
||||
// stays null.
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
@@ -146,7 +146,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// the auth check without first projecting the public response.
|
||||
func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspaceID string, err error) {
|
||||
var callerNS, workspaceNS sql.NullString
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`,
|
||||
queueID,
|
||||
).Scan(&callerNS, &workspaceNS)
|
||||
@@ -185,7 +185,7 @@ func (h *WorkspaceHandler) GetA2AQueueStatus(c *gin.Context) {
|
||||
callerWorkspace := c.GetHeader("X-Workspace-ID")
|
||||
if !isOrg && callerWorkspace == "" {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerWorkspace = wsID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,15 +25,16 @@ import (
|
||||
|
||||
// setupTestDBForQueueTests creates a sqlmock DB using QueryMatcherEqual (exact
|
||||
// string matching) so that ExpectQuery/ExpectExec patterns are compared verbatim.
|
||||
// Uses the same global db.DB as setupTestDB so the handler can use it.
|
||||
// Uses the same global db.GetDB() as setupTestDB so the handler can use it.
|
||||
func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { mockDB.Close() })
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
return mock
|
||||
}
|
||||
|
||||
@@ -80,6 +81,54 @@ func TestExtractIdempotencyKey_emptyOnMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// extractExpiresInSeconds
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractExpiresInSeconds_valid(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"positive int", `{"params":{"expires_in_seconds":30}}`, 30},
|
||||
{"zero", `{"params":{"expires_in_seconds":0}}`, 0},
|
||||
{"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600},
|
||||
{"nested message — not affected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds = %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractExpiresInSeconds_invalidOrMissing(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0},
|
||||
{"missing expires_in_seconds", `{"params":{"message":{"role":"user"}}}`, 0},
|
||||
{"no params at all", `{"method":"message/send"}`, 0},
|
||||
{"malformed JSON", `not json`, 0},
|
||||
{"empty body", ``, 0},
|
||||
{"null value", `{"params":{"expires_in_seconds":null}}`, 0},
|
||||
{"string value", `{"params":{"expires_in_seconds":"30"}}`, 0},
|
||||
{"float value", `{"params":{"expires_in_seconds":30.5}}`, 30},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds(%q) = %d, want %d", tc.body, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDelegationIDFromBody(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
var cursorTime time.Time
|
||||
usingCursor := false
|
||||
if sinceID != "" {
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT created_at FROM activity_logs WHERE id = $1 AND workspace_id = $2`,
|
||||
sinceID, workspaceID,
|
||||
).Scan(&cursorTime)
|
||||
@@ -222,7 +222,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), query, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), query, args...)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Activity list error for %s: %v", workspaceID, err)
|
||||
@@ -285,7 +285,7 @@ func (h *ActivityHandler) SessionSearch(c *gin.Context) {
|
||||
|
||||
sqlQuery, args := buildSessionSearchQuery(workspaceID, query, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "session search failed"})
|
||||
return
|
||||
@@ -476,12 +476,19 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
|
||||
for _, a := range body.Attachments {
|
||||
attachments = append(attachments, AgentMessageAttachment(a))
|
||||
}
|
||||
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
|
||||
writer := NewAgentMessageWriter(db.GetDB(), h.broadcaster)
|
||||
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
|
||||
if errors.Is(err, ErrWorkspaceNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrTalkToUserDisabled) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "talk_to_user_disabled",
|
||||
"hint": "This workspace is not allowed to send messages directly to the user. Forward your update to a parent workspace using delegate_task — they may be able to reach the user.",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
@@ -580,7 +587,7 @@ func (h *ActivityHandler) Report(c *gin.Context) {
|
||||
// most callers expect. For atomic-with-sibling-writes use LogActivityTx
|
||||
// and propagate the error.
|
||||
func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) {
|
||||
hook, err := logActivityExec(ctx, db.DB, broadcaster, params)
|
||||
hook, err := logActivityExec(ctx, db.GetDB(), broadcaster, params)
|
||||
if err != nil {
|
||||
log.Printf("LogActivity insert error: %v", err)
|
||||
return
|
||||
@@ -608,7 +615,7 @@ func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmit
|
||||
|
||||
// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx
|
||||
// and *sql.DB both satisfy it, so the same insert path serves the
|
||||
// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx).
|
||||
// fire-and-forget caller (db.GetDB()) and the Tx-aware caller (*sql.Tx).
|
||||
type activityExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
@@ -388,9 +388,13 @@ func TestActivityList_BeforeTSRejectsInvalidFormat(t *testing.T) {
|
||||
// ---------- Activity type allowlist (#125: memory_write added) ----------
|
||||
|
||||
func TestActivityReport_AcceptsMemoryWriteType(t *testing.T) {
|
||||
mockDB, mock, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
@@ -413,9 +417,13 @@ func TestActivityReport_AcceptsMemoryWriteType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestActivityReport_RejectsUnknownType(t *testing.T) {
|
||||
mockDB, _, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, _, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewActivityHandler(broadcaster)
|
||||
@@ -447,14 +455,18 @@ func TestNotify_PersistsToActivityLogsForReloadRecovery(t *testing.T) {
|
||||
// - Have source_id NULL (canvas-source filter)
|
||||
// - Carry the message text in response_body so extractResponseText
|
||||
// can reconstruct the agent reply on reload
|
||||
mockDB, mock, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
// Workspace existence check
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-notify").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Persistence INSERT — verify shape
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -491,13 +503,17 @@ func TestNotify_WithAttachments_PersistsFilePartsForReload(t *testing.T) {
|
||||
// download chips after a page reload. Without `parts`, the bubble
|
||||
// shows up but the attachment chip is silently dropped on every
|
||||
// refresh.
|
||||
mockDB, mock, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-attach").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Capture the JSONB arg so we can assert on the persisted shape
|
||||
// AFTER the call (must include parts[].kind=file so reload
|
||||
@@ -565,9 +581,13 @@ func TestNotify_RejectsAttachmentWithEmptyURIOrName(t *testing.T) {
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockDB, _, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, _, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
// No DB expectations — handler must reject with 400 BEFORE
|
||||
// reaching SELECT/INSERT. sqlmock will fail "expectations not met"
|
||||
// only if the handler unexpectedly queries.
|
||||
@@ -612,13 +632,17 @@ func TestNotify_DBFailure_StillBroadcastsAnd200(t *testing.T) {
|
||||
// WebSocket push (which the user is already seeing in their open
|
||||
// canvas). Pre-fix the WS push always succeeded; we don't want
|
||||
// the new persistence step to regress that path.
|
||||
mockDB, mock, _ := sqlmock.New()
|
||||
defer mockDB.Close()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-x").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(fmt.Errorf("simulated db hiccup"))
|
||||
|
||||
@@ -925,7 +949,7 @@ func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) {
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
@@ -969,7 +993,7 @@ func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) {
|
||||
WillReturnError(errors.New("constraint violation simulated"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ type AdminDelegationsHandler struct {
|
||||
|
||||
func NewAdminDelegationsHandler(handle *sql.DB) *AdminDelegationsHandler {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &AdminDelegationsHandler{db: handle}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT am.id, am.content, am.scope, am.namespace, am.created_at,
|
||||
w.name AS workspace_name
|
||||
FROM agent_memories am
|
||||
@@ -183,7 +183,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
for _, entry := range entries {
|
||||
// 1. Resolve workspace by name
|
||||
var workspaceID string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID)
|
||||
@@ -205,7 +205,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// secret (same placeholder output) are treated as duplicates.
|
||||
var exists bool
|
||||
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM agent_memories WHERE workspace_id = $1 AND content = $2 AND scope = $3)`,
|
||||
workspaceID, content, entry.Scope,
|
||||
).Scan(&exists)
|
||||
@@ -226,12 +226,12 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
}
|
||||
|
||||
if entry.CreatedAt != "" {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace, created_at) VALUES ($1, $2, $3, $4, $5)`,
|
||||
workspaceID, content, entry.Scope, namespace, entry.CreatedAt,
|
||||
)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4)`,
|
||||
workspaceID, content, entry.Scope, namespace,
|
||||
)
|
||||
@@ -277,7 +277,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// N_workspaces resolver + N_workspaces plugin in the old code).
|
||||
func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) {
|
||||
// 1. One SQL pass: every workspace + its root id.
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.DB)
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover): workspaces query: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"})
|
||||
@@ -445,7 +445,7 @@ func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Conte
|
||||
|
||||
for _, entry := range entries {
|
||||
var workspaceID string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID); err != nil {
|
||||
|
||||
@@ -71,7 +71,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
TrackedRef string `json:"tracked_ref"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT workspace_id, plugin_name, tracked_ref, status
|
||||
FROM plugin_update_queue
|
||||
WHERE id = $1
|
||||
@@ -108,7 +108,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
|
||||
// Step 2: read the workspace_plugins row to get source_raw.
|
||||
var sourceRaw string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT source_raw FROM workspace_plugins
|
||||
WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, entry.WorkspaceID, entry.PluginName).Scan(&sourceRaw)
|
||||
@@ -177,7 +177,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Step 4: mark queue entry as applied.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE plugin_update_queue SET status = 'applied' WHERE id = $1
|
||||
`, queueID); err != nil {
|
||||
log.Printf("AdminPluginDrift: apply: failed to mark queue entry %s as applied: %v", queueID, err)
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *AdminSchedulesHealthHandler) Health(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
now := time.Now()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT
|
||||
w.id AS workspace_id,
|
||||
w.name AS workspace_name,
|
||||
|
||||
@@ -80,7 +80,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
// Confirm the workspace exists — a missing workspace also 404s so we
|
||||
// can't be used to probe for arbitrary IDs.
|
||||
var exists string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT id FROM workspaces WHERE id = $1`, workspaceID).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -91,7 +91,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token issue failed"})
|
||||
return
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestAdminTestToken_HappyPath_TokenValidates(t *testing.T) {
|
||||
mock.ExpectExec("UPDATE workspace_auth_tokens SET last_used_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.DB, "ws-1", resp.AuthToken); err != nil {
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.GetDB(), "ws-1", resp.AuthToken); err != nil {
|
||||
t.Errorf("issued token failed to validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check workspace exists
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID).Scan(&status)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check no active agent already assigned
|
||||
var existingCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, workspaceID,
|
||||
).Scan(&existingCount); err != nil {
|
||||
log.Printf("Agent assign check error: %v", err)
|
||||
@@ -60,7 +60,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Insert agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Deactivate current agent
|
||||
var oldModel string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'replaced', removed_at = now(), removal_reason = 'model_replaced'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING model`,
|
||||
workspaceID,
|
||||
@@ -109,7 +109,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Insert new agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -133,7 +133,7 @@ func (h *AgentHandler) Remove(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var agentID, model string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'removed', removed_at = now(), removal_reason = 'manual_removal'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
workspaceID,
|
||||
@@ -171,7 +171,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target workspace exists
|
||||
var targetStatus string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, body.TargetWorkspaceID).Scan(&targetStatus)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "target workspace not found"})
|
||||
@@ -185,7 +185,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target doesn't already have an agent
|
||||
var targetAgentCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, body.TargetWorkspaceID,
|
||||
).Scan(&targetAgentCount); err != nil {
|
||||
log.Printf("Move agent target check error: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Move the agent: update workspace_id
|
||||
var agentID, model string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET workspace_id = $2
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
sourceID, body.TargetWorkspaceID,
|
||||
|
||||
@@ -54,6 +54,11 @@ import (
|
||||
// timeout) surface as wrapped errors and should be treated as 503.
|
||||
var ErrWorkspaceNotFound = errors.New("agent_message: workspace not found")
|
||||
|
||||
// ErrTalkToUserDisabled is returned when the workspace has
|
||||
// talk_to_user_enabled=false. Callers surface HTTP 403 so the Python tool
|
||||
// can detect it and suggest forwarding to a parent workspace.
|
||||
var ErrTalkToUserDisabled = errors.New("agent_message: talk_to_user disabled")
|
||||
|
||||
// AgentMessageAttachment is one file attached to an agent → user
|
||||
// message. Identical to handlers.NotifyAttachment in field set; kept
|
||||
// distinct so the writer's API doesn't import a handler type with HTTP
|
||||
@@ -107,16 +112,20 @@ func (w *AgentMessageWriter) Send(
|
||||
// notify call surfaced as "workspace not found" and masked real
|
||||
// incidents in the alert path.
|
||||
var wsName string
|
||||
var talkToUserEnabled bool
|
||||
err := w.db.QueryRowContext(ctx,
|
||||
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
`SELECT name, talk_to_user_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
workspaceID,
|
||||
).Scan(&wsName)
|
||||
).Scan(&wsName, &talkToUserEnabled)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrWorkspaceNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent_message: workspace lookup: %w", err)
|
||||
}
|
||||
if !talkToUserEnabled {
|
||||
return ErrTalkToUserDisabled
|
||||
}
|
||||
|
||||
// 2. Build broadcast payload + WS-emit. Same shape that ChatTab's
|
||||
// AGENT_MESSAGE handler in canvas/src/store/canvas-events.ts has
|
||||
|
||||
@@ -86,11 +86,11 @@ func (c *capturingEmitter) RecordAndBroadcast(_ context.Context, eventType strin
|
||||
// path: workspace lookup, broadcast, INSERT, return nil.
|
||||
func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -114,11 +114,11 @@ func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
// Drift here = chips disappear on chat reload.
|
||||
func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-att").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -171,11 +171,11 @@ func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}))
|
||||
|
||||
err := w.Send(context.Background(), "ws-missing", "lost in the void", nil)
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
@@ -200,11 +200,11 @@ func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
// broadcast.
|
||||
func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbfail").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(errors.New("transient db error"))
|
||||
@@ -221,11 +221,11 @@ func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
// table doesn't carry multi-KB summaries that bloat list queries.
|
||||
func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-trunc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
longMsg := strings.Repeat("x", 200)
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -261,11 +261,11 @@ func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-bc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Workspace Name"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Workspace Name", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
@@ -312,10 +312,10 @@ func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
// real incidents in alerting.
|
||||
func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
transientErr := errors.New("connection refused")
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbdown").
|
||||
WillReturnError(transientErr)
|
||||
|
||||
@@ -344,15 +344,15 @@ func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
// coverage. Now it does.
|
||||
func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
// 200-rune CJK message — exceeds the 80-rune cap, would have hit
|
||||
// the byte-slice bug.
|
||||
msg := strings.Repeat("你", 200)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-cjk").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WithArgs(
|
||||
@@ -393,11 +393,11 @@ func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-noatt").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("X"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("X", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var approvalID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO approval_requests (workspace_id, task_id, action, reason, context)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -60,7 +60,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Auto-escalate to parent
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventApprovalEscalated), *parentID, map[string]interface{}{
|
||||
"approval_id": approvalID,
|
||||
@@ -80,12 +80,12 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Auto-expire stale approvals (older than 10 min)
|
||||
db.DB.ExecContext(ctx, `
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests SET status = 'denied', decided_by = 'auto-expired', decided_at = now()
|
||||
WHERE status = 'pending' AND created_at < now() - interval '10 minutes'
|
||||
`)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT a.id, a.workspace_id, w.name, a.action, a.reason, a.status, a.created_at
|
||||
FROM approval_requests a
|
||||
JOIN workspaces w ON w.id = a.workspace_id
|
||||
@@ -116,6 +116,9 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListPendingApprovals rows.Err: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -125,7 +128,7 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, task_id, action, reason, status, decided_by, decided_at, created_at
|
||||
FROM approval_requests WHERE workspace_id = $1
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
@@ -155,6 +158,9 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListApprovals rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -184,7 +190,7 @@ func (h *ApprovalsHandler) Decide(c *gin.Context) {
|
||||
decidedBy = "human"
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests
|
||||
SET status = $1, decided_by = $2, decided_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4 AND status = 'pending'
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Reject if already linked.
|
||||
var exists bool
|
||||
db.DB.QueryRowContext(ctx,
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspace_artifacts WHERE workspace_id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists)
|
||||
@@ -193,7 +193,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
remoteURL := stripCredentials(repo.RemoteURL)
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_artifacts
|
||||
(workspace_id, cf_repo_name, cf_namespace, remote_url, description)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -223,7 +223,7 @@ func (h *ArtifactsHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, cf_repo_name, cf_namespace, remote_url, description, created_at, updated_at
|
||||
FROM workspace_artifacts
|
||||
WHERE workspace_id = $1
|
||||
@@ -287,7 +287,7 @@ func (h *ArtifactsHandler) Fork(c *gin.Context) {
|
||||
|
||||
// Look up the source repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
@@ -352,7 +352,7 @@ func (h *ArtifactsHandler) Token(c *gin.Context) {
|
||||
|
||||
// Look up the linked CF repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
|
||||
@@ -179,7 +179,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
// Count total matching rows (for pagination) ----------------------------
|
||||
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
||||
var total int
|
||||
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
@@ -192,7 +192,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
FROM audit_events ` + where +
|
||||
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
if err != nil {
|
||||
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
|
||||
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0)
|
||||
FROM workspaces
|
||||
WHERE id = $1 AND status != 'removed'`,
|
||||
@@ -119,7 +119,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
|
||||
// Existence check — return 404 for non-existent / removed workspaces.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
@@ -127,7 +127,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET budget_limit = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, budgetArg,
|
||||
); err != nil {
|
||||
@@ -140,7 +140,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
// the DB, including the monthly_spend the agent has already accumulated.
|
||||
var newLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&newLimit, &monthlySpend); err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *ChannelHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users,
|
||||
last_message_at, message_count, created_at, updated_at
|
||||
FROM workspace_channels WHERE workspace_id = $1
|
||||
@@ -166,7 +166,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -222,7 +222,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET channel_config = COALESCE($3::jsonb, channel_config),
|
||||
allowed_users = COALESCE($4::jsonb, allowed_users),
|
||||
@@ -252,7 +252,7 @@ func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
channelID := c.Param("channelId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_channels WHERE id = $1 AND workspace_id = $2
|
||||
`, channelID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -291,7 +291,7 @@ func (h *ChannelHandler) Send(c *gin.Context) {
|
||||
// transient DB hiccup doesn't silently block outbound messages.
|
||||
var msgCount int
|
||||
var budget sql.NullInt64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT message_count, channel_budget FROM workspace_channels WHERE id = $1`,
|
||||
channelID,
|
||||
).Scan(&msgCount, &budget); err != nil && err != sql.ErrNoRows {
|
||||
@@ -476,7 +476,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Look up channels by type and find one whose chat_id list contains msg.ChatID.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = $1 AND enabled = true
|
||||
@@ -577,7 +577,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
// the incoming request with 401 (fail-closed behaviour).
|
||||
func discordPublicKey(ctx context.Context) string {
|
||||
var pubKey string
|
||||
row := db.DB.QueryRowContext(ctx, `
|
||||
row := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COALESCE(channel_config->>'app_public_key', '')
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = 'discord' AND enabled = true
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/channels"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -565,6 +566,20 @@ func TestChannelHandler_Discover_MissingToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { mockDB.Close() })
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB })
|
||||
|
||||
mock.ExpectQuery(`SELECT id, channel_config FROM workspace_channels WHERE enabled = true AND workspace_id`).
|
||||
WithArgs("ws-test").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "channel_config"}))
|
||||
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
// #329: workspace_id required — include so we actually reach the
|
||||
@@ -588,6 +603,20 @@ func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_InvalidBotToken(t *testing.T) {
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { mockDB.Close() })
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB })
|
||||
|
||||
mock.ExpectQuery(`SELECT id, channel_config FROM workspace_channels WHERE enabled = true AND workspace_id`).
|
||||
WithArgs("ws-test").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "channel_config"}))
|
||||
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
|
||||
@@ -133,7 +133,7 @@ const chatUploadMaxBytes = 50 * 1024 * 1024
|
||||
// extraction prevents that class on the consumer side.
|
||||
func resolveWorkspaceForwardCreds(c *gin.Context, ctx context.Context, workspaceID, op string) (wsURL, secret string, ok bool) {
|
||||
var deliveryMode sql.NullString
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(url, ''), delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsURL, &deliveryMode); err != nil {
|
||||
log.Printf("chat_files %s: workspace lookup failed for %s: %v", op, workspaceID, err)
|
||||
@@ -468,7 +468,7 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
// the workspace-side row IS the source of truth for the mode).
|
||||
func lookupUploadDeliveryMode(c *gin.Context, ctx context.Context, workspaceID string) (string, bool) {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -656,7 +656,7 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
// Commit — emitting an ACTIVITY_LOGGED event for a row that ends up
|
||||
// rolled back would leak a ghost message into the canvas's
|
||||
// optimistic UI.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// Unit tests for chat_files.go.
|
||||
//
|
||||
// Upload (HTTP-forward, RFC #2312 PR-C): exercised against an httptest
|
||||
// mock workspace + sqlmock-backed db.DB. The platform-side handler is
|
||||
// mock workspace + sqlmock-backed db.GetDB(). The platform-side handler is
|
||||
// now a streaming proxy; assertions focus on:
|
||||
// * input validation (400 on bad workspace id)
|
||||
// * resolution failures (404 missing row, 503 missing secret/url)
|
||||
|
||||
@@ -15,7 +15,7 @@ type CheckpointsHandler struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.DB
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.GetDB()
|
||||
// at router-setup time; pass a sqlmock DB in tests.
|
||||
func NewCheckpointsHandler(database *sql.DB) *CheckpointsHandler {
|
||||
return &CheckpointsHandler{db: database}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func newCheckpointsHandler(t *testing.T, mock sqlmock.Sqlmock) *CheckpointsHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewCheckpointsHandler(db.DB)
|
||||
return NewCheckpointsHandler(db.GetDB())
|
||||
}
|
||||
|
||||
// ---------- Upsert ----------
|
||||
|
||||
@@ -20,7 +20,7 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
var data []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT data FROM workspace_config WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&data)
|
||||
@@ -58,7 +58,7 @@ func (h *ConfigHandler) Patch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err = db.GetDB().ExecContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_config(workspace_id, data, updated_at)
|
||||
VALUES($1, $2::jsonb, NOW())
|
||||
ON CONFLICT(workspace_id) DO UPDATE
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *TemplatesHandler) findContainer(ctx context.Context, workspaceID string
|
||||
}
|
||||
// Also check by workspace name from DB
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
if status == "failed" {
|
||||
summary = "Delegation failed"
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (
|
||||
workspace_id, activity_type, method, source_id,
|
||||
summary, request_body, response_body, status, error_detail
|
||||
@@ -207,7 +207,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
var existingID, existingStatus, existingTarget string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status, target_id
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -217,7 +217,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
if existingStatus == "failed" {
|
||||
_, _ = db.DB.ExecContext(ctx, `
|
||||
_, _ = db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2 AND status = 'failed'
|
||||
`, sourceID, idempotencyKey)
|
||||
@@ -262,14 +262,20 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
"task": body.Task,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// (which reads response_body->>delegation_id) can locate this row even
|
||||
// when request_body hasn't propagated yet. Fixes mc#984.
|
||||
respJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
var idemArg interface{}
|
||||
if body.IdempotencyKey != "" {
|
||||
idemArg = body.IdempotencyKey
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, status, idempotency_key)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, 'pending', $6)
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), idemArg)
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status, idempotency_key)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'pending', $7)
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON), idemArg)
|
||||
if err == nil {
|
||||
// RFC #2829 #318 — mirror to the durable delegations ledger
|
||||
// (gated by DELEGATION_LEDGER_WRITE; default off → no-op).
|
||||
@@ -281,7 +287,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
// rather than a generic 500. Re-query to fetch the winner's id.
|
||||
if body.IdempotencyKey != "" {
|
||||
var winnerID, winnerStatus string
|
||||
if qerr := db.DB.QueryRowContext(ctx, `
|
||||
if qerr := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -377,7 +383,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: failed — %s", delegationID, proxyErr.Error())
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", proxyErr.Error())
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", proxyErr.Error()); err != nil {
|
||||
@@ -397,7 +403,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: step=handling_failure err=%s", delegationID, errMsg)
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", errMsg)
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", errMsg); err != nil {
|
||||
@@ -436,7 +442,7 @@ handleSuccess:
|
||||
"delegation_id": delegationID,
|
||||
"queued": true,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
@@ -459,7 +465,7 @@ handleSuccess:
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -491,7 +497,7 @@ handleSuccess:
|
||||
// updateDelegationStatus updates the status of a delegation record in activity_logs.
|
||||
// ctx is used for DB operations; caller controls the timeout/retry budget.
|
||||
func (h *DelegationHandler) updateDelegationStatus(ctx context.Context, workspaceID, delegationID, status, errorDetail string) {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = $1, error_detail = CASE WHEN $2 = '' THEN error_detail ELSE $2 END
|
||||
WHERE workspace_id = $3
|
||||
@@ -544,10 +550,15 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
"task": body.Task,
|
||||
"delegation_id": body.DelegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, 'dispatched')
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON)); err != nil {
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// can locate this row. Fixes mc#984.
|
||||
respJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": body.DelegationID,
|
||||
})
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'dispatched')
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON)); err != nil {
|
||||
log.Printf("Delegation Record: insert failed for %s: %v", body.DelegationID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to record delegation"})
|
||||
return
|
||||
@@ -611,7 +622,7 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
"text": body.ResponsePreview,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -669,7 +680,7 @@ func (h *DelegationHandler) ListDelegations(c *gin.Context) {
|
||||
// listDelegationsFromLedger queries the durable delegations table.
|
||||
// Returns nil on error so the caller can fall back to activity_logs.
|
||||
func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT d.delegation_id, d.caller_id, d.callee_id, d.task_preview,
|
||||
d.status, d.result_preview, d.error_detail, d.last_heartbeat,
|
||||
d.deadline, d.created_at, d.updated_at
|
||||
@@ -735,7 +746,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
// Kept for backward compatibility and for workspaces that never had
|
||||
// DELEGATION_LEDGER_WRITE=1 during their delegation lifecycle.
|
||||
func (h *DelegationHandler) listDelegationsFromActivityLogs(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, activity_type, COALESCE(source_id::text, ''), COALESCE(target_id::text, ''),
|
||||
COALESCE(summary, ''), COALESCE(status, ''), COALESCE(error_detail, ''),
|
||||
COALESCE(response_body->>'text', response_body::text, ''),
|
||||
|
||||
@@ -46,7 +46,7 @@ type DelegationLedger struct {
|
||||
// Tests can construct one with a sqlmock-backed *sql.DB.
|
||||
func NewDelegationLedger(handle *sql.DB) *DelegationLedger {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &DelegationLedger{db: handle}
|
||||
}
|
||||
|
||||
@@ -78,11 +78,17 @@ func integrationDB(t *testing.T) *sql.DB {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
// Wire the package-level db.DB so production helpers (recordLedgerInsert,
|
||||
// recordLedgerStatus) see the same connection.
|
||||
// recordLedgerStatus) see the same connection. Guard the swap with mdb.Lock()
|
||||
// to prevent races with production goroutines that call GetDB() (which
|
||||
// acquires RLock) while t.Cleanup runs concurrently.
|
||||
prev := mdb.DB
|
||||
mdb.Lock()
|
||||
mdb.DB = conn
|
||||
mdb.Unlock()
|
||||
t.Cleanup(func() {
|
||||
mdb.Lock()
|
||||
mdb.DB = prev
|
||||
mdb.Unlock()
|
||||
conn.Close()
|
||||
})
|
||||
return conn
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
|
||||
func TestLedgerInsert_HappyPath(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
l := NewDelegationLedger(nil) // uses package db.DB which sqlmock replaced
|
||||
l := NewDelegationLedger(nil) // uses package db.GetDB() which sqlmock replaced
|
||||
|
||||
mock.ExpectExec(`INSERT INTO delegations`).
|
||||
WithArgs(
|
||||
|
||||
@@ -0,0 +1,447 @@
|
||||
package handlers
|
||||
|
||||
// delegation_list_test.go — unit tests for listDelegationsFromLedger and
|
||||
// listDelegationsFromActivityLogs. Both methods are the data-backend of the
|
||||
// ListDelegations handler; coverage was missing (cf. infra-sre review of PR #942).
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
)
|
||||
|
||||
// ---------- listDelegationsFromLedger ----------
|
||||
|
||||
func TestListDelegationsFromLedger_EmptyResult(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
})
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if got != nil {
|
||||
t.Errorf("empty result: expected nil, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_SingleRow(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
// Use time.Time{} for nullable *time.Time columns — sqlmock passes the
|
||||
// zero value to the handler's scan destination. The handler checks Valid
|
||||
// before using each nullable field, so zero values are safe.
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
}).AddRow(
|
||||
"del-1", "ws-1", "ws-2", "summarise the report",
|
||||
"completed", "the report is about Q1",
|
||||
"", now, now, now, now,
|
||||
)
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(got))
|
||||
}
|
||||
e := got[0]
|
||||
if e["delegation_id"] != "del-1" {
|
||||
t.Errorf("delegation_id: got %v, want del-1", e["delegation_id"])
|
||||
}
|
||||
if e["source_id"] != "ws-1" {
|
||||
t.Errorf("source_id: got %v, want ws-1", e["source_id"])
|
||||
}
|
||||
if e["target_id"] != "ws-2" {
|
||||
t.Errorf("target_id: got %v, want ws-2", e["target_id"])
|
||||
}
|
||||
if e["status"] != "completed" {
|
||||
t.Errorf("status: got %v, want completed", e["status"])
|
||||
}
|
||||
if e["response_preview"] != "the report is about Q1" {
|
||||
t.Errorf("response_preview: got %v", e["response_preview"])
|
||||
}
|
||||
if _, ok := e["error"]; ok {
|
||||
t.Errorf("error should be absent when empty, got %v", e["error"])
|
||||
}
|
||||
if e["_ledger"] != true {
|
||||
t.Errorf("_ledger marker: got %v, want true", e["_ledger"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_MultipleRows(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
}).
|
||||
AddRow("del-a", "ws-1", "ws-2", "task a", "in_progress", "", "", now, now, now, now).
|
||||
AddRow("del-b", "ws-1", "ws-3", "task b", "failed", "", "timeout", now, now, now, now).
|
||||
AddRow("del-c", "ws-1", "ws-4", "task c", "completed", "result c", "", now, now, now, now)
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3 entries, got %d", len(got))
|
||||
}
|
||||
if got[0]["delegation_id"] != "del-a" || got[1]["delegation_id"] != "del-b" || got[2]["delegation_id"] != "del-c" {
|
||||
t.Errorf("unexpected order: %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_QueryError(t *testing.T) {
|
||||
// Query failure returns nil — graceful fallback, no panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if got != nil {
|
||||
t.Errorf("query error: expected nil, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_RowsErr(t *testing.T) {
|
||||
// rows.Err() mid-stream: handler collects partial results and returns them.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
// RowError(0) before AddRow(0): row 0 is "bad", rows.Next() returns false
|
||||
// on first call — the row never scans, result stays nil. To get partial
|
||||
// results (row 0 scanned) with rows.Err() non-nil, we use 2 rows and put
|
||||
// RowError(1) after AddRow(1): row 0 scans normally, row 1 is bad,
|
||||
// rows.Err() is error, handler returns partial result.
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
}).
|
||||
AddRow("del-1", "ws-1", "ws-2", "task", "queued", "", "", now, now, now, now).
|
||||
AddRow("del-2", "ws-1", "ws-3", "another task", "queued", "", "", now, now, now, now).
|
||||
RowError(1, context.DeadlineExceeded)
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
// Row 0 scanned and appended; row 1 is bad; rows.Err() is non-nil.
|
||||
// Handler logs the error but returns result (partial results because result != nil).
|
||||
if got == nil || len(got) != 1 {
|
||||
t.Errorf("rows.Err path: expected 1 partial result, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListDelegationsFromLedger_ScanError is removed.
|
||||
//
|
||||
// In Go 1.25 sqlmock.NewRows validates column count at AddRow() time and
|
||||
// panics when len(values) != len(columns). The old pattern
|
||||
// sqlmock.NewRows([]string{}).AddRow("only-one-col")
|
||||
// therefore panics in test SETUP, not inside the handler. The handler has no
|
||||
// recover(), so a scan panic would propagate out of listDelegationsFromLedger
|
||||
// and crash the process — this is the correct behaviour (not silently skipping
|
||||
// a row). The correct way to cover this path is a real-DB integration test.
|
||||
//
|
||||
// ---------- listDelegationsFromActivityLogs ----------
|
||||
|
||||
func TestListDelegationsFromActivityLogs_EmptyResult(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "activity_type", "source_id", "target_id",
|
||||
"summary", "status", "error_detail",
|
||||
"response_preview", "delegation_id", "created_at",
|
||||
})
|
||||
mock.ExpectQuery("SELECT .+ FROM activity_logs").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromActivityLogs(context.Background(), "ws-1")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("empty result: expected empty slice, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromActivityLogs_SingleDelegateRow(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "activity_type", "source_id", "target_id",
|
||||
"summary", "status", "error_detail",
|
||||
"response_preview", "delegation_id", "created_at",
|
||||
}).AddRow(
|
||||
"act-1", "delegate",
|
||||
"ws-1", "ws-2",
|
||||
"analyse Q1 numbers",
|
||||
"in_progress",
|
||||
"", "", "",
|
||||
now,
|
||||
)
|
||||
mock.ExpectQuery("SELECT .+ FROM activity_logs").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromActivityLogs(context.Background(), "ws-1")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(got))
|
||||
}
|
||||
e := got[0]
|
||||
if e["id"] != "act-1" {
|
||||
t.Errorf("id: got %v, want act-1", e["id"])
|
||||
}
|
||||
if e["type"] != "delegate" {
|
||||
t.Errorf("type: got %v, want delegate", e["type"])
|
||||
}
|
||||
if e["source_id"] != "ws-1" {
|
||||
t.Errorf("source_id: got %v, want ws-1", e["source_id"])
|
||||
}
|
||||
if e["target_id"] != "ws-2" {
|
||||
t.Errorf("target_id: got %v, want ws-2", e["target_id"])
|
||||
}
|
||||
if e["summary"] != "analyse Q1 numbers" {
|
||||
t.Errorf("summary: got %v", e["summary"])
|
||||
}
|
||||
if e["status"] != "in_progress" {
|
||||
t.Errorf("status: got %v", e["status"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromActivityLogs_DelegateResultWithError(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "activity_type", "source_id", "target_id",
|
||||
"summary", "status", "error_detail",
|
||||
"response_preview", "delegation_id", "created_at",
|
||||
}).AddRow(
|
||||
"act-2", "delegate_result",
|
||||
"ws-1", "ws-2",
|
||||
"result summary",
|
||||
"failed",
|
||||
"Callee workspace not reachable",
|
||||
`{"text":"the result body text"}`,
|
||||
"del-abc",
|
||||
now,
|
||||
)
|
||||
mock.ExpectQuery("SELECT .+ FROM activity_logs").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromActivityLogs(context.Background(), "ws-1")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(got))
|
||||
}
|
||||
e := got[0]
|
||||
if e["type"] != "delegate_result" {
|
||||
t.Errorf("type: got %v", e["type"])
|
||||
}
|
||||
if e["error"] != "Callee workspace not reachable" {
|
||||
t.Errorf("error: got %v", e["error"])
|
||||
}
|
||||
if e["response_preview"] != `{"text":"the result body text"}` {
|
||||
t.Errorf("response_preview: got %v", e["response_preview"])
|
||||
}
|
||||
if e["delegation_id"] != "del-abc" {
|
||||
t.Errorf("delegation_id: got %v", e["delegation_id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromActivityLogs_QueryError(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM activity_logs").
|
||||
WithArgs("ws-1").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromActivityLogs(context.Background(), "ws-1")
|
||||
// Error → returns empty slice, not nil.
|
||||
if len(got) != 0 {
|
||||
t.Errorf("query error: expected empty slice, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromActivityLogs_RowsErr(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
now := time.Now()
|
||||
// RowError(0) before AddRow(0): row 0 is "bad", rows.Next() returns false
|
||||
// on first call — the row never scans, result stays nil. To get partial
|
||||
// results (row 0 scanned) with rows.Err() non-nil, we use 2 rows and put
|
||||
// RowError(1) after AddRow(1): row 0 scans normally, row 1 is bad,
|
||||
// rows.Err() is error, handler returns partial result.
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "activity_type", "source_id", "target_id",
|
||||
"summary", "status", "error_detail",
|
||||
"response_preview", "delegation_id", "created_at",
|
||||
}).
|
||||
AddRow("act-1", "delegate", "ws-1", "ws-2", "task", "queued", "", "", "", now).
|
||||
AddRow("act-2", "delegate", "ws-1", "ws-3", "another task", "queued", "", "", "", now).
|
||||
RowError(1, context.DeadlineExceeded)
|
||||
mock.ExpectQuery("SELECT .+ FROM activity_logs").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromActivityLogs(context.Background(), "ws-1")
|
||||
// Row 0 scanned and appended; row 1 is bad; rows.Err() is non-nil.
|
||||
// Handler logs the error but returns result (partial results because result != nil).
|
||||
if got == nil || len(got) != 1 {
|
||||
t.Errorf("rows.Err path: expected 1 partial result, got %v", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListDelegationsFromActivityLogs_ScanErrorSkipped is removed.
|
||||
//
|
||||
// Same reason as TestListDelegationsFromLedger_ScanError: Go 1.25 causes
|
||||
// sqlmock.NewRows([]string{}).AddRow(...) to panic in test SETUP. The handler
|
||||
// has no recover(), so a scan panic would crash the process — the correct
|
||||
// behaviour. Real-DB integration tests cover this path.
|
||||
@@ -80,13 +80,13 @@ type DelegationSweeper struct {
|
||||
threshold time.Duration
|
||||
}
|
||||
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.DB
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.GetDB()
|
||||
// (production wiring) or a test handle. Reads optional env overrides
|
||||
// at construction time so a long-running process picks them up via
|
||||
// restart, not mid-flight.
|
||||
func NewDelegationSweeper(handle *sql.DB, ledger *DelegationLedger) *DelegationSweeper {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
if ledger == nil {
|
||||
ledger = NewDelegationLedger(handle)
|
||||
|
||||
@@ -133,9 +133,9 @@ func TestDelegate_Success(t *testing.T) {
|
||||
targetID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
// Expect INSERT into activity_logs for delegation tracking
|
||||
// (6th arg is idempotency_key — nil here since the request omits it)
|
||||
// (6th arg is response_body, 7th is idempotency_key — nil here since the request omits it)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), nil).
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), sqlmock.AnyArg(), nil).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Expect RecordAndBroadcast INSERT into structure_events
|
||||
@@ -189,9 +189,9 @@ func TestDelegate_DBInsertFails_Still202WithWarning(t *testing.T) {
|
||||
|
||||
targetID := "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
|
||||
// DB insert fails (6th arg = idempotency_key, nil for this test)
|
||||
// DB insert fails (6th arg = response_body, 7th = idempotency_key, nil for this test)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), nil).
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), sqlmock.AnyArg(), nil).
|
||||
WillReturnError(fmt.Errorf("database connection lost"))
|
||||
|
||||
// RecordAndBroadcast still fires
|
||||
@@ -491,6 +491,7 @@ func TestDelegationRecord_InsertsActivityLogRow(t *testing.T) {
|
||||
"550e8400-e29b-41d4-a716-446655440001", // target_id
|
||||
"Delegating to 550e8400-e29b-41d4-a716-446655440001", // summary
|
||||
sqlmock.AnyArg(), // request_body (jsonb)
|
||||
sqlmock.AnyArg(), // response_body (jsonb) — mc#984 fix
|
||||
).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// RecordAndBroadcast INSERT for DELEGATION_SENT
|
||||
@@ -699,9 +700,9 @@ func TestDelegate_IdempotentFailedRowIsReleasedAndReplaced(t *testing.T) {
|
||||
mock.ExpectExec("DELETE FROM activity_logs").
|
||||
WithArgs("ws-source", "retry-key").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// Fresh insert with the same idempotency key.
|
||||
// Fresh insert with the same idempotency key (response_body added as mc#984 fix).
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), "retry-key").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), sqlmock.AnyArg(), "retry-key").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
@@ -745,9 +746,9 @@ func TestDelegate_IdempotentRaceUniqueViolationReturnsExisting(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT request_body->>'delegation_id', status, target_id").
|
||||
WithArgs("ws-source", "race-key").
|
||||
WillReturnError(fmt.Errorf("sql: no rows in result set"))
|
||||
// Insert loses the race against a concurrent caller.
|
||||
// Insert loses the race against a concurrent caller (response_body added as mc#984 fix).
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), "race-key").
|
||||
WithArgs("ws-source", "ws-source", targetID, "Delegating to "+targetID, sqlmock.AnyArg(), sqlmock.AnyArg(), "race-key").
|
||||
WillReturnError(fmt.Errorf("pq: duplicate key value violates unique constraint \"activity_logs_idempotency_uniq\""))
|
||||
// Re-query returns the winner.
|
||||
mock.ExpectQuery("SELECT request_body->>'delegation_id', status").
|
||||
|
||||
@@ -73,7 +73,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
var url sql.NullString
|
||||
var status string
|
||||
var forwardedTo sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -89,7 +89,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
resolvedID := targetID
|
||||
for i := 0; i < 5 && forwardedTo.Valid && forwardedTo.String != ""; i++ {
|
||||
resolvedID = forwardedTo.String
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, resolvedID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err != nil {
|
||||
@@ -128,7 +128,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
// of `callerID` and writes the JSON response (or an appropriate 404/503 error).
|
||||
func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, targetID string) {
|
||||
var wsName, wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
|
||||
// External workspaces: return their registered URL.
|
||||
// Rewrite 127.0.0.1/localhost → host.docker.internal ONLY when the
|
||||
@@ -149,7 +149,7 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
}
|
||||
// Fallback: only synthesize a URL if the workspace exists and is online/degraded
|
||||
var wsStatus string
|
||||
dbErr := db.DB.QueryRowContext(ctx,
|
||||
dbErr := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&wsStatus)
|
||||
if dbErr == nil && (wsStatus == "online" || wsStatus == "degraded") {
|
||||
@@ -174,13 +174,13 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
// file, leaving the caller to fall through to the internal-URL path.
|
||||
func writeExternalWorkspaceURL(ctx context.Context, c *gin.Context, callerID, targetID, wsName string) bool {
|
||||
var wsURL string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
if wsURL == "" {
|
||||
return false
|
||||
}
|
||||
outURL := wsURL
|
||||
var callerRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
if !isExternalLikeRuntime(callerRuntime) {
|
||||
outURL = strings.Replace(outURL, "127.0.0.1", "host.docker.internal", 1)
|
||||
outURL = strings.Replace(outURL, "localhost", "host.docker.internal", 1)
|
||||
@@ -224,7 +224,7 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
}
|
||||
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -304,7 +304,7 @@ func filterPeersByQuery(peers []map[string]interface{}, q string) []map[string]i
|
||||
|
||||
// queryPeerMaps returns clean JSON-serializable maps instead of Workspace structs.
|
||||
func queryPeerMaps(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := db.DB.Query(query, args...)
|
||||
rows, err := db.GetDB().Query(query, args...)
|
||||
if err != nil {
|
||||
log.Printf("queryPeerMaps error: %v", err)
|
||||
return nil, err
|
||||
@@ -377,7 +377,7 @@ func (h *DiscoveryHandler) CheckAccess(c *gin.Context) {
|
||||
// are already behind the existing `CanCommunicate` hierarchy check — a
|
||||
// momentary DB outage shouldn't take agent-to-agent discovery offline.
|
||||
func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: discovery HasAnyLiveToken(%s) failed: %v — allowing request", workspaceID, err)
|
||||
return nil
|
||||
@@ -427,7 +427,7 @@ func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID st
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewEventsHandler() *EventsHandler {
|
||||
|
||||
// List handles GET /events
|
||||
func (h *EventsHandler) List(c *gin.Context) {
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
ORDER BY created_at DESC
|
||||
@@ -56,7 +56,7 @@ func (h *EventsHandler) List(c *gin.Context) {
|
||||
func (h *EventsHandler) ListByWorkspace(c *gin.Context) {
|
||||
workspaceID := c.Param("workspaceId")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
@@ -85,12 +85,12 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
// that's better than the inverse where mint succeeds + revoke fails
|
||||
// and TWO live tokens end up valid (the previous one + the new one),
|
||||
// silently leaving the leaked credential alive.
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.DB, id); err != nil {
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.GetDB(), id); err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): revoke failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "revoke failed"})
|
||||
return
|
||||
}
|
||||
tok, err := wsauth.IssueToken(ctx, db.DB, id)
|
||||
tok, err := wsauth.IssueToken(ctx, db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): mint failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "mint failed"})
|
||||
@@ -129,7 +129,7 @@ func (h *WorkspaceHandler) GetExternalConnection(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
|
||||
@@ -230,20 +230,21 @@ func TestWorkspaceList_WithData(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
// 21 cols — see scanWorkspaceRow for order (max_concurrent_tasks
|
||||
// lands between active_tasks and last_error_rate).
|
||||
// 23 cols — broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte(`{"name":"agent1"}`), "http://localhost:8001",
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "", 2, "degraded", []byte("null"), "",
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0))
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -26,17 +26,36 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.DB.
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.GetDB().
|
||||
// It also disables the SSRF URL check so that httptest.NewServer loopback
|
||||
// URLs and fake hostnames (*.example) used in tests don't trigger rejections.
|
||||
//
|
||||
// The mutex guards the swap: setup holds Lock while reading prevDB and writing
|
||||
// mockDB; cleanup holds Lock while restoring prevDB. Concurrent goroutines
|
||||
// from test bodies call GetDB() (RLock) so they block during the swap,
|
||||
// preventing the DATA RACE between cleanup's write and LogActivity's read
|
||||
// (activity.go:590) that mc#1176 fixed.
|
||||
func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
db.Lock()
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { mockDB.Close() })
|
||||
db.Unlock()
|
||||
// Restore prevDB + close mock asynchronously so that concurrent goroutines
|
||||
// spawned by this test (e.g. provisionWorkspaceAuto goroutines) finish
|
||||
// before the swap-back. All GetDB() calls in those goroutines hold
|
||||
// RLock; the Lock here blocks them during the swap-back, guaranteeing
|
||||
// they see either the mock or prevDB, never an inconsistent state.
|
||||
t.Cleanup(func() {
|
||||
db.Lock()
|
||||
db.DB = prevDB
|
||||
db.Unlock()
|
||||
mockDB.Close()
|
||||
})
|
||||
|
||||
// Disable SSRF checks for the duration of this test only. Restore
|
||||
// the previous state via t.Cleanup so that TestIsSafeURL_* tests
|
||||
@@ -366,7 +385,7 @@ func TestBuildProvisionerConfig_IncludesAwarenessSettings(t *testing.T) {
|
||||
"ws-123",
|
||||
"/tmp/configs/template",
|
||||
map[string][]byte{"config.yaml": []byte("name: test")},
|
||||
models.CreateWorkspacePayload{Tier: 2, Runtime: "claude-code"},
|
||||
models.CreateWorkspacePayload{Tier: 2, Runtime: "claude-code", WorkspaceDir: "/tmp/workspace", WorkspaceAccess: "read_write"},
|
||||
map[string]string{"OPENAI_API_KEY": "sk-test"},
|
||||
"/tmp/plugins",
|
||||
"workspace:ws-123",
|
||||
@@ -391,21 +410,21 @@ func TestWorkspaceList(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// 21 cols: `max_concurrent_tasks` added between active_tasks and
|
||||
// last_error_rate (see scanWorkspaceRow + COALESCE(w.max_concurrent_tasks, 1)
|
||||
// in workspace.go). Column order must match that scan exactly.
|
||||
// 23 cols: broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte("null"), "http://localhost:8001",
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "manager", 2, "provisioning", []byte("null"), "",
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0))
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
@@ -1119,13 +1138,14 @@ func TestWorkspaceGet_CurrentTask(t *testing.T) {
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("dddddddd-0004-0000-0000-000000000000").
|
||||
WillReturnRows(sqlmock.NewRows(columns).AddRow(
|
||||
"dddddddd-0004-0000-0000-000000000000", "Task Worker", "worker", 1, "online", []byte("null"), "http://localhost:9000",
|
||||
nil, 2, 1, 0.0, "", 300, "Analyzing document", "langgraph", "", 10.0, 20.0, false,
|
||||
nil, int64(0),
|
||||
nil, int64(0), false, true,
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -55,7 +55,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
)
|
||||
ORDER BY CASE scope WHEN 'global' THEN 0 WHEN 'workspace' THEN 2 END,
|
||||
priority DESC`
|
||||
r, qErr := db.DB.QueryContext(ctx, query, workspaceID)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, workspaceID)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -76,7 +76,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
}
|
||||
query += ` ORDER BY scope, priority DESC, created_at`
|
||||
|
||||
r, qErr := db.DB.QueryContext(ctx, query, args...)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, args...)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -118,7 +118,7 @@ func (h *InstructionsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`INSERT INTO platform_instructions (scope, scope_target, title, content, priority)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id`,
|
||||
body.Scope, body.ScopeTarget, body.Title, body.Content, body.Priority,
|
||||
@@ -154,7 +154,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`UPDATE platform_instructions SET
|
||||
title = COALESCE($2, title),
|
||||
content = COALESCE($3, content),
|
||||
@@ -180,7 +180,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
// DELETE /instructions/:id
|
||||
func (h *InstructionsHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`DELETE FROM platform_instructions WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
@@ -209,7 +209,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT scope, title, content FROM platform_instructions
|
||||
WHERE enabled = true AND (
|
||||
scope = 'global'
|
||||
@@ -248,6 +248,9 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
b.WriteString(content)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ResolveInstructions rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workspace_id": workspaceID,
|
||||
@@ -258,6 +261,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
func scanInstructions(rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...interface{}) error
|
||||
Err() error
|
||||
}) []Instruction {
|
||||
var instructions []Instruction
|
||||
for rows.Next() {
|
||||
@@ -269,6 +273,9 @@ func scanInstructions(rows interface {
|
||||
}
|
||||
instructions = append(instructions, inst)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("scanInstructions rows.Err: %v", err)
|
||||
}
|
||||
if instructions == nil {
|
||||
instructions = []Instruction{}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ type MCPHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
// Pass db.GetDB() and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, newTestBroadcaster())
|
||||
h := NewMCPHandler(db.GetDB(), newTestBroadcaster())
|
||||
return h, mock
|
||||
}
|
||||
|
||||
@@ -751,9 +751,9 @@ func TestMCPHandler_SendMessageToUser_DBErrorLogsAndStill200s(t *testing.T) {
|
||||
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-err").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// INSERT fails — must NOT abort the tool response.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -802,9 +802,9 @@ func TestMCPHandler_SendMessageToUser_ResponseBodyShape(t *testing.T) {
|
||||
|
||||
const userMessage = "Hi there from the agent"
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-shape").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// Capture the response_body argument and assert its exact shape.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -861,9 +861,9 @@ func TestMCPHandler_SendMessageToUser_PersistsToActivityLog(t *testing.T) {
|
||||
// before it does anything else. Returning a name lets the
|
||||
// broadcast payload populate; the test doesn't assert on the
|
||||
// broadcast (no observable WS in this fake), only on the DB.
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-msg").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// The persistence INSERT — pin the exact shape so a future
|
||||
// refactor that switches columns or drops `method='notify'`
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// GLOBAL scope: only root workspaces (no parent) can write
|
||||
if body.Scope == "GLOBAL" {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only root workspaces can write GLOBAL memories"})
|
||||
return
|
||||
@@ -188,7 +188,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
|
||||
var memoryID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id
|
||||
`, workspaceID, content, body.Scope, namespace).Scan(&memoryID)
|
||||
@@ -212,7 +212,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
"content_sha256": hex.EncodeToString(sum[:]),
|
||||
})
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + namespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -228,7 +228,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -278,7 +278,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
|
||||
// Get workspace info for access control
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
|
||||
// Try to generate a query embedding for semantic search.
|
||||
// Falls back to the existing FTS/ILIKE path on failure or when no
|
||||
@@ -420,7 +420,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
log.Printf("Search memories error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "search failed"})
|
||||
@@ -542,7 +542,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
// One round-trip rather than two: SELECT ... WHERE id AND
|
||||
// workspace_id covers the 404 path without an extra existence check.
|
||||
var existingScope, existingContent, existingNamespace string
|
||||
if err := db.DB.QueryRowContext(ctx, `
|
||||
if err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT scope, content, namespace
|
||||
FROM agent_memories
|
||||
WHERE id = $1 AND workspace_id = $2
|
||||
@@ -588,7 +588,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE agent_memories
|
||||
SET content = $1, namespace = $2, updated_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4
|
||||
@@ -611,7 +611,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
"reason": "edited",
|
||||
})
|
||||
summary := "GLOBAL memory edited: id=" + memoryID + " namespace=" + newNamespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_edit_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -628,7 +628,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
log.Printf("Update: embedding failed workspace=%s memory=%s: %v (kept stale embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -652,7 +652,7 @@ func (h *MemoriesHandler) Delete(c *gin.Context) {
|
||||
memoryID := c.Param("memoryId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM agent_memories WHERE id = $1 AND workspace_id = $2`, memoryID, workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
|
||||
@@ -30,7 +30,7 @@ func NewMemoryHandler() *MemoryHandler { return &MemoryHandler{} }
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -65,7 +65,7 @@ func (h *MemoryHandler) Get(c *gin.Context) {
|
||||
|
||||
var entry MemoryEntry
|
||||
var value []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -134,7 +134,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Path A — no version guard: unchanged last-write-wins upsert.
|
||||
if body.IfMatchVersion == nil {
|
||||
var newVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
ON CONFLICT(workspace_id, key) DO UPDATE
|
||||
@@ -168,7 +168,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// version-mismatch or something else.
|
||||
expected := *body.IfMatchVersion
|
||||
var newVersion int64
|
||||
updateErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
updateErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
UPDATE workspace_memory
|
||||
SET value = $3::jsonb,
|
||||
expires_at = $4,
|
||||
@@ -182,7 +182,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Either the row doesn't exist yet, or version mismatch. Look
|
||||
// up the actual state so the 409 body carries useful context.
|
||||
var currentVersion sql.NullInt64
|
||||
probeErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
probeErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT version FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, body.Key).Scan(¤tVersion)
|
||||
@@ -193,7 +193,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// non-existent key with version assertion).
|
||||
if expected == 0 {
|
||||
var createdVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
RETURNING version
|
||||
@@ -239,7 +239,7 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
key := c.Param("key")
|
||||
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
DELETE FROM workspace_memory WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, key)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,7 +90,7 @@ func pickMockReply(workspaceID, requestID string) string {
|
||||
// genuine agent traffic.
|
||||
func lookupRuntime(ctx context.Context, workspaceID string) string {
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&runtime)
|
||||
if err != nil {
|
||||
|
||||
@@ -271,6 +271,62 @@ func (e EnvRequirement) IsSatisfied(configured map[string]struct{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// perWorkspaceUnsatisfied records a single unsatisfied RequiredEnv for a
|
||||
// specific workspace during org import preflight.
|
||||
type perWorkspaceUnsatisfied struct {
|
||||
Workspace string
|
||||
FilesDir string
|
||||
Unsatisfied EnvRequirement
|
||||
}
|
||||
|
||||
// collectPerWorkspaceUnsatisfied walks the workspace tree and returns every
|
||||
// RequiredEnv that is neither in `configured` (global secrets) nor resolvable
|
||||
// from the org root or workspace-level .env file. An empty orgBaseDir skips
|
||||
// the .env walk so all requirements appear unsatisfied (used by tests to
|
||||
// isolate the global-only path).
|
||||
func collectPerWorkspaceUnsatisfied(
|
||||
workspaces []OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
for _, ws := range workspaces {
|
||||
result = append(result, checkWorkspaceRequiredEnv(ws, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func checkWorkspaceRequiredEnv(
|
||||
ws OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
// Merge in .env vars from the org root and the workspace-specific dir.
|
||||
// Workspace-level vars override org-root vars, just as loadWorkspaceEnv
|
||||
// implements: org root first, then ws dir on top.
|
||||
if orgBaseDir != "" {
|
||||
wsEnv := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
for k, v := range wsEnv {
|
||||
configured[k] = struct{}{}
|
||||
_ = v // value only used for merging into configured map
|
||||
}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if !req.IsSatisfied(configured) {
|
||||
result = append(result, perWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, child := range ws.Children {
|
||||
result = append(result, checkWorkspaceRequiredEnv(child, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UnmarshalYAML accepts either a scalar (string → single) or a map
|
||||
// with an `any_of` list (→ group).
|
||||
func (e *EnvRequirement) UnmarshalYAML(value *yaml.Node) error {
|
||||
@@ -796,7 +852,7 @@ func (h *OrgHandler) Import(c *gin.Context) {
|
||||
// nothing (harmless) or, worse, match every workspace if a future
|
||||
// query rewrite drops the IN clause. Belt-and-suspenders.
|
||||
if len(importedNames) > 0 && len(importedIDs) > 0 {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = ANY($1::text[])
|
||||
AND id != ALL($2::uuid[])
|
||||
@@ -923,7 +979,7 @@ func emitOrgEvent(ctx context.Context, eventType string, payload map[string]any)
|
||||
log.Printf("emitOrgEvent: marshal %s payload failed: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, payload, created_at)
|
||||
VALUES ($1, $2, now())
|
||||
`, eventType, payloadJSON); err != nil {
|
||||
|
||||
@@ -64,7 +64,9 @@ func resolvePromptRef(inline, fileRef, orgBaseDir, filesDir string) (string, err
|
||||
|
||||
// envVarRefPattern matches actual ${VAR} or $VAR references (not literal $).
|
||||
// Used to detect unresolved placeholders without false positives like "$5".
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{?[A-Za-z_][A-Za-z0-9_]*\}?`)
|
||||
// Requires [a-zA-Z_] as the first char after $ so $100 stays literal.
|
||||
// Two capture groups: (1) ${VAR} form, (2) $VAR form.
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}|\$([a-zA-Z_][a-zA-Z0-9_]*)`)
|
||||
|
||||
// hasUnresolvedVarRef returns true if the original string had a ${VAR} or $VAR
|
||||
// reference that the expanded string didn't fully replace (i.e. the var was unset).
|
||||
@@ -78,26 +80,103 @@ func hasUnresolvedVarRef(original, expanded string) bool {
|
||||
}
|
||||
|
||||
// expandWithEnv expands ${VAR} and $VAR references in s using the env map.
|
||||
// Falls back to the platform process env if a var isn't in the map.
|
||||
// Shell variables must start with a letter or '_' per POSIX; invalid identifiers
|
||||
// are returned literally so that "$100" and "$5" stay as-is.
|
||||
// Falls back to the platform process env only when the whole value is a
|
||||
// single variable reference; embedded process-env expansion is too broad for
|
||||
// imported org YAML because host variables such as HOME are not template data.
|
||||
func expandWithEnv(s string, env map[string]string) string {
|
||||
return os.Expand(s, func(key string) string {
|
||||
if len(key) == 0 {
|
||||
return "$"
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] != '$' {
|
||||
b.WriteByte(s[i])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
c := key[0]
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return "$" + key // not a valid shell identifier — return literal
|
||||
|
||||
if i+1 >= len(s) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
|
||||
if s[i+1] == '{' {
|
||||
end := strings.IndexByte(s[i+2:], '}')
|
||||
if end < 0 {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
end += i + 2
|
||||
key := s[i+2 : end]
|
||||
ref := s[i : end+1]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = end + 1
|
||||
continue
|
||||
}
|
||||
return os.Getenv(key)
|
||||
})
|
||||
|
||||
if !isEnvIdentStart(s[i+1]) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
j := i + 2
|
||||
for j < len(s) && isEnvIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
key := s[i+1 : j]
|
||||
ref := s[i:j]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = j
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env
|
||||
// expandEnvRef resolves a single variable reference extracted from s.
|
||||
//
|
||||
// Guards:
|
||||
// - Empty key → "$$" escape, return "$"
|
||||
// - key[0] not POSIX ident start → "$" + partial chars, return "$<chars>"
|
||||
// - Key in env map → return the mapped value (template override wins)
|
||||
// - Otherwise → only fall back to os.Getenv if the whole input string IS the
|
||||
// variable reference (ref == whole).
|
||||
//
|
||||
// Bare $VAR format:
|
||||
// $HOME (alone) → ref==whole → os.Getenv ✓ (host HOME is org-template HOME)
|
||||
// $HOME/path (partial) → ref!=whole → literal "$HOME" ✓ (CWE-78: prevents host leak)
|
||||
//
|
||||
// Braced ${VAR} format:
|
||||
// ${HOME} (alone) → ref==whole → os.Getenv ✓
|
||||
// ${ROLE}/admin (partial) → ref!=whole → literal ✓
|
||||
// "yes and ${NOT_SET}" (embedded) → ref!=whole → literal ✓
|
||||
//
|
||||
// This is the CWE-78 fix from commit a3a358f9.
|
||||
func expandEnvRef(key, ref, whole string, env map[string]string) string {
|
||||
if key == "" {
|
||||
return "$"
|
||||
}
|
||||
if !isEnvIdentStart(key[0]) {
|
||||
return "$" + key
|
||||
}
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
}
|
||||
if ref == whole {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
return ref
|
||||
}
|
||||
|
||||
func isEnvIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isEnvIdentPart(c byte) bool {
|
||||
return isEnvIdentStart(c) || (c >= '0' && c <= '9')
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env .env and the workspace-specific .env
|
||||
// (workspace overrides org root). Used by both secret injection and channel
|
||||
// config expansion.
|
||||
//
|
||||
@@ -349,7 +428,11 @@ func resolveInsideRoot(root, userPath string) (string, error) {
|
||||
return "", fmt.Errorf("root abs: %w", err)
|
||||
}
|
||||
joined := filepath.Join(absRoot, userPath)
|
||||
absJoined, err := filepath.Abs(joined)
|
||||
// filepath.Join preserves "." components when root is absolute; clean
|
||||
// them before computing the final absolute path so "./subdir/./file.txt"
|
||||
// resolves to root/subdir/file.txt (not root/./subdir/./file.txt).
|
||||
cleaned := filepath.Clean(joined)
|
||||
absJoined, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("joined abs: %w", err)
|
||||
}
|
||||
|
||||
@@ -287,7 +287,7 @@ func TestRenderCategoryRoutingYAML_StableOrdering(t *testing.T) {
|
||||
if ai <= 0 || zi <= 0 || mi <= 0 {
|
||||
t.Fatalf("could not locate all keys in output: %s", out)
|
||||
}
|
||||
if !(ai < mi && mi < zi) {
|
||||
if ai >= mi || mi >= zi {
|
||||
t.Errorf("keys not sorted: alpha=%d middle=%d zebra=%d, output:\n%s", ai, mi, zi, out)
|
||||
}
|
||||
}
|
||||
@@ -462,8 +462,9 @@ func TestExpandWithEnv_LiteralDollar(t *testing.T) {
|
||||
func TestExpandWithEnv_PartiallyPresent(t *testing.T) {
|
||||
env := map[string]string{"SET": "yes"}
|
||||
result := expandWithEnv("${SET} and ${NOT_SET}", env)
|
||||
// ${SET} resolved; ${NOT_SET} -> "" via empty fallback.
|
||||
assert.Equal(t, "yes and ", result)
|
||||
// ${SET} resolved from env; ${NOT_SET} stays literal (not whole-string ref,
|
||||
// so os.Getenv fallback is NOT used — CWE-78 regression guard).
|
||||
assert.Equal(t, "yes and ${NOT_SET}", result)
|
||||
}
|
||||
|
||||
// mergeCategoryRouting tests — unions defaults with per-workspace routing.
|
||||
@@ -589,7 +590,7 @@ func TestRenderCategoryRoutingYAML_SpecialCharactersEscaped(t *testing.T) {
|
||||
// ── Additional coverage: appendYAMLBlock ───────────────────────────
|
||||
func TestAppendYAMLBlock_BothEmpty(t *testing.T) {
|
||||
result := appendYAMLBlock(nil, "")
|
||||
assert.Nil(t, result)
|
||||
assert.Nil(t, result) // append(nil, []byte("")...) returns nil in Go
|
||||
}
|
||||
|
||||
func TestAppendYAMLBlock_ExistingHasNewline(t *testing.T) {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestResolveInsideRoot_EmptyUserPath(t *testing.T) {
|
||||
_, err := resolveInsideRoot("/safe/root", "")
|
||||
if err == nil {
|
||||
t.Fatalf("empty userPath: expected error, got nil")
|
||||
t.Fatal("empty userPath: expected error, got nil")
|
||||
}
|
||||
if err.Error() != "path is empty" {
|
||||
t.Errorf("empty userPath: got %q, want %q", err.Error(), "path is empty")
|
||||
@@ -26,7 +26,7 @@ func TestResolveInsideRoot_EmptyUserPath(t *testing.T) {
|
||||
func TestResolveInsideRoot_AbsolutePathRejected(t *testing.T) {
|
||||
_, err := resolveInsideRoot("/safe/root", "/etc/passwd")
|
||||
if err == nil {
|
||||
t.Fatalf("absolute userPath: expected error, got nil")
|
||||
t.Fatal("absolute userPath: expected error, got nil")
|
||||
}
|
||||
if err.Error() != "absolute paths are not allowed" {
|
||||
t.Errorf("absolute userPath: got %q, want %q", err.Error(), "absolute paths are not allowed")
|
||||
@@ -44,24 +44,20 @@ func TestResolveInsideRoot_DotDotTraversal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveInsideRoot_DotDotWithIntermediate verifies that a/b/../../c does NOT
|
||||
// escape when root=/safe/root. After normalization: a/b/../.. = ., so a/b/../../c = c,
|
||||
// which is a valid descendant of /safe/root. The original test expected an error
|
||||
// but resolveInsideRoot correctly returns nil (the path stays within root).
|
||||
// The OFFSEC-006 concern is covered by ../../etc/passwd which DOES escape.
|
||||
func TestResolveInsideRoot_DotDotWithIntermediate(t *testing.T) {
|
||||
// a/b/../../c normalises to "c" — a valid descendant inside any root.
|
||||
// Must use t.TempDir() for a real filesystem path so filepath.Abs resolves.
|
||||
root := t.TempDir()
|
||||
got, err := resolveInsideRoot(root, "a/b/../../c")
|
||||
if err != nil {
|
||||
t.Fatalf("a/b/../../c should resolve (normalizes to c within root): %v", err)
|
||||
t.Fatalf("a/b/../../c should resolve within root: %v", err)
|
||||
}
|
||||
// Verify result is inside root and ends with "c"
|
||||
if !strings.HasPrefix(got, root+string(filepath.Separator)) {
|
||||
t.Errorf("result should be inside root %q, got %q", root, got)
|
||||
}
|
||||
// Ensure the suffix is "c"
|
||||
parts := strings.Split(strings.TrimPrefix(got, root), string(filepath.Separator))
|
||||
if parts[len(parts)-1] != "c" {
|
||||
t.Errorf("expected filename 'c', got %q", got)
|
||||
if got[len(got)-1:] != "c" {
|
||||
t.Errorf("resolved path should end in 'c', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,16 +93,14 @@ func TestResolveInsideRoot_DotPathComponent(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("dot path component: unexpected error: %v", err)
|
||||
}
|
||||
// Verify the file component is subdir/file.txt regardless of root length.
|
||||
suffix := string(filepath.Separator) + "subdir" + string(filepath.Separator) + "file.txt"
|
||||
if !strings.HasSuffix(got, suffix) {
|
||||
t.Errorf("dot path component: got %q, want suffix %q", got, suffix)
|
||||
if !strings.HasSuffix(got, "/subdir/file.txt") {
|
||||
t.Errorf("dot path component: got %q, want suffix /subdir/file.txt", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveInsideRoot_NestedDotDotEscapes(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
// a/../../b from /tmp/xyz → /tmp/b (escapes temp dir)
|
||||
// a/../../b from /tmp/dirsomething → /tmp/b (escapes temp dir)
|
||||
got, err := resolveInsideRoot(root, "a/../../b")
|
||||
if err == nil {
|
||||
t.Fatalf("nested dotdot: expected error, got %q", got)
|
||||
@@ -143,21 +137,66 @@ func TestResolveInsideRoot_SiblingNotEscaped(t *testing.T) {
|
||||
}
|
||||
|
||||
// ── isSafeRoleName ────────────────────────────────────────────────────────────
|
||||
// isSafeRoleName is tested comprehensively in org_helpers_pure_test.go.
|
||||
// Only security-critical path-injection cases live here.
|
||||
|
||||
func TestIsSafeRoleName_Empty(t *testing.T) {
|
||||
if isSafeRoleName("") {
|
||||
t.Error("isSafeRoleName(\"\"): expected false, got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeRoleName_Dot(t *testing.T) {
|
||||
if isSafeRoleName(".") {
|
||||
t.Error("isSafeRoleName(\".\"): expected false, got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeRoleName_DotDot(t *testing.T) {
|
||||
if isSafeRoleName("..") {
|
||||
t.Error("isSafeRoleName(\"..\"): expected false, got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeRoleName_PathTraversal(t *testing.T) {
|
||||
unsafe := []string{
|
||||
"../etc",
|
||||
"foo/../../../etc",
|
||||
"foo/../../bar",
|
||||
}
|
||||
for _, name := range unsafe {
|
||||
if isSafeRoleName(name) {
|
||||
t.Errorf("isSafeRoleName(%q): expected false (path traversal), got true", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeRoleName_SpecialChars(t *testing.T) {
|
||||
unsafe := []string{
|
||||
"foo:bar",
|
||||
"foo bar",
|
||||
"foo\tbar",
|
||||
"foo\nbar",
|
||||
"foo\x00bar",
|
||||
"foo@bar",
|
||||
"foo#bar",
|
||||
"foo$bar",
|
||||
}
|
||||
for _, name := range unsafe {
|
||||
if isSafeRoleName(name) {
|
||||
t.Errorf("isSafeRoleName(%q): expected false (special char), got true", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── mergeCategoryRouting ──────────────────────────────────────────────────────
|
||||
// Duplicate mergeCategoryRouting tests removed to avoid redeclaration with
|
||||
// org_helpers_pure_test.go. Only security-specific behaviour lives here.
|
||||
|
||||
func TestSecureRouting_BothNil(t *testing.T) {
|
||||
func TestMergeCategoryRouting_BothNil(t *testing.T) {
|
||||
got := mergeCategoryRouting(nil, nil)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("both nil: got %v, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_DefaultOnly(t *testing.T) {
|
||||
func TestMergeCategoryRouting_DefaultOnly(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer", "DevOps"},
|
||||
}
|
||||
@@ -170,7 +209,7 @@ func TestSecureRouting_DefaultOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_WorkspaceOnly(t *testing.T) {
|
||||
func TestMergeCategoryRouting_WorkspaceOnly(t *testing.T) {
|
||||
wsRouting := map[string][]string{
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
@@ -183,7 +222,7 @@ func TestSecureRouting_WorkspaceOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_MergeNoOverlap(t *testing.T) {
|
||||
func TestMergeCategoryRouting_MergeNoOverlap(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
}
|
||||
@@ -196,7 +235,7 @@ func TestSecureRouting_MergeNoOverlap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
func TestMergeCategoryRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer", "DevOps"},
|
||||
}
|
||||
@@ -212,34 +251,7 @@ func TestSecureRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyListDropsCategory(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
wsRouting := map[string][]string{
|
||||
"security": {}, // empty list = opt out
|
||||
}
|
||||
got := mergeCategoryRouting(defaultRouting, wsRouting)
|
||||
if _, exists := got["security"]; exists {
|
||||
t.Error("empty ws list should delete the category from output")
|
||||
}
|
||||
if len(got["ui"]) != 1 {
|
||||
t.Errorf("ui should still exist: got %v", got["ui"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyKeySkipped(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"": {"Backend Engineer"},
|
||||
}
|
||||
got := mergeCategoryRouting(defaultRouting, nil)
|
||||
if _, exists := got[""]; exists {
|
||||
t.Error("empty key should be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
func TestMergeCategoryRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {},
|
||||
}
|
||||
@@ -249,7 +261,7 @@ func TestSecureRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
func TestMergeCategoryRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
}
|
||||
@@ -264,3 +276,121 @@ func TestSecureRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
t.Error("ws routing should be unmodified after merge")
|
||||
}
|
||||
}
|
||||
|
||||
// ── expandWithEnv ─────────────────────────────────────────────────────────────
|
||||
//
|
||||
// CWE-78 regression tests. The original fix (a3a358f9) ensures that partial
|
||||
// variable references like $HOME/path are NOT resolved via os.Getenv — the
|
||||
// host HOME env var must not leak into org template values. Only whole-string
|
||||
// references ($VAR or ${VAR}) may fall back to the host process environment.
|
||||
|
||||
func TestExpandWithEnv_PartialRefDollarHomePath(t *testing.T) {
|
||||
// $HOME/path must NOT resolve to the host's HOME env var.
|
||||
// The literal $HOME must be returned as-is.
|
||||
got := expandWithEnv("$HOME/path", nil)
|
||||
if got != "$HOME/path" {
|
||||
t.Errorf("$HOME/path: got %q, want literal $HOME/path", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_PartialRefBracedRoleAdmin(t *testing.T) {
|
||||
// ${ROLE}/admin — ROLE is not in env, so expand to the literal ${ROLE}/admin.
|
||||
got := expandWithEnv("${ROLE}/admin", nil)
|
||||
if got != "${ROLE}/admin" {
|
||||
t.Errorf("${ROLE}/admin: got %q, want literal ${ROLE}/admin", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_PartialRefMiddleOfString(t *testing.T) {
|
||||
// $ROLE in the middle of a string — literal, not os.Getenv.
|
||||
got := expandWithEnv("prefix/$ROLE/suffix", nil)
|
||||
if got != "prefix/$ROLE/suffix" {
|
||||
t.Errorf("prefix/$ROLE/suffix: got %q, want literal", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarInEnv(t *testing.T) {
|
||||
// Whole-string $VAR that IS in env — env value wins.
|
||||
env := map[string]string{"FOO": "barvalue"}
|
||||
got := expandWithEnv("$FOO", env)
|
||||
if got != "barvalue" {
|
||||
t.Errorf("$FOO with FOO=barvalue: got %q, want barvalue", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarBracedInEnv(t *testing.T) {
|
||||
// Whole-string ${VAR} that IS in env — env value wins.
|
||||
env := map[string]string{"FOO": "barvalue"}
|
||||
got := expandWithEnv("${FOO}", env)
|
||||
if got != "barvalue" {
|
||||
t.Errorf("${FOO} with FOO=barvalue: got %q, want barvalue", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarNotInEnvBare(t *testing.T) {
|
||||
// Whole-string $VAR not in env — falls back to os.Getenv.
|
||||
// If the host has the var, we get the host value. If not, empty.
|
||||
// At minimum, the result must NOT be the literal "$UNDEFINED_VAR_9Z".
|
||||
got := expandWithEnv("$UNDEFINED_VAR_9Z", nil)
|
||||
if got == "$UNDEFINED_VAR_9Z" {
|
||||
t.Errorf("$UNDEFINED_VAR_9Z: should expand (whole-string fallback to os.Getenv), got literal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarNotInEnvBraced(t *testing.T) {
|
||||
// Whole-string ${VAR} not in env — falls back to os.Getenv.
|
||||
got := expandWithEnv("${UNDEFINED_VAR_9Z}", nil)
|
||||
if got == "${UNDEFINED_VAR_9Z}" {
|
||||
t.Errorf("${UNDEFINED_VAR_9Z}: should expand (whole-string fallback to os.Getenv), got literal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmptyString(t *testing.T) {
|
||||
got := expandWithEnv("", map[string]string{"FOO": "bar"})
|
||||
if got != "" {
|
||||
t.Errorf("empty string: got %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_NoVarRefs(t *testing.T) {
|
||||
got := expandWithEnv("plain string with no vars", map[string]string{"FOO": "bar"})
|
||||
if got != "plain string with no vars" {
|
||||
t.Errorf("plain string: got %q, want unchanged", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_MultipleVarRefs(t *testing.T) {
|
||||
// Two vars, both whole — both expand from env.
|
||||
env := map[string]string{"A": "alpha", "B": "beta"}
|
||||
got := expandWithEnv("$A and $B and more", env)
|
||||
if got != "alpha and beta and more" {
|
||||
t.Errorf("multiple vars: got %q, want alpha and beta and more", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_NumericVarRef(t *testing.T) {
|
||||
// $5 — starts with digit, not a valid identifier start.
|
||||
// Must return the literal "$5", not expand via os.Getenv.
|
||||
got := expandWithEnv("$5", map[string]string{"5": "five"})
|
||||
if got != "$5" {
|
||||
t.Errorf("$5: got %q, want literal $5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_DollarEscape(t *testing.T) {
|
||||
// $$ → both $ written literally (each $ is not followed by an identifier char,
|
||||
// so it is written as-is). No special escape sequence for $$.
|
||||
got := expandWithEnv("$$", nil)
|
||||
if got != "$$" {
|
||||
t.Errorf("$$: got %q, want literal $$", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_MixedPartialAndWhole(t *testing.T) {
|
||||
// $A is in env (whole), $HOME is partial — only $A expands.
|
||||
env := map[string]string{"A": "alpha"}
|
||||
got := expandWithEnv("$A at $HOME", env)
|
||||
if got != "alpha at $HOME" {
|
||||
t.Errorf("$A at $HOME: got %q, want alpha at $HOME", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// status != 'removed' — must match the partial-index predicate
|
||||
// EXACTLY for Postgres to consider the index applicable.
|
||||
var insertedID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access, max_concurrent_tasks)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (COALESCE(parent_id, '00000000-0000-0000-0000-000000000000'::uuid), name)
|
||||
@@ -224,7 +224,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// `collapsed` lives on canvas_layouts (005_canvas_layouts.sql), not
|
||||
// on workspaces; the UI-only flag is intentionally decoupled from
|
||||
// the workspace row.
|
||||
if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
|
||||
// Handle external workspaces
|
||||
if ws.External {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -273,7 +273,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// URL is set; the proxy never tries to resolve one for mock
|
||||
// runtimes. Built for the funding-demo "200-workspace mock
|
||||
// org" template — visual scale without real backend cost.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
log.Printf("Org import: mock workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -512,7 +512,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
} else {
|
||||
encrypted = []byte(value) // store raw when encryption disabled
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now()
|
||||
@@ -570,7 +570,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
sched.Name, ws.Name, nextRunErr)
|
||||
continue
|
||||
}
|
||||
if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil {
|
||||
log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err)
|
||||
} else {
|
||||
@@ -644,7 +644,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
enabled = *ch.Enabled
|
||||
}
|
||||
// Idempotent insert — if same workspace+type already exists, update config
|
||||
if _, err := db.DB.ExecContext(context.Background(), `
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
ON CONFLICT (workspace_id, channel_type) DO UPDATE
|
||||
@@ -695,7 +695,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// abort the import. errors.Is unwraps.
|
||||
func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
var existingID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = $1
|
||||
AND parent_id IS NOT DISTINCT FROM $2
|
||||
@@ -952,56 +952,8 @@ type PerWorkspaceUnsatisfied struct {
|
||||
|
||||
// collectPerWorkspaceUnsatisfied recursively walks workspaces and returns
|
||||
// per-workspace RequiredEnv entries that are not covered by (a) a global
|
||||
// secret key or (b) a key present in the workspace's .env file(s) (org root
|
||||
// .env + per-workspace <files_dir>/.env). This complements
|
||||
// collectOrgEnv + loadConfiguredGlobalSecretKeys, which together only
|
||||
// validate global-level RequiredEnv against global_secrets. The .env
|
||||
// lookup mirrors the runtime resolution in createWorkspaceTree so that
|
||||
// the preflight result matches what the container actually receives at
|
||||
// start time.
|
||||
func collectPerWorkspaceUnsatisfied(workspaces []OrgWorkspace, orgBaseDir string, globalSecrets map[string]struct{}) []PerWorkspaceUnsatisfied {
|
||||
var out []PerWorkspaceUnsatisfied
|
||||
var walk func([]OrgWorkspace)
|
||||
walk = func(wsList []OrgWorkspace) {
|
||||
for _, ws := range wsList {
|
||||
// Build the set of keys available to this workspace from .env.
|
||||
// This is the same three-source stack that createWorkspaceTree
|
||||
// injects into the container:
|
||||
// 1. Org root .env (parseEnvFile, no filesDir)
|
||||
// 2. Workspace <files_dir>/.env (if filesDir is set)
|
||||
// 3. Persona bootstrap env (MOLECULE_PERSONA_ROOT/<filesDir>/env)
|
||||
// Items 1+2 are on-disk and testable; item 3 is host-only and
|
||||
// skipped here (persona env does NOT satisfy required_env —
|
||||
// it carries identity tokens, not workspace LLM keys).
|
||||
envFromFiles := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
// Convert map[string]string (from .env files) to map[string]struct{}
|
||||
// to match IsSatisfied's signature.
|
||||
envSet := make(map[string]struct{}, len(envFromFiles))
|
||||
for k := range envFromFiles {
|
||||
envSet[k] = struct{}{}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if req.IsSatisfied(globalSecrets) {
|
||||
continue // covered by a global secret
|
||||
}
|
||||
if req.IsSatisfied(envSet) {
|
||||
continue // covered by a per-workspace .env file
|
||||
}
|
||||
out = append(out, PerWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
walk(ws.Children)
|
||||
}
|
||||
}
|
||||
walk(workspaces)
|
||||
return out
|
||||
}
|
||||
|
||||
func loadConfiguredGlobalSecretKeys(ctx context.Context) (map[string]struct{}, error) {
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM global_secrets WHERE octet_length(encrypted_value) > 0 LIMIT $1`,
|
||||
globalSecretsPreflightLimit)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,8 +17,11 @@ import (
|
||||
// when one exists, or the workspace's own ID when it is the org root.
|
||||
// Returns an empty string if the workspace is not found.
|
||||
func resolveOrgID(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.GetDB() == nil {
|
||||
return "", nil // nil in unit tests
|
||||
}
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&parentID)
|
||||
@@ -53,7 +56,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
}
|
||||
|
||||
var allowed bool
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM org_plugin_allowlist
|
||||
WHERE org_id = $1 AND plugin_name = $2
|
||||
@@ -69,7 +72,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
|
||||
// Check whether an allowlist exists at all. Empty allowlist = allow-all.
|
||||
var count int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM org_plugin_allowlist WHERE org_id = $1`,
|
||||
orgID,
|
||||
).Scan(&count); err != nil {
|
||||
@@ -135,7 +138,7 @@ func requireCallerOwnsOrg(c *gin.Context) (string, error) {
|
||||
// Look up the token's org_id (populated at mint time by orgTokenActor).
|
||||
// org_id is NULL for tokens minted before this migration or via
|
||||
// ADMIN_TOKEN bootstrap — those callers get callerOrg="" and are denied.
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil {
|
||||
// DB error — deny by default rather than risk cross-org access.
|
||||
return "", fmt.Errorf("allowlist: requireCallerOwnsOrg: %v", err)
|
||||
@@ -196,7 +199,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -216,7 +219,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT plugin_name, enabled_by, enabled_at
|
||||
FROM org_plugin_allowlist
|
||||
WHERE org_id = $1
|
||||
@@ -285,7 +288,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -304,7 +307,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Replace atomically: delete all current entries, then insert the new set.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("allowlist: begin tx failed for org %s: %v", orgID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start transaction"})
|
||||
|
||||
@@ -31,7 +31,7 @@ func NewOrgTokenHandler() *OrgTokenHandler {
|
||||
// List returns live (non-revoked) tokens, newest-first. Prefix only —
|
||||
// never plaintext or hash.
|
||||
func (h *OrgTokenHandler) List(c *gin.Context) {
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.DB)
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("orgtoken list: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list tokens"})
|
||||
@@ -76,7 +76,7 @@ func (h *OrgTokenHandler) Create(c *gin.Context) {
|
||||
|
||||
createdBy, orgID := orgTokenActor(c)
|
||||
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.DB, req.Name, createdBy, orgID)
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.GetDB(), req.Name, createdBy, orgID)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken issue: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mint token"})
|
||||
@@ -101,7 +101,7 @@ func (h *OrgTokenHandler) Revoke(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id required"})
|
||||
return
|
||||
}
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.DB, id)
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken revoke: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to revoke"})
|
||||
@@ -143,7 +143,7 @@ func callerOrg(c *gin.Context) string {
|
||||
if !ok || tokID == "" {
|
||||
return ""
|
||||
}
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil || orgID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupOrgTokenTest wires the package-global db.DB to a sqlmock for
|
||||
// setupOrgTokenTest wires the package-global db.GetDB() to a sqlmock for
|
||||
// the duration of a test, returning the handler + mock + cleanup.
|
||||
// Gin runs in release mode to suppress debug noise.
|
||||
func setupOrgTokenTest(t *testing.T) (*OrgTokenHandler, sqlmock.Sqlmock, func()) {
|
||||
|
||||
@@ -43,7 +43,7 @@ type PendingUploadsHandler struct {
|
||||
}
|
||||
|
||||
// NewPendingUploadsHandler constructs the handler with a concrete
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.DB).
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.GetDB()).
|
||||
func NewPendingUploadsHandler(storage pendinguploads.Storage) *PendingUploadsHandler {
|
||||
return &PendingUploadsHandler{storage: storage}
|
||||
}
|
||||
|
||||
@@ -215,6 +215,9 @@ func TestTarWalk_EmptyDirectory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestTarWalk_NestedDirs is defined in plugins_atomic_tar_test.go to avoid
|
||||
// redeclaration. Deeply nested directory walk is tested there.
|
||||
|
||||
// TestTarWalk_DirEntryHasTrailingSlash: directory entries must end with '/'
|
||||
// per tar format; tar.Header.Typeflag '5' (dir) must produce "name/" not "name".
|
||||
func TestTarWalk_DirEntryHasTrailingSlash(t *testing.T) {
|
||||
|
||||
@@ -300,7 +300,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Auth gate — workspace token required (fail-closed on DB errors).
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
log.Printf("wsauth: plugin.Download HasAnyLiveToken(%s) failed: %v", workspaceID, hlErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -312,7 +312,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func recordWorkspacePluginInstall(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_plugins (workspace_id, plugin_name, source_raw, tracked_ref, installed_sha)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace_id, plugin_name)
|
||||
@@ -86,7 +86,10 @@ func recordWorkspacePluginInstall(
|
||||
// pair. Called by the uninstall path so the row doesn't persist with a stale
|
||||
// installed_sha after the plugin has been removed from the container.
|
||||
func deleteWorkspacePluginRow(ctx context.Context, workspaceID, pluginName string) error {
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() == nil {
|
||||
return nil // nil in unit tests; no-op since the row is test-only
|
||||
}
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_plugins WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, workspaceID, pluginName)
|
||||
return err
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *RegistryHandler) resolveDeliveryMode(ctx context.Context, workspaceID,
|
||||
}
|
||||
var existing sql.NullString
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode, runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&existing, &runtime)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -356,7 +356,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// the row. Without this guard, bulk deletes left tier-3 stragglers because
|
||||
// the last pre-teardown heartbeat flipped status back to 'online' after
|
||||
// Delete's UPDATE.
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, url, agent_card, status, last_heartbeat_at, delivery_mode)
|
||||
VALUES ($1, $2, $3, $4::jsonb, 'online', now(), $5)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
@@ -393,7 +393,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// before consulting the URL cache anyway (see #2339 PR 2).
|
||||
cachedURL := payload.URL
|
||||
var dbURL string
|
||||
if err := db.DB.QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if strings.HasPrefix(dbURL, "http://127.0.0.1") {
|
||||
cachedURL = dbURL
|
||||
}
|
||||
@@ -433,8 +433,8 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// live token; they bootstrap one here on their next register call.
|
||||
// New workspaces always pass through this path on their first boot.
|
||||
response := gin.H{"status": "registered", "delivery_mode": effectiveMode}
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.DB, payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.DB, payload.ID)
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.GetDB(), payload.ID)
|
||||
if tokErr != nil {
|
||||
// Don't fail the whole register on token-issuance error — the
|
||||
// agent is already online per the upsert above. Log and continue.
|
||||
@@ -502,7 +502,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
|
||||
// Read previous current_task to detect changes (before the UPDATE)
|
||||
var prevTask string
|
||||
_ = db.DB.QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
_ = db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
|
||||
// #615: Clamp monthly_spend to a safe range before any DB write.
|
||||
// A malicious or buggy agent could report math.MaxInt64, causing
|
||||
@@ -528,7 +528,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
// zero to avoid accidentally clearing a previously-reported spend value.
|
||||
var err error
|
||||
if payload.MonthlySpend > 0 {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -543,7 +543,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
payload.ActiveTasks, payload.UptimeSeconds, payload.CurrentTask,
|
||||
payload.MonthlySpend)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -655,7 +655,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var currentStatus string
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
Scan(¤tStatus)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -672,7 +672,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeout — restart workspace"), which the canvas surfaces in the
|
||||
// degraded card without the operator scraping container logs.
|
||||
if payload.RuntimeState == "wedged" && currentStatus == "online" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'online'`,
|
||||
models.StatusDegraded, payload.WorkspaceID)
|
||||
if err != nil {
|
||||
@@ -696,7 +696,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
nativeStatus := runtimeOverrides.HasCapability(payload.WorkspaceID, "status_mgmt")
|
||||
|
||||
if !nativeStatus && currentStatus == "online" && payload.ErrorRate >= 0.5 {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to mark %s degraded: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceDegraded), payload.WorkspaceID, map[string]interface{}{
|
||||
@@ -715,7 +715,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// Skipped under native_status_mgmt for the same reason as the
|
||||
// degrade branch above: the adapter owns the transition.
|
||||
if !nativeStatus && currentStatus == "degraded" && payload.ErrorRate < 0.1 && payload.RuntimeState == "" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s to online: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -725,7 +725,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// #73 guard: `AND status = 'offline'` makes the flip conditional in a single statement,
|
||||
// so a Delete that races with this recovery can't flip 'removed' back to 'online'.
|
||||
if currentStatus == "offline" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from offline: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -738,7 +738,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// transition is the only mechanism that moves newly-started workspaces out of
|
||||
// the phantom-idle state. (#1784)
|
||||
if currentStatus == "provisioning" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to transition %s from provisioning to online: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from provisioning to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -766,7 +766,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// heartbeats can't lift the workspace out of awaiting_agent on
|
||||
// their own.
|
||||
if currentStatus == "awaiting_agent" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from awaiting_agent: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from awaiting_agent to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -784,7 +784,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeouts, retry logic, and activity_logs wiring.
|
||||
if h.drainQueue != nil {
|
||||
var maxConcurrent int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(max_concurrent_tasks, 1) FROM workspaces WHERE id = $1`,
|
||||
payload.WorkspaceID,
|
||||
).Scan(&maxConcurrent)
|
||||
@@ -811,7 +811,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
}
|
||||
|
||||
agentCardStr := string(payload.AgentCard)
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspaces SET agent_card = $2::jsonb, updated_at = now() WHERE id = $1
|
||||
`, payload.WorkspaceID, agentCardStr)
|
||||
if err != nil {
|
||||
@@ -849,7 +849,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
func (h *RegistryHandler) requireWorkspaceToken(
|
||||
ctx gincontext, c *gin.Context, workspaceID string,
|
||||
) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
// DB error checking token existence — fail open so we don't take
|
||||
// the whole heartbeat path down on a transient hiccup. Log loudly.
|
||||
@@ -865,7 +865,7 @@ func (h *RegistryHandler) requireWorkspaceToken(
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, token); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, token); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
d := restartContextData{RestartAt: time.Now()}
|
||||
|
||||
var lastHB sql.NullTime
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT last_heartbeat_at FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&lastHB); err == nil && lastHB.Valid {
|
||||
d.PrevSessionAt = lastHB.Time
|
||||
@@ -132,7 +132,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
// the platform ever echoing secret material back into the
|
||||
// message bus.
|
||||
keySet := map[string]struct{}{}
|
||||
if rows, err := db.DB.QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
if rows, err := db.GetDB().QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
for rows.Next() {
|
||||
var k string
|
||||
if rows.Scan(&k) == nil {
|
||||
@@ -141,7 +141,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
if rows, err := db.DB.QueryContext(ctx,
|
||||
if rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM workspace_secrets WHERE workspace_id = $1`, workspaceID,
|
||||
); err == nil {
|
||||
for rows.Next() {
|
||||
@@ -166,7 +166,7 @@ func waitForWorkspaceOnline(ctx context.Context, workspaceID string, timeout tim
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
var status string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&status); err == nil && status == "online" {
|
||||
return true
|
||||
|
||||
@@ -125,7 +125,7 @@ func (h *WorkspaceHandler) resolveAgentURLForRestartSignal(ctx context.Context,
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
var urlNullable *string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable)
|
||||
if err != nil {
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestRewriteForDocker_LocalhostUrlRewritten(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheHit verifies that a Redis-cached
|
||||
// URL is returned without hitting the DB.
|
||||
func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
_ = setupTestDB(t) // db.DB must be set before setupTestRedisWithURL
|
||||
_ = setupTestDB(t) // db.GetDB() must be set before setupTestRedisWithURL
|
||||
_ = setupTestRedisWithURL(t, "http://cached.internal:9000/agent")
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -118,7 +118,7 @@ func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_DBError verifies that a DB error is
|
||||
// returned and propagated when neither Redis cache nor DB lookup succeeds.
|
||||
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -140,7 +140,7 @@ func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheMiss verifies that on Redis miss,
|
||||
// the URL is fetched from the DB and cached.
|
||||
func TestResolveAgentURLForRestartSignal_CacheMiss(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
|
||||
@@ -40,12 +40,12 @@ func resolveRuntimeImage(ctx context.Context, runtime string) string {
|
||||
if os.Getenv("WORKSPACE_IMAGE_LOCAL_OVERRIDE") != "" {
|
||||
return ""
|
||||
}
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var digest string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT digest FROM runtime_image_pins WHERE template_name = $1`, runtime,
|
||||
).Scan(&digest)
|
||||
if err != nil {
|
||||
|
||||
@@ -44,7 +44,7 @@ func (h *ScheduleHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, name, cron_expr, timezone, prompt, enabled,
|
||||
last_run_at, next_run_at, run_count, last_status, last_error,
|
||||
source, created_at, updated_at
|
||||
@@ -127,7 +127,7 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
// source='runtime' marks this row as user-created (Canvas/API). The
|
||||
// org/import path inserts with source='template' and only refreshes
|
||||
// template-source rows on re-import (issue #24), so runtime rows survive.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_schedules (workspace_id, name, cron_expr, timezone, prompt, enabled, next_run_at, source)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, 'runtime')
|
||||
RETURNING id
|
||||
@@ -176,7 +176,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
var nextRunAt *time.Time
|
||||
if body.CronExpr != nil || body.Timezone != nil {
|
||||
var currentCron, currentTZ string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(¤tCron, ¤tTZ)
|
||||
@@ -204,7 +204,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
nextRunAt = &nextRun
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_schedules SET
|
||||
name = COALESCE($2, name),
|
||||
cron_expr = COALESCE($3, cron_expr),
|
||||
@@ -235,7 +235,7 @@ func (h *ScheduleHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id") // #113: bind to owning workspace to prevent IDOR
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -258,7 +258,7 @@ func (h *ScheduleHandler) RunNow(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var prompt string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT prompt FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(&prompt)
|
||||
@@ -290,7 +290,7 @@ func (h *ScheduleHandler) History(c *gin.Context) {
|
||||
// #152: include error_detail in history so UI can show why a run failed.
|
||||
// activity_logs.error_detail is populated by scheduler.fireSchedule when
|
||||
// the A2A proxy returns non-2xx or the update SQL reports an error.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT created_at, duration_ms, status,
|
||||
COALESCE(error_detail, '') as error_detail,
|
||||
COALESCE(request_body::text, '{}') as request_body
|
||||
@@ -390,7 +390,7 @@ func (h *ScheduleHandler) Health(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, name, enabled, last_run_at, next_run_at, run_count, last_status, last_error
|
||||
FROM workspace_schedules
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -0,0 +1,810 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// scheduleCols is the full column set returned by List.
|
||||
var scheduleCols = []string{
|
||||
"id", "workspace_id", "name", "cron_expr", "timezone", "prompt", "enabled",
|
||||
"last_run_at", "next_run_at", "run_count", "last_status", "last_error",
|
||||
"source", "created_at", "updated_at",
|
||||
}
|
||||
|
||||
// ==================== List ====================
|
||||
|
||||
func TestScheduleHandler_List_EmptyResult(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_schedules WHERE workspace_id").
|
||||
WithArgs("ws-list-empty").
|
||||
WillReturnRows(sqlmock.NewRows(scheduleCols))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-list-empty"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-list-empty/schedules", nil)
|
||||
|
||||
handler.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var schedules []interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &schedules); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if len(schedules) != 0 {
|
||||
t.Errorf("expected empty list, got %d items", len(schedules))
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_List_QueryError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_schedules WHERE workspace_id").
|
||||
WithArgs("ws-list-err").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-list-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-list-err/schedules", nil)
|
||||
|
||||
handler.List(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Create ====================
|
||||
|
||||
func TestScheduleHandler_Create_MissingCronExpr(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// prompt only — no cron_expr
|
||||
body := []byte(`{"prompt":"do the thing"}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for missing cron_expr, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_MissingPrompt(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// cron_expr only — no prompt
|
||||
body := []byte(`{"cron_expr":"0 9 * * *"}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for missing prompt, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_InvalidTimezone(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do the thing",
|
||||
"timezone": "Not/A/Timezone",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid timezone, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid timezone") {
|
||||
t.Errorf("expected 'invalid timezone' error, got: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_InvalidCron(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "not-a-cron",
|
||||
"prompt": "do the thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid cron, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid request body") {
|
||||
t.Errorf("expected 'invalid request body' error, got: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_CRLFStripped(t *testing.T) {
|
||||
// Use setupTestDBForQueueTests which sets up QueryMatcherEqual for exact
|
||||
// string matching. The INSERT statement is deterministic enough for that.
|
||||
customSqlmock := setupTestDBForQueueTests(t)
|
||||
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// Prompt with CRLF from a Windows-committed org-template file.
|
||||
// The handler strips \r before inserting so agent doesn't see empty responses.
|
||||
promptWithCRLF := "check\r\ndocs\r\nbefore merge"
|
||||
|
||||
// The handler strips \r → query should receive the LF-only version.
|
||||
customSqlmock.ExpectQuery("INSERT INTO workspace_schedules (workspace_id, name, cron_expr, timezone, prompt, enabled, next_run_at, source) VALUES ($1, $2, $3, $4, $5, $6, $7, 'runtime') RETURNING id").
|
||||
WithArgs("ws-crlf", "", "0 9 * * *", "UTC", "check\ndocs\nbefore merge", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-crlf"))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": promptWithCRLF,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-crlf"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-crlf/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := customSqlmock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DefaultEnabled(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// enabled field absent — must default to true.
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-def-enable", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-enable"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
// no "enabled" field
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-def-enable"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-def-enable/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DefaultTimezone(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// timezone field absent — must default to UTC.
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-def-tz", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-tz"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
// no "timezone" field
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-def-tz"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-def-tz/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_ExplicitEnabledFalse(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
enabled := false
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-dis", "", "0 9 * * *", "UTC", "do thing", enabled, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-dis"))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
"enabled": false,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-dis"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-dis/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-db-err"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-db-err/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_NextRunAtReturned(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-next", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-next"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-next"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-next/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "created" {
|
||||
t.Errorf("expected status 'created', got %v", resp["status"])
|
||||
}
|
||||
if _, ok := resp["next_run_at"]; !ok {
|
||||
t.Error("expected next_run_at in response")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Update ====================
|
||||
|
||||
func TestScheduleHandler_Update_PartialRecomputeCron(t *testing.T) {
|
||||
// Uses QueryMatcherEqual so query strings are compared verbatim — no escaping needed.
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-recompute-cron", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 8 * * *", "UTC"))
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-recompute-cron", nil, "0 6 * * *", nil, nil, nil, sqlmock.AnyArg(), "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"cron_expr": "0 6 * * *"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-recompute-cron"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-recompute-cron", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_PartialRecomputeTimezone(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-recompute-tz", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-recompute-tz", nil, nil, "America/New_York", nil, nil, sqlmock.AnyArg(), "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"timezone": "America/New_York"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-recompute-tz"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-recompute-tz", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_InvalidTimezone(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-bad-tz", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"timezone": "Definitely/Not/Real"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-bad-tz"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-bad-tz", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid timezone, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid timezone") {
|
||||
t.Errorf("expected 'invalid timezone' error, got: %v", resp)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_InvalidCron(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-bad-cron", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"cron_expr": "rubbish"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-bad-cron"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-bad-cron", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid cron, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_NotFound(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-missing", "renamed", nil, nil, nil, nil, nil, "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0)) // no rows affected
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "renamed"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-missing"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-missing", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_DBError(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-update-err", "updated", nil, nil, nil, nil, nil, "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "updated"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-update-err"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-update-err", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_PromptCRLFStripped(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// Changing prompt with CRLF → handler strips \r before the UPDATE.
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-crlf-upd", nil, nil, nil, "fix\nthat", nil, nil, "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"prompt": "fix\r\nthat"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-crlf-upd"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-crlf-upd", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Delete ====================
|
||||
|
||||
func TestScheduleHandler_Delete_Success(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-del", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-del"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-del", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Delete_NotFound(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// IDOR guard: row belongs to different workspace → 0 rows affected → 404.
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-idor", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-idor"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-idor", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Delete_DBError(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-del-err", "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-del-err"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-del-err", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== RunNow ====================
|
||||
|
||||
func TestScheduleHandler_RunNow_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-ok", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"prompt"}).AddRow("run this prompt"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-ok"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-ok/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "fired" {
|
||||
t.Errorf("expected status 'fired', got %v", resp["status"])
|
||||
}
|
||||
if resp["prompt"] != "run this prompt" {
|
||||
t.Errorf("expected prompt 'run this prompt', got %q", resp["prompt"])
|
||||
}
|
||||
if resp["workspace_id"] != "ws-1" {
|
||||
t.Errorf("expected workspace_id 'ws-1', got %q", resp["workspace_id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_RunNow_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-missing", "ws-1").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-missing"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-missing/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_RunNow_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-err", "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-err"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-err/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== History ====================
|
||||
|
||||
func TestScheduleHandler_History_EmptyResult(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-empty", "sched-hist-empty").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"created_at", "duration_ms", "status", "error_detail", "request_body"}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-empty"}, {Key: "scheduleId", Value: "sched-hist-empty"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-empty/schedules/sched-hist-empty/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var entries []interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &entries)
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("expected empty history, got %d entries", len(entries))
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_History_QueryError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-err", "sched-hist-err").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-err"}, {Key: "scheduleId", Value: "sched-hist-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-err/schedules/sched-hist-err/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on query error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_History_MultipleEntries(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
now := time.Now()
|
||||
cols := []string{"created_at", "duration_ms", "status", "error_detail", "request_body"}
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-multi", "sched-hist-multi").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow(now, 1200, "ok", "", `{"schedule_id":"sched-hist-multi"}`).
|
||||
AddRow(now, 3500, "error", "HTTP 502 — upstream timeout", `{"schedule_id":"sched-hist-multi"}`))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-multi"}, {Key: "scheduleId", Value: "sched-hist-multi"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-multi/schedules/sched-hist-multi/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var entries []map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &entries)
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d: %s", len(entries), w.Body.String())
|
||||
}
|
||||
if entries[1]["error_detail"] != "HTTP 502 — upstream timeout" {
|
||||
t.Errorf("expected error_detail on second entry, got: %v", entries[1]["error_detail"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
wsKeys := map[string]bool{}
|
||||
secrets := make([]map[string]interface{}, 0)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM workspace_secrets WHERE workspace_id = $1 ORDER BY key`,
|
||||
workspaceID)
|
||||
if err != nil {
|
||||
@@ -63,9 +63,12 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
"updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("List workspace secrets iteration error: %v", err)
|
||||
}
|
||||
|
||||
// 2. Global secrets not overridden at workspace level
|
||||
globalRows, err := db.DB.QueryContext(ctx,
|
||||
globalRows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets (merged) error: %v", err)
|
||||
@@ -91,6 +94,9 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
"updated_at": updatedAt,
|
||||
})
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
log.Printf("List global secrets iteration error: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, secrets)
|
||||
}
|
||||
@@ -121,7 +127,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// Auth gate (Phase 30.1/30.2): enforce the bearer token when the
|
||||
// workspace has any live token on file. Grandfather legacy workspaces
|
||||
// through so a rolling upgrade doesn't lock them out.
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
// DB hiccup checking token existence — the handler's security
|
||||
// posture is "fail closed" here because unlike heartbeat, we're
|
||||
@@ -137,7 +143,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
@@ -151,7 +157,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// instead of returning a partial bundle that boots a broken agent.
|
||||
var failedKeys []string
|
||||
|
||||
globalRows, gErr := db.DB.QueryContext(ctx,
|
||||
globalRows, gErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM global_secrets`)
|
||||
if gErr == nil {
|
||||
defer globalRows.Close()
|
||||
@@ -174,9 +180,12 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
out[k] = string(decrypted)
|
||||
}
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
log.Printf("secrets.Values: global rows iteration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
wsRows, wErr := db.DB.QueryContext(ctx,
|
||||
wsRows, wErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1`,
|
||||
workspaceID)
|
||||
if wErr == nil {
|
||||
@@ -195,6 +204,9 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
out[k] = string(decrypted) // workspace override wins over global
|
||||
}
|
||||
}
|
||||
if err := wsRows.Err(); err != nil {
|
||||
log.Printf("secrets.Values: workspace rows iteration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(failedKeys) > 0 {
|
||||
@@ -238,7 +250,7 @@ func (h *SecretsHandler) Set(c *gin.Context) {
|
||||
// also rewrites the version — re-setting a secret while encryption
|
||||
// is enabled upgrades a historical plaintext row to AES-GCM.
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -268,7 +280,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = $2`,
|
||||
workspaceID, key)
|
||||
if err != nil {
|
||||
@@ -301,7 +313,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
// ListGlobal handles GET /admin/secrets
|
||||
func (h *SecretsHandler) ListGlobal(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets error: %v", err)
|
||||
@@ -324,6 +336,9 @@ func (h *SecretsHandler) ListGlobal(c *gin.Context) {
|
||||
"scope": "global",
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListGlobal iteration error: %v", err)
|
||||
}
|
||||
c.JSON(http.StatusOK, secrets)
|
||||
}
|
||||
|
||||
@@ -347,7 +362,7 @@ func (h *SecretsHandler) SetGlobal(c *gin.Context) {
|
||||
}
|
||||
|
||||
globalVersion := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
@@ -379,7 +394,7 @@ func (h *SecretsHandler) restartAllAffectedByGlobalKey(key string) {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE status NOT IN ('removed', 'paused')
|
||||
AND COALESCE(runtime, '') <> 'external'
|
||||
@@ -400,6 +415,9 @@ func (h *SecretsHandler) restartAllAffectedByGlobalKey(key string) {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("restartAllAffectedByGlobalKey: iteration error: %v", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -414,7 +432,7 @@ func (h *SecretsHandler) DeleteGlobal(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM global_secrets WHERE key = $1`, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"})
|
||||
@@ -446,7 +464,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// Check if MODEL_PROVIDER secret exists
|
||||
var modelBytes []byte
|
||||
var modelVersion int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID).Scan(&modelBytes, &modelVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -477,7 +495,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
if model == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -487,7 +505,7 @@ func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'MODEL_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -561,7 +579,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
|
||||
var bytesVal []byte
|
||||
var version int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID).Scan(&bytesVal, &version)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -594,7 +612,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setProviderSecret(ctx context.Context, workspaceID, provider string) error {
|
||||
if provider == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -604,7 +622,7 @@ func setProviderSecret(ctx context.Context, workspaceID, provider string) error
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'LLM_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
// Authenticate workspace agents (not canvas browser clients).
|
||||
if workspaceID != "" {
|
||||
ctx := c.Request.Context()
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: WebSocket HasAnyLiveToken(%s) failed: %v", workspaceID, err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -64,7 +64,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *SSEHandler) StreamEvents(c *gin.Context) {
|
||||
|
||||
// Verify the workspace exists — 404 early rather than serving an empty stream.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil {
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user