#!/usr/bin/env python3
"""
Mavis (MiniMax Agent) -> Litefuse observability hook, v2.

Implements litefuse-agent-trace-spec.md v1.2:
- one user turn = one trace, named "Mavis <Agent> — Turn N"
- single AGENT root span, all observations flat under it
- one GENERATION per LLM call, named plan (n tools) #N / response / think #N
- one TOOL span per tool execution, named "tool: <name> (<info>) #N"
- real usage/cost from ~/.mavis/sqlite.db token_usage (no token estimation)
- each span sent exactly once when it ends; tool spans sent live at PostToolUse
- trace header attributes ride on every span
- zero dependencies: bare OTLP/HTTP JSON, stdlib only

Architecture: Mavis hook events are the *triggers*; the Mavis sqlite DB
(session_messages + token_usage) is the *data plane*. Hook payloads alone
carry no usage/model/message boundaries, the DB has all of it.

Events: SessionStart, UserPromptSubmit, MessageComplete,
        PreToolUse, PostToolUse, SessionEnd (fires at each turn end)

Invocation contract (unchanged from v1, hook .md files need no edits):
    litefuse_hook.py <EventName>   with {"input":..., "output":...} on stdin

Environment (LITEFUSE_* primary, LANGFUSE_* fallback):
    TRACE_TO_LITEFUSE              "false" disables (default: on if keys set)
    LITEFUSE_PUBLIC_KEY / LITEFUSE_SECRET_KEY
    LITEFUSE_HOST | LITEFUSE_BASE_URL   default https://litefuse.cloud
    LITEFUSE_TRACING_ENVIRONMENT        default "production"
    LITEFUSE_USER_ID
    LITEFUSE_EXTRA_TARGETS         JSON [{host, public_key, secret_key, environment}]
    MAVIS_LITEFUSE_DEBUG           "true" for verbose logging
    MAVIS_LITEFUSE_MAX_CHARS       truncation threshold, default 1_000_000
    MAVIS_LITEFUSE_DB              override sqlite path (testing)
"""

import base64
import fcntl
import getpass
import json
import os
import re
import secrets
import sqlite3
import sys
import time
import urllib.request
from datetime import datetime, timezone
from pathlib import Path

_HOOK_DIR = Path(os.path.expanduser("~/.mavis/hooks"))
_STATE_DIR = _HOOK_DIR / "litefuse_state"
_LOG_FILE = _HOOK_DIR / "litefuse_hook.log"
_ENV_FILE = Path(os.path.expanduser("~/.mavis/.env"))

_STATE_DIR.mkdir(parents=True, exist_ok=True)

DEBUG = False  # resolved after env load
MAX_CHARS = 1_000_000


# ---------------------------------------------------------------------------
# env + logging
# ---------------------------------------------------------------------------
def _load_env_file(path: Path):
    """Defensive .env load: hook .md scripts already `source ~/.mavis/.env`,
    but keep working if invoked another way. Never overrides existing env."""
    try:
        if not path.exists():
            return
        for line in path.read_text().splitlines():
            line = line.strip()
            if not line or line.startswith("#") or "=" not in line:
                continue
            k, _, v = line.partition("=")
            k, v = k.strip(), v.strip().strip('"').strip("'")
            if k and k not in os.environ:
                os.environ[k] = v
    except Exception:
        pass


def env(*names, default=None):
    """First set variable among names; LITEFUSE_* callers list fallbacks."""
    for n in names:
        v = os.environ.get(n)
        if v:
            return v
    return default


def log(msg: str, level: str = "INFO"):
    if level != "ERROR" and not DEBUG:
        return
    try:
        ts = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
        with open(_LOG_FILE, "a") as f:
            f.write(f"[{ts}] [{level}] {msg}\n")
    except Exception:
        pass


# ---------------------------------------------------------------------------
# targets
# ---------------------------------------------------------------------------
def get_targets():
    """[{host, auth, environment}]; empty list disables tracing."""
    if env("TRACE_TO_LITEFUSE", default="true").lower() == "false":
        return []
    targets = []
    pk = env("LITEFUSE_PUBLIC_KEY", "LANGFUSE_PUBLIC_KEY")
    sk = env("LITEFUSE_SECRET_KEY", "LANGFUSE_SECRET_KEY")
    host = env("LITEFUSE_HOST", "LITEFUSE_BASE_URL", "LANGFUSE_HOST",
               "LANGFUSE_BASE_URL", default="https://litefuse.cloud")
    environment = env("LITEFUSE_TRACING_ENVIRONMENT",
                      "LANGFUSE_TRACING_ENVIRONMENT", default="production")
    if pk and sk:
        targets.append({
            "host": host.rstrip("/"),
            "auth": base64.b64encode(f"{pk}:{sk}".encode()).decode(),
            "environment": environment,
        })
    extra = env("LITEFUSE_EXTRA_TARGETS")
    if extra:
        try:
            for t in json.loads(extra):
                if t.get("public_key") and t.get("secret_key") and t.get("host"):
                    targets.append({
                        "host": t["host"].rstrip("/"),
                        "auth": base64.b64encode(
                            f"{t['public_key']}:{t['secret_key']}".encode()).decode(),
                        "environment": t.get("environment", environment),
                    })
        except Exception as e:
            log(f"bad LITEFUSE_EXTRA_TARGETS: {e}", "ERROR")
    return targets


# ---------------------------------------------------------------------------
# state (per session, flock-guarded)
# ---------------------------------------------------------------------------
def state_path(session_id: str) -> Path:
    safe = re.sub(r"[^A-Za-z0-9_.-]", "_", session_id)
    return _STATE_DIR / f"{safe}.json"


class SessionLock:
    def __init__(self, session_id: str):
        self._path = state_path(session_id).with_suffix(".lock")
        self._fh = None

    def __enter__(self):
        self._fh = open(self._path, "w")
        fcntl.flock(self._fh, fcntl.LOCK_EX)
        return self

    def __exit__(self, *exc):
        try:
            fcntl.flock(self._fh, fcntl.LOCK_UN)
            self._fh.close()
        except Exception:
            pass


def load_state(session_id: str) -> dict:
    p = state_path(session_id)
    if p.exists():
        try:
            return json.loads(p.read_text())
        except Exception:
            pass
    return {}


def save_state(session_id: str, state: dict):
    try:
        tmp = state_path(session_id).with_suffix(".tmp")
        tmp.write_text(json.dumps(state, ensure_ascii=False))
        tmp.rename(state_path(session_id))
    except Exception as e:
        log(f"save_state failed: {e}", "ERROR")


def clear_state(session_id: str):
    for suffix in (".json", ".lock"):
        try:
            state_path(session_id).with_suffix(suffix).unlink(missing_ok=True)
        except Exception:
            pass


def is_child_session(session_id: str) -> bool:
    """opencode-internal subagent sessions (task tool children) are ses_*;
    Mavis daemon sessions are mvs_*. Child tool hooks fire via the opencode
    proxy, but their spans belong to the parent's subtree (built at the
    parent's PostToolUse) -- emitting them here would create orphan traces."""
    return session_id.startswith("ses_")


# ---------------------------------------------------------------------------
# mavis sqlite (read-only)
# ---------------------------------------------------------------------------
def db_connect():
    db = os.environ.get("MAVIS_LITEFUSE_DB") or os.path.expanduser("~/.mavis/sqlite.db")
    conn = sqlite3.connect(f"file:{db}?mode=ro", uri=True, timeout=3)
    return conn


def db_user_turn_count(session_id: str) -> int:
    try:
        with db_connect() as c:
            row = c.execute(
                "SELECT COUNT(*) FROM session_messages WHERE session_id=? AND role='user'",
                (session_id,)).fetchone()
            return row[0] if row else 0
    except Exception as e:
        log(f"db_user_turn_count: {e}", "ERROR")
        return 0


def db_max_rowid(session_id: str) -> int:
    try:
        with db_connect() as c:
            row = c.execute(
                "SELECT COALESCE(MAX(id),0) FROM session_messages WHERE session_id=?",
                (session_id,)).fetchone()
            return row[0] if row else 0
    except Exception as e:
        log(f"db_max_rowid: {e}", "ERROR")
        return 0


def db_turn_messages(session_id: str, after_rowid: int) -> list:
    """Assistant messages of the current turn, parsed, in arrival order."""
    out = []
    try:
        with db_connect() as c:
            rows = c.execute(
                "SELECT id, data FROM session_messages "
                "WHERE session_id=? AND id>? AND role='assistant' ORDER BY id",
                (session_id, after_rowid)).fetchall()
        for rowid, data in rows:
            try:
                d = json.loads(data)
            except Exception:
                continue
            if d.get("msg_type") in (1, 2):
                d["_rowid"] = rowid
                out.append(d)
    except Exception as e:
        log(f"db_turn_messages: {e}", "ERROR")
    return out


def db_child_messages(session_id: str) -> list:
    """All parsed messages of a (subagent) session, in arrival order."""
    out = []
    try:
        with db_connect() as c:
            rows = c.execute(
                "SELECT id, data FROM session_messages WHERE session_id=? ORDER BY id",
                (session_id,)).fetchall()
        for rowid, data in rows:
            try:
                d = json.loads(data)
            except Exception:
                continue
            if d.get("msg_type") in (1, 2):
                d["_rowid"] = rowid
                out.append(d)
    except Exception as e:
        log(f"db_child_messages: {e}", "ERROR")
    return out


def db_usage_for(msg_ids: list) -> dict:
    """msg_id -> token_usage row dict."""
    if not msg_ids:
        return {}
    out = {}
    try:
        with db_connect() as c:
            qs = ",".join("?" * len(msg_ids))
            rows = c.execute(
                f"SELECT turn_id, model, ts, input_tokens, output_tokens, "
                f"reasoning_tokens, cache_read_tokens, cache_write_tokens, cost_usd "
                f"FROM token_usage WHERE turn_id IN ({qs})", msg_ids).fetchall()
        for r in rows:
            out[r[0]] = {
                "model": r[1], "ts": r[2], "input": r[3], "output": r[4],
                "reasoning": r[5], "cache_read": r[6], "cache_write": r[7],
                "cost_usd": r[8],
            }
    except Exception as e:
        log(f"db_usage_for: {e}", "ERROR")
    return out


# ---------------------------------------------------------------------------
# naming helpers
# ---------------------------------------------------------------------------
_SHELL_PREFIXES = {"cd", "export", "sudo", "env", "nohup", "time", "exec",
                   "command", "builtin", "set", "source", "."}

# opencode-native delegation tools: the child runs as an internal ses_*
# session that never crosses the Mavis hook bridge — its trace subtree is
# reconstructed from sqlite when the delegation tool returns (spec §2.5)
DELEGATION_TOOLS = {"task"}


def parse_child_session_id(result_text: str):
    m = re.search(r"task_id:\s*(ses_[A-Za-z0-9]+)", result_text or "")
    if m:
        return m.group(1)
    m = re.search(r"\b(ses_[A-Za-z0-9]{10,})\b", result_text or "")
    return m.group(1) if m else None


def _bash_keyinfo(command: str) -> str:
    first = re.split(r"[;|&\n]", command or "", 1)[0].strip()
    for tok in first.split():
        if "=" in tok and not tok.startswith("="):
            continue  # FOO=bar env assignment
        if tok.startswith("-") or tok.startswith('"') or tok.startswith("'"):
            continue
        name = os.path.basename(tok)
        if name in _SHELL_PREFIXES:
            continue
        return name
    return ""


def tool_keyinfo(tool_name: str, args: dict) -> str:
    try:
        if not isinstance(args, dict):
            return ""
        if tool_name == "bash":
            info = _bash_keyinfo(str(args.get("command", "")))
        elif tool_name in ("read", "write", "edit"):
            p = args.get("filePath") or args.get("file_path") or args.get("path") or ""
            info = os.path.basename(str(p))
        elif tool_name in ("grep", "glob"):
            info = str(args.get("pattern", ""))
        elif tool_name == "ls":
            p = str(args.get("path") or "")
            info = os.path.basename(p.rstrip("/")) or p or ""
        elif tool_name == "webfetch":
            m = re.match(r"https?://([^/]+)", str(args.get("url", "")))
            info = m.group(1) if m else ""
        elif tool_name == "todowrite":
            todos = args.get("todos")
            info = f"{len(todos)} todos" if isinstance(todos, list) else ""
        elif tool_name == "skill":
            info = str(args.get("name") or args.get("skill") or args.get("id") or "")
        else:
            for k in ("filePath", "file_path", "path", "name", "pattern",
                      "query", "url", "command"):
                if args.get(k):
                    v = str(args[k])
                    info = os.path.basename(v) if "/" in v else v
                    break
            else:
                info = ""
        info = (info or "").strip()
        return info[:24]
    except Exception:
        return ""


def tool_span_name(tool_name: str, args: dict, step: int) -> str:
    info = tool_keyinfo(tool_name, args)
    if info:
        return f"tool: {tool_name} ({info}) #{step}"
    return f"tool: {tool_name} #{step}"


def split_model(model: str):
    """'minimax/MiniMax-M2.7' -> ('minimax', 'MiniMax-M2.7')"""
    if model and "/" in model:
        p, _, m = model.partition("/")
        return p, m
    return ("minimax", model or "")


def agent_title(agent_name: str) -> str:
    return f"Mavis {(agent_name or 'agent').capitalize()}"


# ---------------------------------------------------------------------------
# truncation
# ---------------------------------------------------------------------------
def truncate_text(value, meta: dict, key_prefix: str):
    s = value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
    if len(s) > MAX_CHARS:
        meta[f"{key_prefix}_truncated"] = True
        meta[f"{key_prefix}_orig_len"] = len(s)
        return s[:MAX_CHARS]
    return s if isinstance(value, str) else value


# ---------------------------------------------------------------------------
# OTLP
# ---------------------------------------------------------------------------
def _attr(key, value):
    return {"key": key, "value": {"stringValue": value}}


def _ms_to_ns(ms) -> int:
    return int(ms) * 1_000_000


def _iso(ms) -> str:
    return datetime.fromtimestamp(ms / 1000, tz=timezone.utc).isoformat()


def build_span(turn: dict, *, span_id: str, parent_span_id, name: str,
               obs_type: str, start_ms: int, end_ms: int,
               input_val=None, output_val=None, model: str = None,
               usage_details: dict = None, cost_details: dict = None,
               level: str = None, status_message: str = None,
               metadata: dict = None, completion_start_ms: int = None,
               trace_output=None, session_id: str = None,
               user_id: str = None) -> dict:
    """One OTLP span dict (without langfuse.environment; appended per target)."""
    if end_ms <= start_ms:
        end_ms = start_ms + 1
    meta = dict(metadata or {})
    attrs = [_attr("langfuse.observation.type", obs_type)]
    if input_val is not None:
        v = truncate_text(input_val, meta, "agent_input")
        attrs.append(_attr("langfuse.observation.input",
                           v if isinstance(v, str) else json.dumps(v, ensure_ascii=False)))
    if output_val is not None:
        v = truncate_text(output_val, meta, "agent_output")
        attrs.append(_attr("langfuse.observation.output",
                           v if isinstance(v, str) else json.dumps(v, ensure_ascii=False)))
    if model:
        attrs.append(_attr("langfuse.observation.model.name", model))
    if usage_details:
        attrs.append(_attr("langfuse.observation.usage_details", json.dumps(usage_details)))
    if cost_details:
        attrs.append(_attr("langfuse.observation.cost_details", json.dumps(cost_details)))
    if level:
        attrs.append(_attr("langfuse.observation.level", level))
    if status_message:
        attrs.append(_attr("langfuse.observation.status_message", status_message[:500]))
    if completion_start_ms and start_ms < completion_start_ms < end_ms:
        attrs.append(_attr("langfuse.observation.completion_start_time",
                           json.dumps(_iso(completion_start_ms))))
    if meta:
        attrs.append(_attr("langfuse.observation.metadata", json.dumps(meta, ensure_ascii=False)))

    # trace header rides on every span (spec §7.3)
    attrs.append(_attr("langfuse.trace.name", turn["trace_name"]))
    tmeta = {}
    attrs.append(_attr("langfuse.trace.input",
                       truncate_text(turn.get("prompt") or "", tmeta, "agent_input")))
    if trace_output is not None:
        attrs.append(_attr("langfuse.trace.output", truncate_text(trace_output, tmeta, "agent_output")))
    attrs.append(_attr("langfuse.trace.tags", json.dumps(turn.get("tags") or [])))
    if turn.get("trace_metadata"):
        attrs.append(_attr("langfuse.trace.metadata",
                           json.dumps(turn["trace_metadata"], ensure_ascii=False)))
    attrs.append(_attr("session.id", session_id))
    attrs.append(_attr("user.id", user_id))

    return {
        "traceId": turn["trace_id"],
        "spanId": span_id,
        **({"parentSpanId": parent_span_id} if parent_span_id else {}),
        "name": name,
        "kind": 1,
        "startTimeUnixNano": str(_ms_to_ns(start_ms)),
        "endTimeUnixNano": str(_ms_to_ns(end_ms)),
        "attributes": attrs,
        "status": {"code": 2 if level == "ERROR" else 1},
    }


def send_spans(spans: list):
    """POST spans to every target, fail-open, short timeout."""
    if not spans:
        return
    targets = get_targets()
    for t in targets:
        try:
            tspans = []
            for s in spans:
                s2 = dict(s)
                s2["attributes"] = s["attributes"] + [
                    _attr("langfuse.environment", t["environment"])]
                tspans.append(s2)
            payload = {"resourceSpans": [{
                "resource": {"attributes": [_attr("service.name", "mavis-agent")]},
                "scopeSpans": [{"scope": {"name": "mavis-litefuse-hook"},
                                "spans": tspans}],
            }]}
            req = urllib.request.Request(
                f"{t['host']}/api/public/otel/v1/traces",
                data=json.dumps(payload).encode(),
                method="POST",
                headers={"Content-Type": "application/json",
                         "Authorization": f"Basic {t['auth']}"})
            with urllib.request.urlopen(req, timeout=10) as resp:
                log(f"sent {len(spans)} span(s) -> {t['host']} HTTP {resp.status}")
        except Exception as e:
            log(f"send to {t['host']} failed: {e}", "ERROR")


# ---------------------------------------------------------------------------
# turn helpers
# ---------------------------------------------------------------------------
def get_user_id() -> str:
    u = env("LITEFUSE_USER_ID")
    if u:
        return u
    try:
        return os.environ.get("USER") or getpass.getuser()
    except Exception:
        return "mavis-user"


def new_turn(state: dict, session_id: str, prompt: str) -> dict:
    # sqlite count handles resumed sessions; state count keeps numbering
    # monotonic if the db ever becomes unreadable
    number = max(db_user_turn_count(session_id),
                 state.get("turns_emitted", 0)) + 1
    agent = state.get("agent_name") or "agent"
    turn = {
        "number": number,
        "trace_id": secrets.token_hex(16),
        "root_span_id": secrets.token_hex(8),
        "trace_name": f"{agent_title(agent)} — Turn {number}",
        "start_ms": int(time.time() * 1000),
        "prompt": prompt,
        "step": 0,
        "model": state.get("model") or "",
        "tools": {},
        "tags": [],
        "trace_metadata": {},
        "final_content": None,
        "final_content_ms": None,
        "start_rowid": db_max_rowid(session_id),
    }
    state["turn"] = turn
    return turn


def turn_tags(state: dict, turn: dict) -> list:
    agent = state.get("agent_name") or "agent"
    tags = [f"mavis-{agent}"]
    _, model = split_model(turn.get("model") or state.get("model") or "")
    if model:
        tags.append(f"model:{model}")
    return tags


def litefuse_session_id(state: dict, session_id: str) -> str:
    # subagent sessions group under the parent's session view (spec §1.2;
    # full §2.5 subtrees are not reachable from Mavis hooks, see docs)
    return state.get("parent_session_id") or session_id


def base_trace_metadata(state: dict, turn: dict, session_id: str) -> dict:
    provider, model = split_model(turn.get("model") or state.get("model") or "")
    md = {
        "agent_turn_number": turn["number"],
        "agent_session_id": session_id,
        "agent_cwd": state.get("workspace_dir") or "",
        "agent_model": model,
        "agent_provider": provider,
    }
    if state.get("parent_session_id"):
        md["agent_parent_session_id"] = state["parent_session_id"]
        md["agent_subagent"] = True
    return md


def ensure_turn(state: dict, session_id: str) -> dict:
    """Defensive: hooks may fire on a session whose prompt we never saw."""
    if state.get("turn"):
        return state["turn"]
    log(f"no open turn for {session_id}, creating implicit one", "WARN")
    return new_turn(state, session_id, "")


# ---------------------------------------------------------------------------
# event handlers
# ---------------------------------------------------------------------------
def handle_session_start(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    with SessionLock(session_id):
        state = load_state(session_id)
        state.update({
            "session_id": session_id,
            "agent_name": inp.get("agentName") or state.get("agent_name") or "agent",
            "workspace_dir": inp.get("workspaceDir") or state.get("workspace_dir") or "",
            "parent_session_id": inp.get("parentSessionId") or state.get("parent_session_id"),
        })
        state.setdefault("turn", None)
        save_state(session_id, state)
    log(f"SessionStart {session_id} agent={state['agent_name']} "
        f"parent={state.get('parent_session_id')}")


def handle_user_prompt_submit(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    if is_child_session(session_id):
        return
    with SessionLock(session_id):
        state = load_state(session_id)
        state.setdefault("session_id", session_id)
        state.setdefault("agent_name", inp.get("agentName") or "agent")
        # a turn left open means the previous SessionEnd never fired
        if state.get("turn"):
            log(f"open turn superseded by new prompt, finalizing ({session_id})", "WARN")
            finalize_turn(state, session_id, reason="superseded")
        new_turn(state, session_id, inp.get("prompt") or "")
        save_state(session_id, state)
    log(f"UserPromptSubmit {session_id} turn={state['turn']['number']}")


def handle_pre_tool_use(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    call_id = inp.get("toolCallId") or ""
    now_ms = int(time.time() * 1000)
    if is_child_session(session_id):
        # only record real tool timings for the parent's subtree rebuild
        with SessionLock(session_id):
            state = load_state(session_id)
            state.setdefault("child_tools", {})[call_id] = {"start_ms": now_ms}
            save_state(session_id, state)
        return
    with SessionLock(session_id):
        state = load_state(session_id)
        state.setdefault("session_id", session_id)
        state.setdefault("agent_name", inp.get("agentName") or "agent")
        turn = ensure_turn(state, session_id)
        if inp.get("model"):
            turn["model"] = inp["model"]
            state["model"] = inp["model"]
        if call_id in turn["tools"]:
            save_state(session_id, state)
            return
        # record real wall-clock timing only; numbering and emission happen
        # at turn finalization from the authoritative sqlite message order
        turn["tools"][call_id] = {
            "name": inp.get("toolName") or "?",
            "start_ms": now_ms,
            "args": inp.get("toolArgs") if isinstance(inp.get("toolArgs"), dict) else {},
        }
        save_state(session_id, state)
    log(f"PreToolUse {session_id} {inp.get('toolName')} call={call_id}")


def _result_text(tool_result):
    if tool_result is None:
        return ""
    if isinstance(tool_result, str):
        return tool_result
    return json.dumps(tool_result, ensure_ascii=False)


def _result_is_error(tool_result) -> bool:
    if isinstance(tool_result, dict):
        if tool_result.get("isError") or tool_result.get("is_error") or tool_result.get("error"):
            return True
    s = _result_text(tool_result)
    head = s[:200].lower()
    return (head.startswith(("error", "exception", "traceback"))
            or '"success": false' in head or '"iserror": true' in head)


def handle_post_tool_use(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    call_id = inp.get("toolCallId") or ""
    now_ms = int(time.time() * 1000)
    if is_child_session(session_id):
        with SessionLock(session_id):
            state = load_state(session_id)
            entry = state.setdefault("child_tools", {}).setdefault(call_id, {})
            entry["end_ms"] = now_ms
            save_state(session_id, state)
        return
    # record only -- all spans are emitted at turn finalization, where step
    # numbering can follow the authoritative message order in sqlite (the
    # proxy hook events race far ahead of the daemon's message persistence)
    with SessionLock(session_id):
        state = load_state(session_id)
        state.setdefault("session_id", session_id)
        state.setdefault("agent_name", inp.get("agentName") or "agent")
        turn = ensure_turn(state, session_id)
        tc = turn["tools"].get(call_id)
        if tc is None:
            # PreToolUse missed (hook installed mid-turn / daemon dedup quirk)
            tc = {"name": inp.get("toolName") or "?",
                  "start_ms": now_ms,
                  "args": inp.get("toolArgs") if isinstance(inp.get("toolArgs"), dict) else {}}
            turn["tools"][call_id] = tc
        result_text = _result_text(inp.get("toolResult"))
        tc.update({
            "end_ms": now_ms,
            "is_error": _result_is_error(inp.get("toolResult")),
            "result": result_text[:MAX_CHARS],
        })
        save_state(session_id, state)
    log(f"PostToolUse recorded {tc['name']} ({now_ms - tc['start_ms']}ms)")


def handle_message_complete(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    if is_child_session(session_id):
        return
    with SessionLock(session_id):
        state = load_state(session_id)
        if not state.get("turn"):
            return
        state["turn"]["final_content"] = inp.get("content") or ""
        state["turn"]["final_content_ms"] = int(time.time() * 1000)
        save_state(session_id, state)
    log(f"MessageComplete {session_id} len={len(inp.get('content') or '')}")


def handle_session_end(inp: dict):
    session_id = inp.get("sessionId") or "unknown"
    reason = inp.get("reason") or ""
    if is_child_session(session_id):
        return
    spans = None
    with SessionLock(session_id):
        state = load_state(session_id)
        if not state.get("turn"):
            log(f"SessionEnd {session_id} reason={reason}: no open turn")
            return
        spans = finalize_turn(state, session_id, reason=reason, send=False)
        save_state(session_id, state)
    if spans:
        send_spans(spans)
    log(f"SessionEnd {session_id} reason={reason} emitted={len(spans or [])}")


# ---------------------------------------------------------------------------
# subagent subtree (spec §2.5): delegation tool span -> container -> steps
# ---------------------------------------------------------------------------
def build_subagent_subtree(turn: dict, lf_session: str, user_id: str,
                           parent_span_id: str, child_id: str,
                           task_input, depth: int = 0) -> list:
    """Reconstruct a finished ses_* child session from sqlite as a subtree
    under the delegation tool span. Steps renumber from #1 per container.
    Tool start/end inside the child are estimated from message timestamps
    (no live hook events exist for opencode-internal sessions)."""
    if depth > 3:
        return []
    msgs = db_child_messages(child_id)
    if not msgs:
        log(f"subagent {child_id}: no messages in sqlite, skipping subtree", "WARN")
        return []
    usage_map = db_usage_for([m["msg_id"] for m in msgs if m.get("msg_id")])
    # real tool timings recorded live by the child's own (suppressed) hooks
    child_timings = (load_state(child_id) or {}).get("child_tools", {})

    container_id = secrets.token_hex(8)
    first_ts = min(m.get("timestamp") or 0 for m in msgs) or int(time.time() * 1000)
    last_ts = max([m.get("timestamp") or 0 for m in msgs]
                  + [u["ts"] for u in usage_map.values() if u.get("ts")])
    content_msgs = [m for m in msgs
                    if m.get("msg_type") == 1 and (m.get("msg_content") or "").strip()]
    response_msg_id = content_msgs[-1]["msg_id"] if content_msgs else None
    final_output = content_msgs[-1].get("msg_content") if content_msgs else None

    spans = []
    step = 0
    api_calls = 0
    tool_count = 0
    prev_end = first_ts
    gen_input = task_input
    for i, m in enumerate(msgs):
        mid = m.get("msg_id")
        tool_calls = m.get("tool_calls") or []
        has_text = bool((m.get("msg_content") or "").strip())
        has_thinking = bool(m.get("thinking_content"))
        if not (tool_calls or has_text or has_thinking):
            continue
        step += 1
        gen_step = step
        if tool_calls:
            n = len(tool_calls)
            name = f"plan ({n} tool{'s' if n != 1 else ''}) #{gen_step}"
        elif has_text and mid == response_msg_id:
            name = "subagent response"
        elif has_text:
            name = f"generation #{gen_step}"
        else:
            name = f"think #{gen_step}"
        u = usage_map.get(mid)
        end_ms = (u or {}).get("ts") or m.get("timestamp") or last_ts
        start_ms = min(prev_end, end_ms - 1)
        next_ts = (msgs[i + 1].get("timestamp") if i + 1 < len(msgs) else None) or last_ts
        meta = {
            "agent_step_index": gen_step,
            "agent_provider": split_model((u or {}).get("model") or "")[0],
            "agent_input_is_delta": True,
        }
        if m.get("finish_reason"):
            meta["agent_stop_reason"] = m["finish_reason"]
        if tool_calls:
            meta["agent_tool_call_count"] = len(tool_calls)
        if has_thinking:
            meta["agent_thinking_chars"] = len(m["thinking_content"])
        spans.append(build_span(
            turn,
            span_id=secrets.token_hex(8),
            parent_span_id=container_id,
            name=name,
            obs_type="generation",
            start_ms=start_ms, end_ms=end_ms,
            input_val=gen_input,
            output_val=_gen_output_blocks(m),
            model=split_model((u or {}).get("model") or "")[1] or None,
            usage_details=_usage_details(u),
            cost_details={"total": u["cost_usd"]} if u and u.get("cost_usd") else None,
            metadata=meta,
            completion_start_ms=m.get("timestamp"),
            session_id=lf_session, user_id=user_id,
        ))
        api_calls += 1
        tool_results = []
        for tc in tool_calls:
            step += 1
            tool_count += 1
            args = tc.get("tool_call_args")
            try:
                args = json.loads(args) if isinstance(args, str) else (args or {})
            except Exception:
                args = {"_raw": tc.get("tool_call_args")}
            result = tc.get("tool_call_result_data") or ""
            is_error = _result_is_error(result)
            tname = tc.get("tool_name") or "?"
            tool_span_id = secrets.token_hex(8)
            grandchild = None
            if tname in DELEGATION_TOOLS:
                grandchild = parse_child_session_id(result)
            t_name = (f"tool (1 subagent) #{step}" if grandchild
                      else tool_span_name(tname, args, step))
            timing = child_timings.get(tc.get("tool_call_id") or "", {})
            t_start = timing.get("start_ms") or end_ms
            t_end = timing.get("end_ms") or next_ts
            t_meta = {
                "agent_tool_name": tname,
                "agent_tool_call_id": tc.get("tool_call_id") or "",
                "agent_step_index": step,
                "agent_plan_step": gen_step,
                "agent_is_error": is_error,
            }
            if timing.get("start_ms") and timing.get("end_ms"):
                t_meta["agent_duration_ms"] = t_end - t_start
            else:
                t_meta["agent_times_estimated"] = True
            spans.append(build_span(
                turn,
                span_id=tool_span_id,
                parent_span_id=container_id,
                name=t_name,
                obs_type="tool",
                start_ms=t_start, end_ms=t_end,
                input_val=args,
                output_val=str(result),
                level="ERROR" if is_error else None,
                status_message=str(result)[:500] if is_error else None,
                metadata=t_meta,
                session_id=lf_session, user_id=user_id,
            ))
            if grandchild:
                spans.extend(build_subagent_subtree(
                    turn, lf_session, user_id, tool_span_id, grandchild,
                    args.get("prompt") or args, depth + 1))
            tool_results.append({"tool_call_id": tc.get("tool_call_id"),
                                 "tool_name": tname,
                                 "result": str(result)[:4000]})
        if tool_results:
            gen_input = {"tool_results": tool_results}
        elif has_text:
            gen_input = m.get("msg_content") or ""
        prev_end = next_ts

    spans.append(build_span(
        turn,
        span_id=container_id,
        parent_span_id=parent_span_id,
        name="subagent",
        obs_type="agent",
        start_ms=first_ts, end_ms=last_ts,
        input_val=task_input,
        output_val=final_output,
        metadata={
            "agent_subagent": True,
            "agent_session_id": child_id,
            "agent_api_calls": api_calls,
            "agent_tool_calls": tool_count,
            "agent_steps": step,
            "agent_duration_ms": last_ts - first_ts,
        },
        session_id=lf_session, user_id=user_id,
    ))
    clear_state(child_id)
    return spans


# ---------------------------------------------------------------------------
# turn finalization: generations + root from sqlite, single send each
# ---------------------------------------------------------------------------
def _gen_output_blocks(msg: dict) -> list:
    blocks = []
    if msg.get("thinking_content"):
        blocks.append({"type": "thinking", "thinking": msg["thinking_content"]})
    if msg.get("msg_content"):
        blocks.append({"type": "text", "text": msg["msg_content"]})
    for tc in msg.get("tool_calls") or []:
        args = tc.get("tool_call_args")
        if isinstance(args, str):
            try:
                args = json.loads(args)
            except Exception:
                pass
        blocks.append({"type": "tool_call", "id": tc.get("tool_call_id"),
                       "name": tc.get("tool_name"), "input": args})
    return blocks


def _usage_details(u: dict) -> dict:
    if not u:
        return None
    d = {"input": u.get("input") or 0, "output": u.get("output") or 0}
    if u.get("cache_read"):
        d["cache_read_input_tokens"] = u["cache_read"]
    if u.get("cache_write"):
        d["cache_creation_input_tokens"] = u["cache_write"]
    if u.get("reasoning"):
        d["reasoning_tokens"] = u["reasoning"]
    return d


def finalize_turn(state: dict, session_id: str, reason: str, send: bool = True):
    turn = state["turn"]
    now_ms = int(time.time() * 1000)
    user_id = get_user_id()
    msgs = db_turn_messages(session_id, turn["start_rowid"])
    usage_map = db_usage_for([m["msg_id"] for m in msgs if m.get("msg_id")])

    # degradation detection: hooks saw activity the DB can't account for
    # (schema drift, db moved, ...) -> still emit a structurally complete
    # trace from hook data alone, flagged so dashboards can spot it
    degraded_reason = None
    if not msgs and (turn["tools"] or turn.get("final_content")):
        degraded_reason = "no_messages"
        log(f"degraded turn (sqlite returned no messages) session={session_id}", "ERROR")
    elif msgs and not usage_map:
        degraded_reason = "no_usage"
        log(f"degraded turn (sqlite returned no token_usage) session={session_id}", "ERROR")

    # adopt model from usage if hooks never told us (tool-less turns)
    if not turn.get("model"):
        for u in usage_map.values():
            if u.get("model"):
                turn["model"] = u["model"]
                state["model"] = u["model"]
                break
    provider, model_name = split_model(turn.get("model") or "")
    turn["tags"] = turn_tags(state, turn)
    turn["trace_metadata"] = base_trace_metadata(state, turn, session_id)
    lf_session = litefuse_session_id(state, session_id)

    # classify generations; response = last content-bearing msg
    content_msgs = [m for m in msgs
                    if m.get("msg_type") == 1 and (m.get("msg_content") or "").strip()]
    response_msg_id = content_msgs[-1]["msg_id"] if content_msgs else None

    # step indexes follow the authoritative message order: each LLM call gets
    # the next number, then its tool calls in order (spec §3.4). Hook events
    # only contribute wall-clock timing, matched by tool_call_id.
    gen_steps = {}
    tool_steps = {}
    for m in msgs:
        if not (m.get("tool_calls") or (m.get("msg_content") or "").strip()
                or m.get("thinking_content")):
            continue  # never emitted (spec §4.4) -> takes no number
        turn["step"] += 1
        gen_steps[m.get("msg_id")] = turn["step"]
        for tc in m.get("tool_calls") or []:
            turn["step"] += 1
            tool_steps[tc.get("tool_call_id")] = turn["step"]

    step_ends = {}  # timeline for deriving generation starts
    matched_call_ids = set()

    spans = []
    last_tool_results = None  # input delta for the next generation
    prev_prompt_input = turn.get("prompt") or ""
    final_output = turn.get("final_content")

    for m in msgs:
        mid = m.get("msg_id")
        if mid not in gen_steps:
            continue  # no input and no output -> no observation (spec §4.4)
        step = gen_steps[mid]
        u = usage_map.get(mid)
        tool_calls = m.get("tool_calls") or []
        is_plan = bool(tool_calls)
        has_text = bool((m.get("msg_content") or "").strip())
        has_thinking = bool(m.get("thinking_content"))

        if is_plan:
            n = len(tool_calls)
            name = f"plan ({n} tool{'s' if n != 1 else ''}) #{step}"
        elif has_text and mid == response_msg_id:
            name = "response"
        elif has_text:
            name = f"generation #{step}"
        else:
            name = f"think #{step}"

        # end: the model has finished by the time its first tool starts (hook
        # wall-clock), else the usage flush ts, else the msg timestamp
        first_tool_start = min(
            (turn["tools"][tc.get("tool_call_id")]["start_ms"]
             for tc in tool_calls
             if tc.get("tool_call_id") in turn["tools"]), default=None)
        candidates = [t for t in (first_tool_start, (u or {}).get("ts")) if t]
        end_ms = min(candidates) if candidates else (m.get("timestamp") or now_ms)
        if name == "response" and turn.get("final_content_ms"):
            end_ms = max(end_ms, turn["final_content_ms"])
        # start: end of the latest earlier step (LLM call follows it immediately)
        prior_ends = [e for s, e in step_ends.items() if s < step] or [turn["start_ms"]]
        start_ms = max(max(prior_ends), turn["start_ms"])
        if start_ms >= end_ms:
            start_ms = max(turn["start_ms"], end_ms - 1)
        step_ends[step] = end_ms

        # input: what the model newly received (full request is not observable
        # from Mavis hooks — flagged via agent_input_is_delta)
        if last_tool_results is not None:
            gen_input = {"tool_results": last_tool_results}
        else:
            gen_input = prev_prompt_input
        meta = {
            "agent_turn_number": turn["number"],
            "agent_step_index": step,
            "agent_provider": provider,
            "agent_input_is_delta": True,
        }
        if m.get("finish_reason"):
            meta["agent_stop_reason"] = m["finish_reason"]
        if is_plan:
            meta["agent_tool_call_count"] = len(tool_calls)
        if has_thinking:
            meta["agent_thinking_chars"] = len(m["thinking_content"])
        if isinstance(m.get("usage"), dict) and m["usage"].get("total_tokens"):
            meta["agent_context_tokens"] = m["usage"]["total_tokens"]

        spans.append(build_span(
            turn,
            span_id=secrets.token_hex(8),
            parent_span_id=turn["root_span_id"],
            name=name,
            obs_type="generation",
            start_ms=start_ms, end_ms=end_ms,
            input_val=gen_input,
            output_val=_gen_output_blocks(m),
            model=split_model((u or {}).get("model") or turn.get("model") or "")[1] or None,
            usage_details=_usage_details(u),
            cost_details={"total": u["cost_usd"]} if u and u.get("cost_usd") else None,
            level="ERROR" if m.get("finish_reason") == "error" else None,
            metadata=meta,
            completion_start_ms=m.get("timestamp"),
            session_id=lf_session, user_id=user_id,
        ))

        if is_plan:
            last_tool_results = []
            for tc in tool_calls:
                call_id = tc.get("tool_call_id") or ""
                t_step = tool_steps[call_id]
                matched_call_ids.add(call_id)
                hook_tc = turn["tools"].get(call_id) or {}
                args = tc.get("tool_call_args")
                try:
                    args = json.loads(args) if isinstance(args, str) else (args or {})
                except Exception:
                    args = {"_raw": tc.get("tool_call_args")}
                if not args and hook_tc.get("args"):
                    args = hook_tc["args"]
                result = (hook_tc.get("result")
                          if hook_tc.get("result") is not None
                          else tc.get("tool_call_result_data") or "")
                is_error = hook_tc.get("is_error", _result_is_error(result))
                t_start = hook_tc.get("start_ms") or end_ms
                t_end = hook_tc.get("end_ms") or m.get("timestamp") or now_ms
                t_finished = bool(hook_tc.get("end_ms")
                                  or tc.get("tool_call_status") == 2)
                tname = tc.get("tool_name") or hook_tc.get("name") or "?"
                tool_span_id = secrets.token_hex(8)
                child_id = None
                if tname in DELEGATION_TOOLS and not is_error:
                    child_id = parse_child_session_id(str(result))
                t_meta = {
                    "agent_tool_name": tname,
                    "agent_tool_call_id": call_id,
                    "agent_step_index": t_step,
                    "agent_plan_step": step,
                    "agent_is_error": is_error,
                    "agent_turn_number": turn["number"],
                }
                if hook_tc.get("start_ms") and hook_tc.get("end_ms"):
                    t_meta["agent_duration_ms"] = t_end - t_start
                else:
                    t_meta["agent_times_estimated"] = True
                if child_id:
                    t_meta["agent_subagent_session_id"] = child_id
                level = None
                status_message = None
                if is_error:
                    level, status_message = "ERROR", str(result)[:500]
                elif not t_finished:
                    level, status_message = "WARNING", "turn ended before tool completed"
                    t_end = now_ms
                spans.append(build_span(
                    turn,
                    span_id=tool_span_id,
                    parent_span_id=turn["root_span_id"],
                    name=(f"tool (1 subagent) #{t_step}" if child_id
                          else tool_span_name(tname, args, t_step)),
                    obs_type="tool",
                    start_ms=t_start, end_ms=t_end,
                    input_val=args,
                    output_val=str(result) if result else None,
                    level=level,
                    status_message=status_message,
                    metadata=t_meta,
                    session_id=lf_session, user_id=user_id,
                ))
                if child_id:
                    task_input = args.get("prompt") if isinstance(args, dict) else None
                    spans.extend(build_subagent_subtree(
                        turn, lf_session, user_id, tool_span_id, child_id,
                        task_input or args))
                step_ends[t_step] = t_end
                r = str(result)
                last_tool_results.append({"tool_call_id": call_id,
                                          "tool_name": tname,
                                          "result": r[:4000]})
        elif has_text:
            last_tool_results = None
            prev_prompt_input = m.get("msg_content") or ""
        if mid == response_msg_id and not final_output:
            final_output = m.get("msg_content") or ""

    # degraded mode: synthesize the response generation from MessageComplete
    # (no usage/thinking available -- structure over silence)
    if degraded_reason == "no_messages" and (turn.get("final_content") or "").strip():
        turn["step"] += 1
        tool_ends = [tc["end_ms"] for tc in turn["tools"].values() if tc.get("end_ms")]
        start_ms = max(tool_ends) if tool_ends else turn["start_ms"]
        end_ms = turn.get("final_content_ms") or now_ms
        spans.append(build_span(
            turn,
            span_id=secrets.token_hex(8),
            parent_span_id=turn["root_span_id"],
            name="response",
            obs_type="generation",
            start_ms=start_ms, end_ms=end_ms,
            input_val=turn.get("prompt") or "",
            output_val=turn["final_content"],
            model=split_model(turn.get("model") or "")[1] or None,
            metadata={
                "agent_turn_number": turn["number"],
                "agent_step_index": turn["step"],
                "agent_degraded": True,
                "agent_degraded_reason": degraded_reason,
            },
            session_id=lf_session, user_id=user_id,
        ))
        final_output = turn["final_content"]

    # hook-seen tools whose message never reached sqlite: emit from hook
    # data alone (in-flight at abort, or persistence lost the message)
    for call_id, tc in sorted(turn["tools"].items(),
                              key=lambda kv: kv[1].get("start_ms") or 0):
        if call_id in matched_call_ids:
            continue
        turn["step"] += 1
        step = turn["step"]
        finished = bool(tc.get("end_ms"))
        meta = {"agent_tool_name": tc["name"], "agent_tool_call_id": call_id,
                "agent_step_index": step,
                "agent_is_error": tc.get("is_error", False),
                "agent_turn_number": turn["number"]}
        if finished:
            meta["agent_duration_ms"] = tc["end_ms"] - tc["start_ms"]
        spans.append(build_span(
            turn,
            span_id=secrets.token_hex(8),
            parent_span_id=turn["root_span_id"],
            name=tool_span_name(tc["name"], tc.get("args") or {}, step),
            obs_type="tool",
            start_ms=tc["start_ms"], end_ms=tc.get("end_ms") or now_ms,
            input_val=tc.get("args") or {},
            output_val=tc.get("result"),
            level=("ERROR" if tc.get("is_error")
                   else None if finished else "WARNING"),
            status_message=(str(tc.get("result"))[:500] if tc.get("is_error")
                            else None if finished
                            else "turn ended before tool completed"),
            metadata=meta,
            session_id=lf_session, user_id=user_id,
        ))

    # root agent span, trace input/output, turn rollup metadata
    def _count(obs_type):
        return sum(1 for s in spans
                   if s.get("parentSpanId") == turn["root_span_id"]
                   and any(a["key"] == "langfuse.observation.type"
                           and a["value"]["stringValue"] == obs_type
                           for a in s["attributes"]))
    n_tools = _count("tool")
    n_gens = _count("generation")
    root_level = None
    status_message = None
    if reason in ("aborted", "superseded") or not (final_output or "").strip():
        root_level = "WARNING"
        status_message = f"turn ended without final response (reason={reason or 'unknown'})"
    if reason == "error":
        root_level = "ERROR"
        status_message = "turn ended with error"
    root_meta = dict(turn["trace_metadata"])
    root_meta.update({
        "agent_api_calls": n_gens,
        "agent_tool_calls": n_tools,
        "agent_steps": turn["step"],
        "agent_duration_ms": now_ms - turn["start_ms"],
        "agent_message_count": len(msgs),
    })
    if degraded_reason:
        root_meta["agent_degraded"] = True
        root_meta["agent_degraded_reason"] = degraded_reason
    ctx = next((m["usage"] for m in reversed(msgs)
                if isinstance(m.get("usage"), dict) and m["usage"].get("context_window")), None)
    if ctx:
        root_meta["agent_context_tokens"] = ctx.get("total_tokens")
        root_meta["agent_context_window"] = ctx.get("context_window")
    turn["trace_metadata"] = root_meta

    spans.append(build_span(
        turn,
        span_id=turn["root_span_id"],
        parent_span_id=None,
        name=turn["trace_name"],
        obs_type="agent",
        start_ms=turn["start_ms"], end_ms=now_ms,
        input_val=turn.get("prompt") or "",
        output_val=final_output or None,
        level=root_level,
        status_message=status_message,
        metadata=root_meta,
        trace_output=final_output or "",
        session_id=lf_session, user_id=user_id,
    ))

    state["last_emitted_rowid"] = max(
        [m["_rowid"] for m in msgs] + [turn["start_rowid"]])
    state["turns_emitted"] = state.get("turns_emitted", 0) + 1
    state["turn"] = None

    if send:
        send_spans(spans)
    return spans


# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------
HANDLERS = {
    "SessionStart": handle_session_start,
    "UserPromptSubmit": handle_user_prompt_submit,
    "PreToolUse": handle_pre_tool_use,
    "PostToolUse": handle_post_tool_use,
    "MessageComplete": handle_message_complete,
    "SessionEnd": handle_session_end,
}


def main():
    global DEBUG, MAX_CHARS
    _load_env_file(_ENV_FILE)
    DEBUG = os.environ.get("MAVIS_LITEFUSE_DEBUG", "").lower() == "true"
    try:
        MAX_CHARS = int(os.environ.get("MAVIS_LITEFUSE_MAX_CHARS", 1_000_000))
    except ValueError:
        MAX_CHARS = 1_000_000

    # contract with mavis: always print JSON to stdout, always exit 0
    try:
        event = sys.argv[1] if len(sys.argv) > 1 else ""
        payload = json.load(sys.stdin)
        inp = payload.get("input") or {}
        handler = HANDLERS.get(event)
        if handler and get_targets():
            handler(inp)
        elif not handler:
            log(f"unknown event {event}", "WARN")
    except Exception as e:
        import traceback
        log(f"hook error: {e}\n{traceback.format_exc()}", "ERROR")
    print(json.dumps({"metadata": {}}))
    sys.exit(0)


if __name__ == "__main__":
    main()
