#!/usr/bin/env python3
"""
aos_rag.py - minimal local RAG helper for ArchitectOS

Goals:
  - Index useful local artifacts (command logs, scripts/docs).
  - Let the LLM search "what did we do before / what tools exist?"
  - Stay simple, transparent, and easy to reason about.

Storage:
  - SQLite DB:  ~/.architectos/data/aos_rag.db  (RAG store)
  - Single table 'items' with:
      id        INTEGER PRIMARY KEY
      kind      TEXT    ('command', 'script', 'doc', etc.)
      path      TEXT    (file path or log path)
      summary   TEXT    (short description)
      content   TEXT    (command line or file snippet)
      embedding TEXT    (JSON-encoded list[float])

Usage examples:

  # Index command logs
  aos_rag.py index-commands /home/lathem/ArchitectOS/logs/commands

  # Index scripts/docs
  aos_rag.py index-files /home/lathem/ArchitectOS --kind script --ext .py .rb

  # Search
  aos_rag.py search "python formatting utility or cleanup script"

In the ArchitectOS environment, you should normally invoke it with the
project venv Python, e.g.:

  /home/lathem/ArchitectOS/venv/bin/python3 tools/aos_rag.py search "..."
"""

import argparse
import json
import math
import sqlite3
import sys
import os
from pathlib import Path
from typing import Iterable, List, Tuple

from openai import OpenAI, BadRequestError


# ------------------------------------------------------------
#  Configuration
# ------------------------------------------------------------

AOS_HOME = Path(os.getenv("AOS_HOME", str(Path.home() / ".architectos"))).expanduser()
DATA_DIR = AOS_HOME / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)

DB_PATH = DATA_DIR / "aos_rag.db"

EMBED_MODEL = "text-embedding-3-large"

client = OpenAI()


# ------------------------------------------------------------
#  DB helpers
# ------------------------------------------------------------

def get_db() -> sqlite3.Connection:
    conn = sqlite3.connect(DB_PATH)
    conn.execute(
        """
        CREATE TABLE IF NOT EXISTS items (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            kind TEXT NOT NULL,
            path TEXT,
            summary TEXT NOT NULL,
            content TEXT NOT NULL,
            embedding TEXT NOT NULL
        )
        """
    )
    conn.execute("CREATE INDEX IF NOT EXISTS idx_items_kind ON items(kind)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_items_path ON items(path)")
    return conn


def embed(text: str) -> List[float]:
    """
    Call OpenAI embeddings API and return a list of floats.
    This expects OPENAI_API_KEY to be set in the environment.
    """
    resp = client.embeddings.create(
        model=EMBED_MODEL,
        input=text,
    )
    return resp.data[0].embedding

def safe_embed_for_file(summary: str, content: str, max_bytes: int) -> List[float]:
    """
    Safely embed file content without exceeding the model's context:

    - Starts with summary + up to max_bytes of content.
    - If the API complains about context length or any BadRequestError,
      progressively halves the amount of text until it fits, or falls back
      to embedding just the summary.
    """
    # Start from a conservative ceiling to avoid huge prompts.
    cur_bytes = min(max_bytes, 8192)

    while cur_bytes >= 1024:
        text = summary + "\n" + content[:cur_bytes]
        try:
            return embed(text)
        except BadRequestError as e:
            # Any BadRequestError from the embeddings API is treated as a signal
            # to shrink the context. This is robust against message text changes.
            cur_bytes //= 2
            continue

    # If we somehow still can't fit, embed just the summary as a fallback.
    # If even that fails, re-raise so the caller sees the error.
    try:
        return embed(summary)
    except BadRequestError:
        raise
def cosine(a: List[float], b: List[float]) -> float:
    """Cosine similarity between two vectors."""
    dot = 0.0
    na = 0.0
    nb = 0.0
    for x, y in zip(a, b):
        dot += x * y
        na += x * x
        nb += y * y
    if na <= 0.0 or nb <= 0.0:
        return 0.0
    return dot / (math.sqrt(na) * math.sqrt(nb) + 1e-9)


# ------------------------------------------------------------
#  Index: command logs
# ------------------------------------------------------------

def index_commands(log_dir: Path) -> None:
    """
    Index command log files from ~/ArchitectOS/logs/commands.

    Each log file is expected to contain a line like:
      COMMAND: <actual command>
    """
    if not log_dir.is_dir():
        print(f"[aos-rag] log_dir does not exist or is not a directory: {log_dir}", file=sys.stderr)
        sys.exit(1)

    conn = get_db()
    cur = conn.cursor()

    log_files = sorted(log_dir.glob("*.log"))
    if not log_files:
        print(f"[aos-rag] No .log files found in {log_dir}", file=sys.stderr)

    indexed = 0

    for log_file in log_files:
        try:
            text = log_file.read_text(encoding="utf-8", errors="ignore")
        except Exception as e:
            print(f"[aos-rag] Skipping {log_file}: {e}", file=sys.stderr)
            continue

        cmd_line = None
        for line in text.splitlines():
            if line.startswith("COMMAND: "):
                cmd_line = line[len("COMMAND: "):].strip()
                break

        if not cmd_line:
            # Not a normal command log; skip.
            continue

        summary = f"shell command: {cmd_line}"
        emb_vec = embed(summary + "\n" + cmd_line)

        # Avoid duplicates for this specific log file.
        cur.execute(
            "DELETE FROM items WHERE kind = ? AND path = ?",
            ("command", str(log_file)),
        )

        cur.execute(
            """
            INSERT INTO items (kind, path, summary, content, embedding)
            VALUES (?, ?, ?, ?, ?)
            """,
            (
                "command",
                str(log_file),
                summary,
                cmd_line,
                json.dumps(emb_vec),
            ),
        )
        indexed += 1

    conn.commit()
    conn.close()

    print(
        f"[aos-rag] Indexed {indexed} commands from log files in {log_dir}"
    )


# ------------------------------------------------------------
#  Index: files (scripts/docs)
# ------------------------------------------------------------

def iter_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
    """Yield files under root whose suffix matches one of the given extensions."""
    exts = {e.lower() for e in extensions}
    for path in root.rglob("*"):
        if not path.is_file():
            continue
        if path.suffix.lower() in exts:
            yield path


def index_files(root: Path, kind: str, exts: List[str], max_bytes: int) -> None:
    """
    Index files under a root directory (scripts, docs, etc.).

    Example:
      kind = "script", exts = [".py", ".rb"]
    """
    if not root.is_dir():
        print(f"[aos-rag] root does not exist or is not a directory: {root}", file=sys.stderr)
        sys.exit(1)

    conn = get_db()
    cur = conn.cursor()

    count = 0

    for path in iter_files(root, exts):
        # Skip absurdly large files (e.g. > 4x max_bytes)
        try:
            size = path.stat().st_size
        except OSError as e:
            print(f"[aos-rag] Skipping {path}: cannot stat file ({e})", file=sys.stderr)
            continue

        if size > max_bytes * 4:
            print(f"[aos-rag] Skipping {path}: size {size} > {max_bytes * 4} bytes", file=sys.stderr)
            continue

        try:
            raw = path.read_text(encoding="utf-8", errors="ignore")
        except Exception as e:
            print(f"[aos-rag] Skipping {path}: {e}", file=sys.stderr)
            continue

        content = raw[:max_bytes]
        try:
            rel = path.relative_to(root)
            rel_str = str(rel)
        except ValueError:
            rel_str = str(path)

        summary = f"{kind} file: {rel_str}"

        # Use safe embedding that backs off if context is too large
        emb_vec = safe_embed_for_file(summary, content, max_bytes)

        # Avoid duplicates for this specific file.
        cur.execute(
            "DELETE FROM items WHERE kind = ? AND path = ?",
            (kind, str(path)),
        )

        cur.execute(
            """
            INSERT INTO items (kind, path, summary, content, embedding)
            VALUES (?, ?, ?, ?, ?)
            """,
            (
                kind,
                str(path),
                summary,
                content,
                json.dumps(emb_vec),
            ),
        )
        count += 1

    conn.commit()
    conn.close()

    print(
        f"[aos-rag] Indexed {count} files under {root} as kind='{kind}' "
        f"with extensions {', '.join(exts)}"
    )


# ------------------------------------------------------------
#  Search
# ------------------------------------------------------------

def search(query: str, limit: int = 5, kind_filter: str | None = None) -> None:
    """
    Search indexed items for the query and print ranked results.

    If kind_filter is provided, only items of that kind are considered.
    """
    conn = get_db()
    cur = conn.cursor()

    if kind_filter:
        cur.execute(
            "SELECT id, kind, path, summary, content, embedding FROM items WHERE kind = ?",
            (kind_filter,),
        )
    else:
        cur.execute(
            "SELECT id, kind, path, summary, content, embedding FROM items"
        )

    rows = cur.fetchall()
    conn.close()

    if not rows:
        print("[aos-rag] No items indexed yet. Try 'index-commands' or 'index-files'.")
        return

    q_vec = embed(query)

    scored: List[Tuple[float, Tuple]] = []

    for row in rows:
        id_, kind, path, summary, content, emb_json = row
        try:
            emb_vec = json.loads(emb_json)
        except Exception:
            continue
        score = cosine(q_vec, emb_vec)
        scored.append((score, row))

    scored.sort(key=lambda x: x[0], reverse=True)
    top = scored[:limit]

    print(f'RAG search results for: "{query}"')
    if kind_filter:
        print(f"(kind filter: {kind_filter})")
    print()

    if not top:
        print("[aos-rag] No results above similarity threshold.")
        return

    for rank, (score, row) in enumerate(top, start=1):
        id_, kind, path, summary, content, emb_json = row
        print(f"{rank}) [{kind}] path: {path}")
        print(f"   score: {score:.3f}")
        print(f"   summary: {summary}")

        # For commands, show the command line; for others, show a brief snippet.
        if kind == "command":
            print(f"   command: {content}")
        else:
            snippet = content.strip().replace("\n", " ")
            if len(snippet) > 200:
                snippet = snippet[:197] + "..."
            print(f"   snippet: {snippet}")
        print("---")


# ------------------------------------------------------------
#  CLI
# ------------------------------------------------------------

def main() -> None:
    ap = argparse.ArgumentParser(description="Minimal RAG helper for ArchitectOS")
    sub = ap.add_subparsers(dest="cmd", required=True)

    # index-commands
    p_idx_cmd = sub.add_parser("index-commands", help="Index command log files")
    p_idx_cmd.add_argument(
        "log_dir",
        type=Path,
        help="Directory containing command logs (e.g. ~/ArchitectOS/logs/commands)",
    )

    # index-files
    p_idx_files = sub.add_parser("index-files", help="Index scripts/docs under a root")
    p_idx_files.add_argument(
        "root",
        type=Path,
        help="Root directory to scan (e.g. ~/ArchitectOS)",
    )
    p_idx_files.add_argument(
        "--kind",
        type=str,
        default="script",
        help="Logical kind label for these files (default: script)",
    )
    p_idx_files.add_argument(
        "--ext",
        type=str,
        nargs="+",
        default=[".py"],
        help="File extensions to include (default: .py). Example: --ext .py .rb",
    )
    p_idx_files.add_argument(
        "--max-bytes",
        type=int,
        default=16384,
        help="Maximum bytes of each file to embed (default: 16384)",
    )

    # search
    p_search = sub.add_parser("search", help="Search indexed items")
    p_search.add_argument("query", type=str, help="Query text")
    p_search.add_argument(
        "--limit",
        type=int,
        default=5,
        help="Maximum number of results to show (default: 5)",
    )
    p_search.add_argument(
        "--kind",
        type=str,
        default=None,
        help="Optional kind filter (e.g. 'command', 'script')",
    )

    args = ap.parse_args()

    if args.cmd == "index-commands":
        index_commands(args.log_dir)
    elif args.cmd == "index-files":
        index_files(args.root, args.kind, args.ext, args.max_bytes)
    elif args.cmd == "search":
        search(args.query, args.limit, args.kind)
    else:
        ap.print_help()


if __name__ == "__main__":
    main()