#!/usr/bin/env python3
"""
Kimi Code -> Litefuse hook (v2)

解析 Kimi Code 的 wire.jsonl，为每个完整的用户回合生成一条 Litefuse trace。
遵循《Litefuse Agent 集成 Trace 规范》v1.2（litefuse-agent-trace-spec.md）：

  Kimi Code — Turn 3                    AGENT（根，回合真实 wall-clock 时长）
  ├── plan (2 tools) #1                 GENERATION（usage_details、真实延迟、TTFT）
  ├── tool: bash (grep) #2              TOOL（metadata agent_plan_step 指回 #1）
  ├── tool: read (index.ts) #3          TOOL
  └── response                          GENERATION（最终回答，无 #N，携带自身 usage）

  #N 是每回合一个按时间递增的步骤计数器，generation 与 tool 共用；
  metadata 统一打平为 agent_ 前缀（agent_step_index / agent_plan_step / …）。

传输：直发 OTLP/HTTP JSON 到 <host>/api/public/otel/v1/traces，零第三方依赖。
每个 span 恰好发送一次（只发完整回合）；trace header（name/session/user/
input/tags）随每个 span 携带。traceId/spanId 由 session+turn 确定性派生，
状态重置后重跑不会产生重复 trace。

与事件型集成（pi v2）的差异（离线解析的已知取舍）：
  - generation 的 input 是回合内重建的消息序列（user prompt + 本回合内的
    assistant/tool 消息），不含跨回合历史与 system prompt——wire.jsonl 不含
    每次 API 调用的完整请求体。metadata 标注 agent_input_scope: "turn"。
  - 时间语义：wire 的 step.end 时刻含工具执行（它是 agent step 的结束），
    generation 的真实结束 = step.begin + llmFirstTokenLatencyMs +
    llmStreamDurationMs；tool span = tool.call 时刻 → tool.result 时刻
    （含审批等待在 generation 与 tool 之间表现为时间线空隙，符合真实情况）。
  - 子 agent 子树（规范 §2.5）暂未实现：现有会话数据中无 Agent 工具样本，
    待拿到子 agent 的 wire 布局后按 §2.5/§2.6 补齐；目前 Agent 工具按普通
    TOOL observation 记录。

回合完整性：只发已结束的回合（最后一个 step 以 end_turn 收尾、或回合被
turn.cancel 取消、或后面已有新的 turn.prompt）。仍在进行的回合留在 offset
之后，下次轮询重读；超过 KIMI_LITEFUSE_STALE_MINUTES（默认 30）无新事件的
未完回合按"被中断"发出（root WARNING）。

环境变量（LITEFUSE_* 优先，LANGFUSE_* 为生态兼容 fallback）：
  TRACE_TO_LITEFUSE=true               # 启用追踪（保持 v1 约定）
  LITEFUSE_PUBLIC_KEY / LITEFUSE_SECRET_KEY
  LITEFUSE_BASE_URL | LITEFUSE_HOST    # 默认 https://litefuse.cloud
  LITEFUSE_TRACING_ENVIRONMENT         # 默认 production
  LITEFUSE_USER_ID                     # 默认系统用户名
  LITEFUSE_EXTRA_TARGETS               # 多目标 JSON 数组（可选）
  KIMI_LITEFUSE_DEBUG=true
  KIMI_LITEFUSE_MAX_CHARS              # input/output 截断阈值，默认 1000000
  KIMI_LITEFUSE_STALE_MINUTES          # 未完回合判定中断的闲置分钟数，默认 30
  KIMI_LITEFUSE_STATE_DIR              # 状态目录覆盖（测试用）

Fail-open：任何异常只写 ~/.kimi-code/state/litefuse_hook.log 并静默返回，
永不影响 Kimi Code 本身。
"""

import base64
import getpass
import hashlib
import json
import os
import re
import sys
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# ----------------- Env file (no shell config needed) -----------------
def _load_env_file(path: Path) -> None:
    if not path.exists():
        return
    try:
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#") or "=" not in line:
                    continue
                key, val = line.split("=", 1)
                key, val = key.strip(), val.strip().strip('"').strip("'")
                if key and key not in os.environ:
                    os.environ[key] = val
    except Exception:
        pass


_load_env_file(Path.home() / ".kimi-code" / "litefuse.env")

# ----------------- Config -----------------
STATE_DIR = Path(os.environ.get("KIMI_LITEFUSE_STATE_DIR", "") or (Path.home() / ".kimi-code" / "state"))
LOG_FILE = STATE_DIR / "litefuse_hook.log"
STATE_FILE = STATE_DIR / "litefuse_state.json"
LOCK_FILE = STATE_DIR / "litefuse_state.lock"

DEBUG = os.environ.get("KIMI_LITEFUSE_DEBUG", "").lower() == "true"
MAX_CHARS = int(os.environ.get("KIMI_LITEFUSE_MAX_CHARS", "1000000"))
STALE_MS = int(float(os.environ.get("KIMI_LITEFUSE_STALE_MINUTES", "30")) * 60_000)
HTTP_TIMEOUT_S = 10

try:
    USER_ID = os.environ.get("LITEFUSE_USER_ID") or getpass.getuser() or "kimi-user"
except Exception:
    USER_ID = "kimi-user"


def _env(*names: str, default: str = "") -> str:
    for n in names:
        v = os.environ.get(n)
        if v:
            return v
    return default


def load_targets() -> List[Dict[str, str]]:
    targets: List[Dict[str, str]] = []
    seen = set()

    def add(pk: Any, sk: Any, url: Any, env: Any) -> None:
        p, s = str(pk or ""), str(sk or "")
        if not p or not s:
            return
        u = str(url or "https://litefuse.cloud").rstrip("/")
        key = f"{p}|{u}"
        if key in seen:
            return
        seen.add(key)
        targets.append({"public_key": p, "secret_key": s, "base_url": u,
                        "environment": str(env or "default")})

    add(
        _env("LITEFUSE_PUBLIC_KEY", "LANGFUSE_PUBLIC_KEY"),
        _env("LITEFUSE_SECRET_KEY", "LANGFUSE_SECRET_KEY"),
        _env("LITEFUSE_BASE_URL", "LITEFUSE_HOST", "LANGFUSE_BASE_URL", "LANGFUSE_HOST",
             default="https://litefuse.cloud"),
        _env("LITEFUSE_TRACING_ENVIRONMENT", "LANGFUSE_TRACING_ENVIRONMENT", default="production"),
    )
    try:
        extra = json.loads(os.environ.get("LITEFUSE_EXTRA_TARGETS", "[]"))
        if isinstance(extra, list):
            for t in extra:
                if isinstance(t, dict):
                    add(t.get("publicKey") or t.get("public_key"),
                        t.get("secretKey") or t.get("secret_key"),
                        t.get("baseUrl") or t.get("base_url"),
                        t.get("environment"))
    except Exception:
        pass
    return targets


# ----------------- Logging (fail-open) -----------------
def _log(level: str, message: str) -> None:
    try:
        STATE_DIR.mkdir(parents=True, exist_ok=True)
        ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        with open(LOG_FILE, "a", encoding="utf-8") as f:
            f.write(f"{ts} [{level}] {message}\n")
    except Exception:
        pass


def debug(msg: str) -> None:
    if DEBUG:
        _log("DEBUG", msg)


def info(msg: str) -> None:
    _log("INFO", msg)


# ----------------- State -----------------
class FileLock:
    def __init__(self, path: Path, timeout_s: float = 2.0):
        self.path = path
        self.timeout_s = timeout_s
        self._fh = None

    def __enter__(self):
        STATE_DIR.mkdir(parents=True, exist_ok=True)
        self._fh = open(self.path, "a+", encoding="utf-8")
        try:
            import fcntl
            deadline = time.time() + self.timeout_s
            while True:
                try:
                    fcntl.flock(self._fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
                    break
                except BlockingIOError:
                    if time.time() > deadline:
                        break
                    time.sleep(0.05)
        except Exception:
            pass
        return self

    def __exit__(self, exc_type, exc, tb):
        try:
            import fcntl
            fcntl.flock(self._fh.fileno(), fcntl.LOCK_UN)
        except Exception:
            pass
        try:
            self._fh.close()
        except Exception:
            pass


def load_state() -> Dict[str, Any]:
    try:
        if STATE_FILE.exists():
            return json.loads(STATE_FILE.read_text(encoding="utf-8"))
    except Exception:
        pass
    return {}


def save_state(state: Dict[str, Any]) -> None:
    try:
        STATE_DIR.mkdir(parents=True, exist_ok=True)
        tmp = STATE_FILE.with_suffix(".tmp")
        tmp.write_text(json.dumps(state, indent=2, sort_keys=True), encoding="utf-8")
        os.replace(tmp, STATE_FILE)
    except Exception as e:
        debug(f"save_state failed: {e}")


def state_key(session_id: str, transcript_path: str) -> str:
    return hashlib.sha256(f"{session_id}::{transcript_path}".encode("utf-8")).hexdigest()


# ----------------- Helpers -----------------
def clip(s: str, max_len: int = 24) -> str:
    return s if len(s) <= max_len else s[: max_len - 1] + "…"


def serialize(v: Any, max_chars: int = 0) -> Tuple[str, bool, int]:
    """value -> (text, truncated, orig_len)"""
    max_chars = max_chars or MAX_CHARS
    if isinstance(v, str):
        s = v
    else:
        try:
            s = json.dumps(v, ensure_ascii=False)
        except Exception:
            s = "[unserializable]"
    orig = len(s)
    if orig > max_chars:
        return s[:max_chars] + "…", True, orig
    return s, False, orig


def det_id(seed: str, nbytes: int) -> str:
    """Deterministic OTLP id (traceId 16 bytes, spanId 8 bytes) from a stable seed."""
    return hashlib.sha256(seed.encode("utf-8")).hexdigest()[: nbytes * 2]


def nano(ms: int) -> str:
    return str(int(ms) * 1_000_000)


def attr_kv(key: str, v: Any) -> Optional[Dict[str, Any]]:
    if v is None:
        return None
    if isinstance(v, bool):
        return {"key": key, "value": {"boolValue": v}}
    if isinstance(v, int):
        return {"key": key, "value": {"intValue": v}}
    if isinstance(v, list):
        return {"key": key, "value": {"arrayValue": {"values": [{"stringValue": str(x)} for x in v]}}}
    return {"key": key, "value": {"stringValue": str(v)}}


def attrs(obj: Dict[str, Any]) -> List[Dict[str, Any]]:
    out = []
    for k, v in obj.items():
        kv = attr_kv(k, v)
        if kv is not None:
            out.append(kv)
    return out


def meta_json(meta: Dict[str, Any]) -> str:
    """Flatten to top-level agent_<key>; drop None (sparse, per spec §4.5/§6.2)."""
    return json.dumps({f"agent_{k}": v for k, v in meta.items() if v is not None}, ensure_ascii=False)


BASH_SKIP = {"cd", "export", "sudo", "time", "env", "exec", "nohup", "command"}


def bash_exe(command: Any) -> str:
    if not isinstance(command, str):
        return ""
    import re
    for line in command.split("\n"):
        for word in re.split(r"[\s;|&()]+", line):
            if not word or "=" in word or word.startswith(("-", '"', "'")):
                continue
            if word in BASH_SKIP:
                continue
            if re.match(r"^[a-zA-Z][a-zA-Z0-9._+-]*$", word):
                return clip(word)
    return ""


def _basename(p: Any) -> str:
    if not isinstance(p, str) or not p:
        return ""
    return clip(p.rstrip("/").split("/")[-1] or p)


def tool_key_info(name: str, args: Any) -> str:
    """名词性关键信息，截到 24 字符；取不到返回 ""（规范 §3.3）。"""
    if not isinstance(args, dict):
        return ""
    n = name.lower()
    if n == "bash":
        return bash_exe(args.get("command"))
    if n in ("read", "write", "edit"):
        return _basename(args.get("path") or args.get("file_path"))
    if n in ("grep", "glob", "find"):
        pat = args.get("pattern")
        return clip(pat) if isinstance(pat, str) else _basename(args.get("path"))
    if n == "ls":
        return _basename(args.get("path"))
    if n == "fetchurl":
        url = args.get("url")
        if isinstance(url, str):
            host = url.split("//", 1)[-1].split("/", 1)[0]
            return clip(host)
        return ""
    if n == "websearch":
        q = args.get("query")
        return clip(q) if isinstance(q, str) else ""
    if n in ("skill", "agent"):
        v = args.get("skill") or args.get("subagent_type") or args.get("name") or args.get("description")
        return clip(v) if isinstance(v, str) else ""
    v = args.get("path") or args.get("file_path")
    if isinstance(v, str):
        return _basename(v)
    v = args.get("name") or args.get("pattern") or args.get("title")
    return clip(v) if isinstance(v, str) else ""


# ----------------- Wire.jsonl parsing -----------------
@dataclass
class ToolCall:
    id: str
    name: str
    args: Any
    call_ms: int


@dataclass
class ToolResult:
    output: Any
    is_error: bool
    time_ms: int


@dataclass
class Step:
    begin_ms: int
    end_ms: int = 0
    thinking: str = ""
    text: str = ""
    tool_calls: List[ToolCall] = field(default_factory=list)
    usage: Optional[Dict[str, Any]] = None
    finish_reason: str = ""
    ttft_ms: Optional[int] = None
    stream_ms: Optional[int] = None
    model: str = ""
    uuid: str = ""


@dataclass
class TurnChunk:
    start_line: int                      # absolute line index of its turn.prompt
    prompt_text: str = ""
    prompt_ms: int = 0
    image_blocks: int = 0
    turn_id: Optional[int] = None
    steps: List[Step] = field(default_factory=list)
    tool_results: Dict[str, ToolResult] = field(default_factory=dict)
    cancelled: bool = False
    last_event_ms: int = 0
    end_line: int = 0                    # absolute line index just past the chunk


def _prompt_text_and_images(input_obj: Any) -> Tuple[str, int]:
    if isinstance(input_obj, str):
        return input_obj, 0
    texts, images = [], 0
    if isinstance(input_obj, list):
        for x in input_obj:
            if not isinstance(x, dict):
                continue
            if x.get("type") == "text":
                texts.append(x.get("text", ""))
            elif x.get("type") in ("image", "image_url", "video"):
                images += 1
    return "\n".join(texts), images


def parse_chunks(lines: List[Tuple[int, Optional[Dict[str, Any]]]]) -> List[TurnChunk]:
    """lines: [(abs_line_no, parsed_obj_or_None)] -> turn chunks in order."""
    chunks: List[TurnChunk] = []
    cur: Optional[TurnChunk] = None
    step: Optional[Step] = None
    think_parts: List[str] = []
    text_parts: List[str] = []

    def close_step(end_ms: int = 0) -> None:
        nonlocal step, think_parts, text_parts
        if step is None or cur is None:
            return
        step.thinking = "\n".join(p for p in think_parts if p)
        step.text = "\n".join(p for p in text_parts if p)
        if end_ms and not step.end_ms:
            step.end_ms = end_ms
        cur.steps.append(step)
        step, think_parts, text_parts = None, [], []

    for line_no, obj in lines:
        if obj is None:
            if cur is not None:
                cur.end_line = line_no + 1
            continue
        mtype = obj.get("type")
        t_ms = int(obj.get("time") or 0)

        if mtype == "turn.prompt":
            close_step()
            if cur is not None:
                cur.end_line = line_no
            text, images = _prompt_text_and_images(obj.get("input"))
            cur = TurnChunk(start_line=line_no, prompt_text=text, prompt_ms=t_ms,
                            image_blocks=images, last_event_ms=t_ms, end_line=line_no + 1)
            chunks.append(cur)
            continue

        if cur is None:
            continue
        cur.end_line = line_no + 1
        if t_ms:
            cur.last_event_ms = max(cur.last_event_ms, t_ms)

        if mtype == "turn.cancel":
            close_step(t_ms)
            cur.cancelled = True
            continue

        if mtype == "usage.record":
            model = obj.get("model", "")
            if model and cur.steps:
                cur.steps[-1].model = model
            continue

        if mtype != "context.append_loop_event":
            continue
        event = obj.get("event", {})
        etype = event.get("type")
        if cur.turn_id is None:
            tid = event.get("turnId")
            try:
                cur.turn_id = int(tid)
            except (TypeError, ValueError):
                pass

        if etype == "step.begin":
            close_step(t_ms)
            step = Step(begin_ms=t_ms, uuid=event.get("uuid", ""))
        elif etype == "content.part" and step is not None:
            part = event.get("part", {})
            if part.get("type") == "think":
                think_parts.append(part.get("think", "") or part.get("thinking", ""))
            elif part.get("type") == "text":
                text_parts.append(part.get("text", ""))
        elif etype == "tool.call" and step is not None:
            step.tool_calls.append(ToolCall(
                id=event.get("toolCallId", "") or event.get("uuid", ""),
                name=event.get("name", "unknown"),
                args=event.get("args", {}),
                call_ms=t_ms,
            ))
        elif etype == "step.end":
            if step is not None:
                step.usage = event.get("usage")
                step.finish_reason = event.get("finishReason", "") or ""
                step.ttft_ms = event.get("llmFirstTokenLatencyMs")
                step.stream_ms = event.get("llmStreamDurationMs")
                step.end_ms = t_ms
            close_step(t_ms)
        elif etype == "tool.result":
            tcid = event.get("toolCallId", "")
            result = event.get("result", {})
            if tcid:
                output = result.get("output", result) if isinstance(result, dict) else result
                is_error = bool(result.get("isError")) if isinstance(result, dict) else False
                cur.tool_results[tcid] = ToolResult(output=output, is_error=is_error, time_ms=t_ms)

    close_step()
    return chunks


def chunk_finished(chunk: TurnChunk) -> bool:
    """回合已结束：被取消，或最后一个 step 以 end_turn 收尾且其工具结果已齐。"""
    if chunk.cancelled:
        return True
    if not chunk.steps:
        return False
    last = chunk.steps[-1]
    if last.finish_reason != "end_turn":
        return False
    pending = [tc for s in chunk.steps for tc in s.tool_calls if tc.id not in chunk.tool_results]
    return not pending


# ----------------- Span building -----------------
SUBAGENT_RESULT_RE = re.compile(
    r"^agent_id:\s*(?P<id>agent-[\w.-]+)\s*\n"
    r"actual_subagent_type:\s*(?P<type>\S+)\s*\n"
    r"status:\s*(?P<status>\S+)")


def parse_subagent_result(output: Any) -> Optional[Dict[str, str]]:
    """Agent 工具结果首部携带子 agent 标识（agent_id / actual_subagent_type / status）。"""
    if not isinstance(output, str):
        return None
    m = SUBAGENT_RESULT_RE.match(output)
    if not m:
        return None
    return {"id": m.group("id"), "type": m.group("type"), "status": m.group("status")}


def usage_details(usage: Optional[Dict[str, Any]]) -> Optional[str]:
    """Kimi usage -> Anthropic 风格 key（规范 §5.1）。"""
    if not isinstance(usage, dict):
        return None
    ud: Dict[str, int] = {}
    mapping = [
        ("inputOther", "input"),
        ("output", "output"),
        ("inputCacheRead", "cache_read_input_tokens"),
        ("inputCacheCreation", "cache_creation_input_tokens"),
    ]
    for src, dst in mapping:
        v = usage.get(src)
        if isinstance(v, (int, float)) and v > 0:
            ud[dst] = int(v)
    return json.dumps(ud) if ud else None


def assistant_blocks(step: Step) -> Any:
    """保留 thinking / text / toolCall 块结构（规范 §4.2）。"""
    blocks: List[Dict[str, Any]] = []
    if step.thinking:
        blocks.append({"type": "thinking", "thinking": step.thinking})
    if step.text:
        blocks.append({"type": "text", "text": step.text})
    for tc in step.tool_calls:
        blocks.append({"type": "toolCall", "id": tc.id, "name": tc.name, "arguments": tc.args})
    if len(blocks) == 1 and blocks[0]["type"] == "text":
        return blocks[0]["text"]
    return blocks


def build_turn_spans(
    session_id: str,
    work_dir: str,
    turn_num: int,
    chunk: TurnChunk,
    transcript_path: str,
    interrupted: bool = False,
    agents_dir: Optional[Path] = None,
) -> List[Dict[str, Any]]:
    trace_id = det_id(f"kimi-code:{session_id}:{turn_num}", 16)
    root_span_id = det_id(f"{trace_id}:root", 8)
    trace_name = f"Kimi Code — Turn {turn_num}"

    prompt_text, prompt_trunc, prompt_len = serialize(chunk.prompt_text)
    model = next((s.model for s in reversed(chunk.steps) if s.model), "") or "kimi-for-coding"
    tags = ["kimi-code", f"model:{model}"]

    def trace_headers() -> Dict[str, Any]:
        return {
            "langfuse.trace.name": trace_name,
            "langfuse.trace.input": prompt_text or None,
            "langfuse.trace.tags": tags,
            "session.id": session_id,
            "user.id": USER_ID,
        }

    spans: List[Dict[str, Any]] = []
    end_ms_floor = chunk.last_event_ms or chunk.prompt_ms

    def add_span(span_id: str, parent_id: Optional[str], name: str, otype: str,
                 start_ms: int, end_ms: int, extra: Dict[str, Any]) -> None:
        end_ms = max(int(end_ms), int(start_ms) + 1)  # §7.1 同戳相邻 span 留 1ms
        spans.append({
            "traceId": trace_id,
            "spanId": span_id,
            "parentSpanId": parent_id,
            "name": name,
            "kind": 1,
            "startTimeUnixNano": nano(start_ms),
            "endTimeUnixNano": nano(end_ms),
            "attributes": attrs({
                **trace_headers(),
                "langfuse.observation.type": otype,
                **extra,
            }),
        })

    def emit_run(chunks: List[TurnChunk], parent_id: str, seed: str,
                 is_subagent: bool, floor_ms: int) -> Dict[str, Any]:
        """把一段 agent 运行的全部步骤挂到 parent_id 下并返回统计。
        每个容器一个按时间递增的步骤计数器，generation 与 tool 共用（§3.4），
        子 agent 容器内从 #1 重计。"""
        step_index = 0
        api_calls = 0
        tools_total = 0
        final_text = ""
        # 运行内重建的消息序列（generation input；不含跨回合历史，见模块 docstring）
        messages: List[Dict[str, Any]] = []

        for c in chunks:
            messages.append({"role": "user", "content": c.prompt_text})
            for si, step in enumerate(c.steps):
                api_calls += 1
                step_index += 1
                gen_step = step_index
                n_tools = len(step.tool_calls)
                incomplete = not step.end_ms  # 被取消/中断的在途 LLM 调用

                if n_tools > 0:
                    name = f"plan ({n_tools} {'tool' if n_tools == 1 else 'tools'}) #{gen_step}"
                elif step.text and step.finish_reason == "end_turn":
                    name = "subagent response" if is_subagent else "response"
                elif step.thinking and not step.text:
                    name = f"think #{gen_step}"
                else:
                    name = f"generation #{gen_step}"

                gen_input, in_trunc, in_len = serialize(messages)
                out_blocks = assistant_blocks(step)
                gen_output, out_trunc, out_len = serialize(out_blocks) if out_blocks else ("", False, 0)
                # wire 的 step.end 时刻含工具执行；LLM 调用的真实结束 = begin + ttft + stream
                if step.begin_ms and step.ttft_ms is not None and step.stream_ms is not None:
                    gen_end = step.begin_ms + int(step.ttft_ms) + int(step.stream_ms)
                elif step.tool_calls and step.tool_calls[0].call_ms:
                    gen_end = step.tool_calls[0].call_ms
                else:
                    gen_end = step.end_ms or floor_ms
                completion_start = None
                if step.ttft_ms is not None and step.begin_ms:
                    completion_start = json.dumps(
                        datetime.fromtimestamp((step.begin_ms + int(step.ttft_ms)) / 1000, tz=timezone.utc)
                        .isoformat().replace("+00:00", "Z"))

                level = status = None
                if step.finish_reason == "error":
                    level, status = "ERROR", (step.text or "LLM call failed")[:500]
                elif incomplete:
                    level, status = "WARNING", "turn ended before the LLM call completed"

                add_span(
                    det_id(f"{seed}:gen:{step.uuid or si}", 8), parent_id, name, "generation",
                    step.begin_ms or c.prompt_ms, gen_end,
                    {
                        "langfuse.observation.model.name": step.model or model,
                        "langfuse.observation.input": gen_input or None,
                        "langfuse.observation.output": gen_output or None,
                        "langfuse.observation.completion_start_time": completion_start,
                        "langfuse.observation.usage_details": usage_details(step.usage),
                        "langfuse.observation.level": level,
                        "langfuse.observation.status_message": status,
                        "langfuse.observation.metadata": meta_json({
                            "turn_number": turn_num,
                            "step_index": gen_step,
                            "provider": "kimi",
                            "stop_reason": step.finish_reason or None,
                            "api_duration_ms": (gen_end - step.begin_ms) if step.begin_ms else None,
                            "time_to_first_token_ms": step.ttft_ms,
                            "stream_duration_ms": step.stream_ms,
                            "tool_call_count": n_tools or None,
                            "thinking_chars": len(step.thinking) or None,
                            "step_uuid": step.uuid or None,
                            "input_scope": "turn",
                            "input_truncated": in_trunc or None,
                            "input_orig_len": in_len if in_trunc else None,
                            "output_truncated": out_trunc or None,
                            "output_orig_len": out_len if out_trunc else None,
                        }),
                    },
                )

                messages.append({"role": "assistant", "content": out_blocks if out_blocks else ""})

                for tc in step.tool_calls:
                    tools_total += 1
                    step_index += 1
                    result = c.tool_results.get(tc.id)
                    sub = parse_subagent_result(result.output) if result is not None else None
                    is_delegation = tc.name.lower() == "agent"
                    t_start = tc.call_ms or step.end_ms or c.prompt_ms
                    args_text, args_trunc, args_len = serialize(tc.args)
                    if is_delegation:
                        # 委派工具 span（§2.5）：Kimi 的 Agent 工具一次 spawn 一个子 agent
                        display = f"tool (1 subagent) #{step_index}"
                    else:
                        key_info = tool_key_info(tc.name, tc.args)
                        display = (f"tool: {tc.name.lower()} ({key_info}) #{step_index}"
                                   if key_info else f"tool: {tc.name.lower()} #{step_index}")

                    if result is not None:
                        out_text, t_out_trunc, t_out_len = serialize(result.output)
                        t_end = result.time_ms or t_start
                        t_level = "ERROR" if result.is_error else None
                        t_status = out_text[:500] if result.is_error else None
                    else:
                        out_text, t_out_trunc, t_out_len = "", False, 0
                        t_end = floor_ms
                        t_level, t_status = "WARNING", "turn ended before tool completed"

                    tool_span_id = det_id(f"{seed}:tool:{tc.id}", 8)
                    add_span(
                        tool_span_id, parent_id, display, "tool",
                        t_start, t_end,
                        {
                            "langfuse.observation.input": args_text or None,
                            "langfuse.observation.output": out_text or None,
                            "langfuse.observation.level": t_level,
                            "langfuse.observation.status_message": t_status,
                            "langfuse.observation.metadata": meta_json({
                                "tool_name": tc.name,
                                "tool_call_id": tc.id,
                                "step_index": step_index,
                                "plan_step": gen_step,
                                "turn_number": turn_num,
                                "duration_ms": (t_end - t_start) if result is not None else None,
                                "is_error": result.is_error if (result is not None and result.is_error) else None,
                                "subagent_id": sub["id"] if sub else None,
                                "input_truncated": args_trunc or None,
                                "input_orig_len": args_len if args_trunc else None,
                                "output_truncated": t_out_trunc or None,
                                "output_orig_len": t_out_len if t_out_trunc else None,
                            }),
                        },
                    )

                    messages.append({"role": "tool", "tool_call_id": tc.id,
                                     "content": result.output if result is not None else None})

                    if sub is not None and agents_dir is not None:
                        emit_subagent(sub, tool_span_id, seed, t_start, t_end)

                if si == len(c.steps) - 1 and step.text and step.finish_reason == "end_turn":
                    final_text = step.text

        return {"api_calls": api_calls, "tool_calls": tools_total,
                "steps": step_index, "final_text": final_text}

    def emit_subagent(sub: Dict[str, str], tool_span_id: str, parent_seed: str,
                      t_start: int, t_end: int) -> None:
        """§2.5 子树：委派工具 span 下挂 subagent 容器（agent 型），容器下是
        子 agent 自己的步骤；嵌套委派按同样规则递归。"""
        try:
            wire = agents_dir / sub["id"] / "wire.jsonl"
            if not wire.exists():
                debug(f"subagent wire missing: {wire}")
                return
            sub_chunks = parse_chunks(read_lines_from(wire, 0))
        except Exception as e:
            debug(f"subagent {sub['id']} parse failed: {e}")
            return
        if not sub_chunks:
            return
        seed = f"{parent_seed}:sub:{sub['id']}"
        container_id = det_id(f"{seed}:container", 8)
        c_start = sub_chunks[0].prompt_ms or t_start
        c_end = max((c.last_event_ms for c in sub_chunks if c.last_event_ms), default=t_end)
        stats = emit_run(sub_chunks, container_id, seed, True, c_end)

        cancelled = any(c.cancelled for c in sub_chunks)
        warn = None
        if cancelled:
            warn = "subagent cancelled"
        elif sub.get("status") != "completed":
            warn = f"subagent status: {sub.get('status')}"
        elif not stats["final_text"]:
            warn = "subagent ended without a final text response"
        in_text, _, _ = serialize(sub_chunks[0].prompt_text)
        out_text, _, _ = serialize(stats["final_text"]) if stats["final_text"] else ("", False, 0)
        add_span(
            container_id, tool_span_id, "subagent", "agent",
            c_start, c_end,
            {
                "langfuse.observation.input": in_text or None,
                "langfuse.observation.output": out_text or None,
                "langfuse.observation.level": "WARNING" if warn else None,
                "langfuse.observation.status_message": warn,
                "langfuse.observation.metadata": meta_json({
                    "subagent": True,
                    "subagent_id": sub["id"],
                    "subagent_type": sub.get("type"),
                    "subagent_status": sub.get("status"),
                    "turn_number": turn_num,
                    "api_calls": stats["api_calls"],
                    "tool_calls": stats["tool_calls"],
                    "steps": stats["steps"],
                    "duration_ms": (c_end - c_start) if c_start else None,
                }),
            },
        )

    stats = emit_run([chunk], root_span_id, trace_id, False, end_ms_floor)
    final_text = stats["final_text"]

    # Root AGENT span（回合容器，最后构建、与其余 span 一起一次性发出）
    output_text, root_out_trunc, root_out_len = serialize(final_text) if final_text else ("", False, 0)
    warn = None
    if chunk.cancelled:
        warn = "turn cancelled by user"
    elif interrupted:
        warn = "turn interrupted (no completion events; session likely ended mid-turn)"
    elif not final_text:
        warn = "turn ended without a final text response"

    root_end = end_ms_floor
    add_span(
        root_span_id, None, trace_name, "agent",
        chunk.prompt_ms or root_end, root_end,
        {
            "langfuse.observation.input": prompt_text or None,
            "langfuse.observation.output": output_text or None,
            "langfuse.observation.level": "WARNING" if warn else None,
            "langfuse.observation.status_message": warn,
            "langfuse.trace.output": output_text or None,
            "langfuse.trace.metadata": meta_json({
                "turn_number": turn_num,
                "session_id": session_id,
                "cwd": work_dir or None,
                "model": model,
                "provider": "kimi",
                "transcript_path": transcript_path,
                "image_blocks": chunk.image_blocks or None,
                "prompt_truncated": prompt_trunc or None,
                "prompt_orig_len": prompt_len if prompt_trunc else None,
                "api_calls": stats["api_calls"],
                "tool_calls": stats["tool_calls"],
                "steps": stats["steps"],
                "duration_ms": root_end - chunk.prompt_ms if chunk.prompt_ms else None,
                "cancelled": chunk.cancelled or None,
                "output_truncated": root_out_trunc or None,
                "output_orig_len": root_out_len if root_out_trunc else None,
            }),
        },
    )
    return spans


# ----------------- Transport (raw OTLP/HTTP JSON, §10) -----------------
BATCH_BYTES = int(os.environ.get("KIMI_LITEFUSE_BATCH_BYTES", "800000"))
SHRINK_CHARS = 100_000  # 413 时单 span 的 input/output 收缩长度


def _batches(spans: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
    """按序切分，使每个请求体不超过 BATCH_BYTES（服务端有请求大小上限）。"""
    out: List[List[Dict[str, Any]]] = []
    cur: List[Dict[str, Any]] = []
    cur_bytes = 0
    for s in spans:
        n = len(json.dumps(s, ensure_ascii=False).encode("utf-8"))
        if cur and cur_bytes + n > BATCH_BYTES:
            out.append(cur)
            cur, cur_bytes = [], 0
        cur.append(s)
        cur_bytes += n
    if cur:
        out.append(cur)
    return out


def _shrink_span(span: Dict[str, Any]) -> Dict[str, Any]:
    """最后手段：收缩超大 span 的 input/output（保留前 SHRINK_CHARS 字符）。"""
    out = dict(span)
    new_attrs = []
    for kv in span["attributes"]:
        if kv["key"] in ("langfuse.observation.input", "langfuse.observation.output",
                         "langfuse.trace.input", "langfuse.trace.output"):
            sv = kv["value"].get("stringValue")
            if isinstance(sv, str) and len(sv) > SHRINK_CHARS:
                kv = {"key": kv["key"],
                      "value": {"stringValue": sv[:SHRINK_CHARS] + "…[shrunk for transport]"}}
        new_attrs.append(kv)
    out["attributes"] = new_attrs
    return out


def _post_batch(target: Dict[str, str], auth: str, batch: List[Dict[str, Any]],
                shrunk: bool = False) -> None:
    body = json.dumps({
        "resourceSpans": [{
            "resource": {"attributes": attrs({"service.name": "kimi-code"})},
            "scopeSpans": [{
                "scope": {"name": "kimi-litefuse", "version": "2.0.0"},
                "spans": [
                    {**{k: v for k, v in s.items() if v is not None},
                     "attributes": s["attributes"] + attrs(
                         {"langfuse.environment": target["environment"]})}
                    for s in batch
                ],
            }],
        }],
    }).encode("utf-8")
    req = urllib.request.Request(
        f"{target['base_url']}/api/public/otel/v1/traces",
        data=body,
        headers={"Authorization": f"Basic {auth}", "Content-Type": "application/json"},
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT_S) as resp:
            debug(f"sent {len(batch)} span(s) ({len(body)}B) -> {target['base_url']} HTTP {resp.status}")
        return
    except urllib.error.HTTPError as e:
        if e.code == 413:
            if len(batch) > 1:  # 对半重试
                mid = len(batch) // 2
                _post_batch(target, auth, batch[:mid], shrunk)
                _post_batch(target, auth, batch[mid:], shrunk)
                return
            if not shrunk:  # 单 span 仍超限：收缩 input/output 后重试一次
                _post_batch(target, auth, [_shrink_span(batch[0])], shrunk=True)
                return
        _log("ERROR", f"send failed: {target['base_url']} HTTP {e.code}")
    except Exception as e:
        _log("ERROR", f"send failed: {target['base_url']} {e}")


def send_spans(targets: List[Dict[str, str]], spans: List[Dict[str, Any]]) -> None:
    if not spans:
        return
    for target in targets:
        auth = base64.b64encode(
            f"{target['public_key']}:{target['secret_key']}".encode("utf-8")).decode("ascii")
        for batch in _batches(spans):
            _post_batch(target, auth, batch)


# ----------------- Incremental session processing -----------------
def read_lines_from(transcript_path: Path, offset: int) -> List[Tuple[int, Optional[Dict[str, Any]]]]:
    """物理行号从 offset 起读；解析失败的行返回 (no, None) 以保持行号准确。"""
    out: List[Tuple[int, Optional[Dict[str, Any]]]] = []
    try:
        with open(transcript_path, "r", encoding="utf-8", errors="replace") as f:
            for no, line in enumerate(f):
                if no < offset:
                    continue
                line = line.strip()
                if not line:
                    out.append((no, None))
                    continue
                try:
                    out.append((no, json.loads(line)))
                except Exception:
                    out.append((no, None))
    except Exception as e:
        debug(f"read_lines_from failed: {e}")
    return out


def list_sessions() -> List[Dict[str, str]]:
    """session_index.jsonl 全量会话（同 sessionId 取最后一条）。"""
    index_path = Path.home() / ".kimi-code" / "session_index.jsonl"
    sessions: Dict[str, Dict[str, str]] = {}
    try:
        with open(index_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                sid = obj.get("sessionId", "")
                if sid:
                    sessions[sid] = obj
    except Exception as e:
        debug(f"list_sessions failed: {e}")
    return list(sessions.values())


def process_session(targets: List[Dict[str, str]], state: Dict[str, Any],
                    session: Dict[str, str], now_ms: int) -> int:
    session_id = session.get("sessionId", "")
    session_dir = session.get("sessionDir", "")
    work_dir = session.get("workDir", "")
    wire = Path(session_dir) / "agents" / "main" / "wire.jsonl"
    if not session_id or not wire.exists():
        return 0

    key = state_key(session_id, str(wire))
    ss = state.get(key, {})
    offset = int(ss.get("offset", 0))
    turn_count = int(ss.get("turn_count", 0))

    lines = read_lines_from(wire, offset)
    if not lines:
        return 0
    total_end = lines[-1][0] + 1

    chunks = parse_chunks(lines)
    emitted = 0
    consumed = offset

    if not chunks:
        # 窗口内无 turn.prompt：无法归属的前导行（仅旧状态迁移时出现），直接消费
        consumed = total_end
    for i, chunk in enumerate(chunks):
        is_last = i == len(chunks) - 1
        finished = chunk_finished(chunk)
        stale = now_ms - (chunk.last_event_ms or now_ms) > STALE_MS
        if is_last and not finished and not stale:
            consumed = chunk.start_line  # 留给下次轮询重读
            break
        turn_num = (chunk.turn_id + 1) if chunk.turn_id is not None else (turn_count + 1)
        turn_count = max(turn_count, turn_num)
        try:
            spans = build_turn_spans(session_id, work_dir, turn_num, chunk, str(wire),
                                     interrupted=(not finished),
                                     agents_dir=Path(session_dir) / "agents")
            send_spans(targets, spans)
            emitted += 1
            debug(f"turn {turn_num} ({len(spans)} spans) session={session_id}"
                  + (" [interrupted]" if not finished else ""))
        except Exception as e:
            _log("ERROR", f"emit turn {turn_num} failed: {e}")
        consumed = chunk.end_line if not is_last else total_end

    state[key] = {
        "offset": consumed,
        "turn_count": turn_count,
        "updated": datetime.now(timezone.utc).isoformat(),
    }
    return emitted


# ----------------- Main -----------------
def main() -> int:
    start = time.time()
    debug("Hook started")

    if os.environ.get("TRACE_TO_LITEFUSE", "").lower() != "true":
        return 0

    targets = load_targets()
    if not targets:
        debug("No Litefuse target configured (missing LITEFUSE_PUBLIC_KEY/SECRET_KEY)")
        return 0

    now_ms = int(time.time() * 1000)
    try:
        with FileLock(LOCK_FILE):
            state = load_state()
            emitted = 0
            for session in list_sessions():
                try:
                    emitted += process_session(targets, state, session, now_ms)
                except Exception as e:
                    _log("ERROR", f"process_session failed: {e}")
            save_state(state)
        if emitted:
            info(f"Emitted {emitted} turn(s) in {time.time() - start:.2f}s "
                 f"-> {', '.join(t['base_url'] for t in targets)}")
        return 0
    except Exception as e:
        _log("ERROR", f"Unexpected failure: {e}")
        return 0


if __name__ == "__main__":
    sys.exit(main())
