Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d104c20a8 |
@@ -0,0 +1,227 @@
|
||||
"""Tests for molecule_audit.hooks — EU AI Act Art. 12 pipeline hooks.
|
||||
|
||||
Covers:
|
||||
- LedgerHooks context manager (session lifecycle)
|
||||
- on_task_start / on_llm_call / on_tool_call / on_task_end hook methods
|
||||
- _safe_append error swallowing
|
||||
- _to_bytes helper
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure workspace root is on path
|
||||
_ws_root = __file__.rsplit("/tests/", 1)[0]
|
||||
if _ws_root not in sys.path:
|
||||
sys.path.insert(0, _ws_root)
|
||||
|
||||
from molecule_audit.hooks import LedgerHooks, _to_bytes
|
||||
|
||||
|
||||
class TestToBytes:
|
||||
"""Unit tests for the _to_bytes helper."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _to_bytes(None) is None
|
||||
|
||||
def test_bytes_passthrough(self):
|
||||
data = b"hello"
|
||||
assert _to_bytes(data) == data
|
||||
|
||||
def test_str_encoded_utf8(self):
|
||||
assert _to_bytes("hello") == b"hello"
|
||||
assert _to_bytes("こんにちは") == "こんにちは".encode("utf-8")
|
||||
|
||||
def test_dict_json_serialized(self):
|
||||
result = _to_bytes({"key": "value", "num": 42})
|
||||
assert b'"key"' in result and b'"value"' in result and b'"num"' in result
|
||||
|
||||
def test_list_json_serialized(self):
|
||||
result = _to_bytes([1, 2, "three"])
|
||||
# JSON encodes "three" as a string, so it has quotes
|
||||
assert b'"three"' in result or b'three' in result
|
||||
assert b"1" in result
|
||||
|
||||
def test_dict_sorted_keys(self):
|
||||
"""Dicts are JSON-serialized with sorted keys for deterministic output."""
|
||||
a = _to_bytes({"b": 1, "a": 2})
|
||||
b = _to_bytes({"a": 2, "b": 1})
|
||||
assert a == b
|
||||
|
||||
|
||||
class TestLedgerHooksInit:
|
||||
"""LedgerHooks constructor and defaults."""
|
||||
|
||||
def test_session_id_required(self):
|
||||
hooks = LedgerHooks(session_id="task-123")
|
||||
assert hooks.session_id == "task-123"
|
||||
|
||||
def test_agent_id_from_env(self):
|
||||
import os
|
||||
env_id = os.environ.get("WORKSPACE_ID", "unknown-agent")
|
||||
hooks = LedgerHooks(session_id="s1")
|
||||
assert hooks.agent_id == env_id
|
||||
|
||||
def test_agent_id_override(self):
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="explicit-agent")
|
||||
assert hooks.agent_id == "explicit-agent"
|
||||
|
||||
def test_db_url_stored(self):
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
assert hooks._db_url == "sqlite:///:memory:"
|
||||
|
||||
def test_human_oversight_default(self):
|
||||
hooks = LedgerHooks(session_id="s1")
|
||||
assert hooks._default_human_oversight is False
|
||||
|
||||
def test_human_oversight_true(self):
|
||||
hooks = LedgerHooks(session_id="s1", human_oversight_flag=True)
|
||||
assert hooks._default_human_oversight is True
|
||||
|
||||
|
||||
class TestLedgerHooksContextManager:
|
||||
"""LedgerHooks context manager lifecycle."""
|
||||
|
||||
def test_enter_returns_self(self):
|
||||
hooks = LedgerHooks(session_id="s1")
|
||||
with hooks as entered:
|
||||
assert entered is hooks
|
||||
|
||||
def test_exit_closes_session(self):
|
||||
mock_session = MagicMock()
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
# Pre-open the session via a mock
|
||||
hooks._session = mock_session
|
||||
hooks.__exit__(None, None, None)
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
def test_exit_no_session_noop(self):
|
||||
hooks = LedgerHooks(session_id="s1")
|
||||
# No session opened — should not raise
|
||||
hooks.__exit__(None, None, None)
|
||||
|
||||
|
||||
class TestLedgerHooksOpenSession:
|
||||
"""Lazy session opening."""
|
||||
|
||||
def test_opens_session_on_first_call(self):
|
||||
mock_factory = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_factory.return_value = mock_session
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
with patch("molecule_audit.hooks.get_session_factory", return_value=mock_factory):
|
||||
session = hooks._open_session()
|
||||
assert session is mock_session
|
||||
mock_factory.assert_called_once()
|
||||
|
||||
def test_reuses_same_session(self):
|
||||
mock_factory = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_factory.return_value = mock_session
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
with patch("molecule_audit.hooks.get_session_factory", return_value=mock_factory):
|
||||
s1 = hooks._open_session()
|
||||
s2 = hooks._open_session()
|
||||
assert s1 is s2
|
||||
# Factory called only once (lazy)
|
||||
assert mock_factory.call_count == 1
|
||||
|
||||
def test_close_resets_session(self):
|
||||
mock_session = MagicMock()
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
hooks._session = mock_session
|
||||
hooks.close()
|
||||
assert hooks._session is None
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
|
||||
class TestLedgerHooksHookMethods:
|
||||
"""Hook methods call _safe_append with correct kwargs."""
|
||||
|
||||
def _mock_hooks(self):
|
||||
"""Return a LedgerHooks with all ledger functions mocked."""
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
mock_session = MagicMock()
|
||||
hooks._session = mock_session
|
||||
return hooks, mock_session
|
||||
|
||||
def test_on_task_start_calls_append(self):
|
||||
hooks, mock_session = self._mock_hooks()
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_task_start(input_text="user prompt", risk_flag=True)
|
||||
mock_append.assert_called_once()
|
||||
call_kwargs = mock_append.call_args.kwargs
|
||||
assert call_kwargs["operation"] == "task_start"
|
||||
assert call_kwargs["human_oversight_flag"] is False
|
||||
assert call_kwargs["risk_flag"] is True
|
||||
|
||||
def test_on_task_start_respects_human_oversight_override(self):
|
||||
hooks, _ = self._mock_hooks()
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_task_start(human_oversight_flag=True)
|
||||
assert mock_append.call_args.kwargs["human_oversight_flag"] is True
|
||||
|
||||
def test_on_llm_call_includes_model(self):
|
||||
hooks, _ = self._mock_hooks()
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_llm_call(model="hermes-4-405b", input_text="prompt", output_text="response")
|
||||
call_kwargs = mock_append.call_args.kwargs
|
||||
assert call_kwargs["operation"] == "llm_call"
|
||||
assert call_kwargs["model_used"] == "hermes-4-405b"
|
||||
|
||||
def test_on_tool_call_includes_tool_name(self):
|
||||
hooks, _ = self._mock_hooks()
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_tool_call(tool_name="search", input_data={"q": "test"}, output_data=["result"])
|
||||
call_kwargs = mock_append.call_args.kwargs
|
||||
assert call_kwargs["operation"] == "tool_call"
|
||||
assert call_kwargs["model_used"] == "search"
|
||||
|
||||
def test_on_task_end_records_output(self):
|
||||
hooks, _ = self._mock_hooks()
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_task_end(output_text="final result")
|
||||
call_kwargs = mock_append.call_args.kwargs
|
||||
assert call_kwargs["operation"] == "task_end"
|
||||
|
||||
def test_hooks_use_instance_agent_id(self):
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="my-workspace", db_url="sqlite:///:memory:")
|
||||
mock_session = MagicMock()
|
||||
hooks._session = mock_session
|
||||
with patch("molecule_audit.hooks.append_event") as mock_append:
|
||||
hooks.on_task_start()
|
||||
assert mock_append.call_args.kwargs["agent_id"] == "my-workspace"
|
||||
|
||||
|
||||
class TestLedgerHooksSafeAppend:
|
||||
"""_safe_append swallows all exceptions without re-raising."""
|
||||
|
||||
def test_append_error_swallowed_and_logged(self, caplog):
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
mock_session = MagicMock()
|
||||
hooks._session = mock_session
|
||||
|
||||
error = RuntimeError("DB write failed")
|
||||
with patch("molecule_audit.hooks.append_event", side_effect=error):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
# Should not raise
|
||||
hooks.on_task_start(input_text="test")
|
||||
assert any("DB write failed" in r.message for r in caplog.records)
|
||||
|
||||
def test_multiple_exceptions_swallowed(self):
|
||||
hooks = LedgerHooks(session_id="s1", db_url="sqlite:///:memory:")
|
||||
mock_session = MagicMock()
|
||||
hooks._session = mock_session
|
||||
|
||||
with patch("molecule_audit.hooks.append_event", side_effect=RuntimeError("err")):
|
||||
# All three calls should succeed (no exception)
|
||||
hooks.on_task_start(input_text="a")
|
||||
hooks.on_llm_call(model="m", input_text="b")
|
||||
hooks.on_tool_call(tool_name="t", input_data={})
|
||||
@@ -1,266 +0,0 @@
|
||||
"""Tests for shared_runtime helper functions.
|
||||
|
||||
Covers the untested helpers in shared_runtime.py:
|
||||
- _extract_part_text
|
||||
- extract_message_text
|
||||
- format_conversation_history
|
||||
- build_task_text
|
||||
- append_peer_guidance
|
||||
- brief_task
|
||||
|
||||
Does NOT cover set_current_task (async, covered in test_a2a_executor.py).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
# Ensure the workspace root is on the path so 'shared_runtime' resolves
|
||||
_ws_root = __file__.rsplit("/tests/", 1)[0]
|
||||
if _ws_root not in sys.path:
|
||||
sys.path.insert(0, _ws_root)
|
||||
|
||||
from shared_runtime import (
|
||||
_extract_part_text,
|
||||
extract_message_text,
|
||||
format_conversation_history,
|
||||
build_task_text,
|
||||
append_peer_guidance,
|
||||
brief_task,
|
||||
)
|
||||
|
||||
|
||||
# ─── _extract_part_text ──────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractPartText:
|
||||
def test_dict_with_text(self):
|
||||
assert _extract_part_text({"text": "hello world"}) == "hello world"
|
||||
|
||||
def test_dict_with_nested_root_text(self):
|
||||
assert _extract_part_text({"root": {"text": "nested text"}}) == "nested text"
|
||||
|
||||
def test_dict_prefers_text_over_root(self):
|
||||
# When both text and root exist, text wins (outer text)
|
||||
assert _extract_part_text({"text": "outer", "root": {"text": "inner"}}) == "outer"
|
||||
|
||||
def test_dict_empty_text_and_root(self):
|
||||
assert _extract_part_text({"kind": "text"}) == ""
|
||||
|
||||
def test_dict_missing_fields(self):
|
||||
assert _extract_part_text({"kind": "image"}) == ""
|
||||
|
||||
def test_dict_mixed_with_extra_fields(self):
|
||||
assert _extract_part_text({"kind": "text", "text": "foo", "url": "http://..."}) == "foo"
|
||||
|
||||
def test_object_with_text_attribute(self):
|
||||
class PartObj:
|
||||
text = "object text"
|
||||
|
||||
assert _extract_part_text(PartObj()) == "object text"
|
||||
|
||||
def test_object_with_root_text_attribute(self):
|
||||
class RootObj:
|
||||
text = "root object text"
|
||||
|
||||
class PartObj:
|
||||
root = RootObj()
|
||||
|
||||
assert _extract_part_text(PartObj()) == "root object text"
|
||||
|
||||
def test_object_empty_text(self):
|
||||
class EmptyObj:
|
||||
text = ""
|
||||
|
||||
assert _extract_part_text(EmptyObj()) == ""
|
||||
|
||||
def test_object_no_text_or_root(self):
|
||||
class NoTextObj:
|
||||
pass
|
||||
|
||||
assert _extract_part_text(NoTextObj()) == ""
|
||||
|
||||
def test_none_like(self):
|
||||
assert _extract_part_text(None) == ""
|
||||
|
||||
|
||||
# ─── extract_message_text ────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractMessageText:
|
||||
def test_list_of_dict_parts(self):
|
||||
parts = [{"text": "hello"}, {"text": "world"}]
|
||||
assert extract_message_text(parts) == "hello world"
|
||||
|
||||
def test_single_part(self):
|
||||
parts = [{"text": "only one"}]
|
||||
assert extract_message_text(parts) == "only one"
|
||||
|
||||
def test_empty_list(self):
|
||||
assert extract_message_text([]) == ""
|
||||
|
||||
def test_none_parts(self):
|
||||
assert extract_message_text(None) == ""
|
||||
|
||||
def test_object_with_message_parts(self):
|
||||
"""Object with .message.parts attribute (A2A RequestContext pattern)."""
|
||||
msg = type("Message", (), {"parts": [{"text": "from context"}, {"text": "message"}]})()
|
||||
ctx = type("Context", (), {"message": msg})()
|
||||
assert extract_message_text(ctx) == "from context message"
|
||||
|
||||
def test_joins_with_single_space(self):
|
||||
# Inter-part join uses single space; internal whitespace within parts is preserved
|
||||
parts = [{"text": "hello"}, {"text": "world"}]
|
||||
assert extract_message_text(parts) == "hello world"
|
||||
|
||||
def test_preserves_within_part_whitespace(self):
|
||||
parts = [{"text": " spaced "}, {"text": "\ttext\t"}]
|
||||
# Leading/trailing whitespace stripped; internal whitespace within parts preserved
|
||||
assert extract_message_text(parts) == "spaced \ttext"
|
||||
|
||||
def test_skips_parts_without_text(self):
|
||||
parts = [{"kind": "image"}, {"text": "visible"}, {"url": "http://x"}]
|
||||
assert extract_message_text(parts) == "visible"
|
||||
|
||||
|
||||
# ─── format_conversation_history ──────────────────────────────────────────────
|
||||
|
||||
class TestFormatConversationHistory:
|
||||
def test_empty_history(self):
|
||||
assert format_conversation_history([]) == ""
|
||||
|
||||
def test_single_user_message(self):
|
||||
result = format_conversation_history([("human", "hello")])
|
||||
assert "User: hello" in result
|
||||
|
||||
def test_single_agent_message(self):
|
||||
result = format_conversation_history([("ai", "hi there")])
|
||||
assert "Agent: hi there" in result
|
||||
|
||||
def test_interleaved_history(self):
|
||||
history = [
|
||||
("human", "first"),
|
||||
("ai", "response one"),
|
||||
("human", "second"),
|
||||
("ai", "response two"),
|
||||
]
|
||||
result = format_conversation_history(history)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 4
|
||||
assert lines[0] == "User: first"
|
||||
assert lines[1] == "Agent: response one"
|
||||
assert lines[2] == "User: second"
|
||||
assert lines[3] == "Agent: response two"
|
||||
|
||||
|
||||
# ─── build_task_text ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestBuildTaskText:
|
||||
def test_no_history_returns_user_message(self):
|
||||
assert build_task_text("hello", []) == "hello"
|
||||
|
||||
def test_history_prepends_transcript(self):
|
||||
history = [("human", "hi"), ("ai", "hello")]
|
||||
result = build_task_text("send email", history)
|
||||
assert "Conversation so far:" in result
|
||||
assert "User: hi" in result
|
||||
assert "Agent: hello" in result
|
||||
assert "Current request: send email" in result
|
||||
|
||||
def test_empty_history_returns_user_message(self):
|
||||
# Empty list should behave like no history
|
||||
assert build_task_text("hello", []) == "hello"
|
||||
|
||||
def test_single_history_entry(self):
|
||||
result = build_task_text("bye", [("human", "last")])
|
||||
assert "User: last" in result
|
||||
assert "Current request: bye" in result
|
||||
|
||||
|
||||
# ─── append_peer_guidance ─────────────────────────────────────────────────────
|
||||
|
||||
class TestAppendPeerGuidance:
|
||||
def test_no_base_text_uses_default(self):
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
"peer info here",
|
||||
default_text="default",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
assert "peer info here" in result
|
||||
assert "## Peers" in result
|
||||
assert "delegate_task" in result
|
||||
assert "default" in result
|
||||
|
||||
def test_base_text_preserved(self):
|
||||
result = append_peer_guidance(
|
||||
"my prompt",
|
||||
"peer info",
|
||||
default_text="fallback",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
assert "my prompt" in result
|
||||
assert "## Peers" in result
|
||||
|
||||
def test_empty_peers_info_skipped(self):
|
||||
result = append_peer_guidance(
|
||||
"my prompt",
|
||||
"",
|
||||
default_text="fallback",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
assert result == "my prompt"
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
result = append_peer_guidance(
|
||||
" prompt ",
|
||||
" peers ",
|
||||
default_text="fallback",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
# Should not double-space
|
||||
assert " " not in result
|
||||
|
||||
def test_tool_name_injected(self):
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
"peer info",
|
||||
default_text="default",
|
||||
tool_name="my_tool",
|
||||
)
|
||||
assert "my_tool" in result
|
||||
|
||||
|
||||
# ─── brief_task ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestBriefTask:
|
||||
def test_short_text_unchanged(self):
|
||||
assert brief_task("hello world") == "hello world"
|
||||
|
||||
def test_exactly_at_limit(self):
|
||||
text = "a" * 60
|
||||
assert brief_task(text) == text
|
||||
|
||||
def test_over_limit_truncates(self):
|
||||
text = "a" * 100
|
||||
result = brief_task(text)
|
||||
assert len(result) == 63 # 60 + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_under_limit_no_ellipsis(self):
|
||||
text = "a" * 59
|
||||
result = brief_task(text)
|
||||
assert result == text
|
||||
assert "..." not in result
|
||||
|
||||
def test_default_limit_60(self):
|
||||
text = "a" * 70
|
||||
result = brief_task(text, limit=60)
|
||||
assert len(result) == 63
|
||||
|
||||
def test_custom_limit(self):
|
||||
text = "a" * 20
|
||||
result = brief_task(text, limit=10)
|
||||
assert len(result) == 13 # 10 + "..."
|
||||
|
||||
def test_empty_string(self):
|
||||
assert brief_task("") == ""
|
||||
assert brief_task("") == "" # no ellipsis for empty
|
||||
Reference in New Issue
Block a user