#!/usr/bin/env python3
"""pip-witness real-time viewer and scan orchestrator.

Serves a web UI that shows live scan progress, attestation results, and
allows subscribing to packages for automatic re-scanning on new releases.

Architecture:
- Flask app serving REST API + SSE (Server-Sent Events) for live updates
- SQLite database for scan results, subscriptions, and queue state
- Background scanner thread pulling from prioritized queue
- Watchdog on attestations/ directory for immediate result display
"""

import base64
import json
import os
import queue
import sqlite3
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from pathlib import Path

from flask import Flask, Response, jsonify, request, send_file, send_from_directory

APP_DIR = Path(__file__).parent
PROJECT_DIR = APP_DIR.parent
ATTESTATION_DIR = PROJECT_DIR / "attestations"
DB_PATH = APP_DIR / "pip_witness.db"
STATIC_DIR = APP_DIR / "static"

app = Flask(__name__, static_folder=str(STATIC_DIR))

# SSE clients
sse_clients: list[queue.Queue] = []
sse_lock = threading.Lock()

# Three queues: scan (Docker), eval (OPA on new scans), re-eval (OPA on existing attestations)
scan_queue: queue.PriorityQueue = queue.PriorityQueue()     # (priority, ts, pkg, ver, scan_id)
eval_queue: queue.Queue = queue.Queue()                      # (scan_id,) - new attestations to evaluate
reeval_queue: queue.Queue = queue.Queue()                    # (scan_id,) - re-evaluate with current policies

# Scan rate tracking for ETA
scan_completions: list = []  # timestamps of recent completions


# ---------------------------------------------------------------------------
# Database
# ---------------------------------------------------------------------------

def get_db():
    db = sqlite3.connect(str(DB_PATH))
    db.row_factory = sqlite3.Row
    return db


def init_db():
    db = get_db()
    db.executescript("""
        -- Scans: the immutable attestation record (one per Docker run)
        CREATE TABLE IF NOT EXISTS scans (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            package TEXT NOT NULL,
            version TEXT,
            status TEXT NOT NULL DEFAULT 'queued',
            started_at TEXT,
            completed_at TEXT,
            attestation_path TEXT,
            gitoid TEXT,
            pre_analysis_path TEXT,
            network_connections INTEGER DEFAULT 0,
            files_opened INTEGER DEFAULT 0,
            processes_spawned INTEGER DEFAULT 0,
            dns_lookups INTEGER DEFAULT 0,
            sockets_created INTEGER DEFAULT 0,
            packages_installed INTEGER DEFAULT 0,
            error TEXT,
            created_at TEXT DEFAULT (datetime('now'))
        );

        -- Evaluations: policy results against an attestation (many per scan)
        -- Re-running policies creates a new evaluation row, old ones are kept
        CREATE TABLE IF NOT EXISTS evaluations (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            scan_id INTEGER NOT NULL REFERENCES scans(id),
            risk_score INTEGER DEFAULT 0,
            risk_level TEXT DEFAULT 'UNKNOWN',
            policy_results TEXT,
            policy_version TEXT,
            created_at TEXT DEFAULT (datetime('now'))
        );

        CREATE TABLE IF NOT EXISTS subscriptions (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            package TEXT NOT NULL UNIQUE,
            last_scanned_version TEXT,
            last_scanned_at TEXT,
            scan_frequency TEXT DEFAULT 'on_release',
            priority INTEGER DEFAULT 50,
            notes TEXT,
            created_at TEXT DEFAULT (datetime('now'))
        );

        CREATE INDEX IF NOT EXISTS idx_scans_package ON scans(package);
        CREATE INDEX IF NOT EXISTS idx_scans_status ON scans(status);
        CREATE INDEX IF NOT EXISTS idx_evaluations_scan ON evaluations(scan_id);
        CREATE INDEX IF NOT EXISTS idx_subscriptions_package ON subscriptions(package);

        -- Track PyPI file hashes over time to detect tampering
        CREATE TABLE IF NOT EXISTS pypi_hashes (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            package TEXT NOT NULL,
            version TEXT NOT NULL,
            filename TEXT NOT NULL,
            sha256 TEXT NOT NULL,
            first_seen TEXT DEFAULT (datetime('now')),
            last_seen TEXT DEFAULT (datetime('now')),
            UNIQUE(package, version, filename, sha256)
        );
        CREATE INDEX IF NOT EXISTS idx_pypi_hashes_pkg ON pypi_hashes(package, version);
    """)
    db.commit()
    db.close()


# ---------------------------------------------------------------------------
# SSE (Server-Sent Events)
# ---------------------------------------------------------------------------

def broadcast_event(event_type: str, data: dict):
    """Send an event to all connected SSE clients."""
    msg = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
    with sse_lock:
        dead = []
        for i, q in enumerate(sse_clients):
            try:
                q.put_nowait(msg)
            except queue.Full:
                dead.append(i)
        for i in reversed(dead):
            sse_clients.pop(i)


@app.route("/api/events")
def sse_stream():
    """SSE endpoint for real-time updates."""
    q: queue.Queue = queue.Queue(maxsize=100)
    with sse_lock:
        sse_clients.append(q)

    def generate():
        yield "event: connected\ndata: {}\n\n"
        while True:
            try:
                msg = q.get(timeout=30)
                yield msg
            except queue.Empty:
                yield ": keepalive\n\n"

    return Response(generate(), mimetype="text/event-stream",
                    headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})


# ---------------------------------------------------------------------------
# API Routes
# ---------------------------------------------------------------------------

@app.route("/")
def index():
    return send_from_directory(str(STATIC_DIR), "index.html")


@app.route("/api/scans")
def list_scans():
    db = get_db()
    rows = db.execute(
        "SELECT * FROM scans ORDER BY created_at DESC LIMIT 100"
    ).fetchall()
    db.close()
    return jsonify([dict(r) for r in rows])


@app.route("/api/scans/<int:scan_id>")
def get_scan(scan_id):
    db = get_db()
    row = db.execute("SELECT * FROM scans WHERE id = ?", (scan_id,)).fetchone()
    if not row:
        db.close()
        return jsonify({"error": "not found"}), 404

    result = dict(row)

    # Load attestation data if available
    if result.get("attestation_path") and os.path.exists(result["attestation_path"]):
        try:
            with open(result["attestation_path"]) as f:
                envelope = json.load(f)
            payload = json.loads(base64.b64decode(envelope["payload"]))
            result["attestation"] = payload
        except Exception as e:
            result["attestation_error"] = str(e)

    # Load LATEST policy evaluation from evaluations table
    eval_row = db.execute(
        "SELECT * FROM evaluations WHERE scan_id = ? ORDER BY created_at DESC LIMIT 1",
        (scan_id,)
    ).fetchone()
    if eval_row:
        result["risk_score"] = eval_row["risk_score"]
        result["risk_level"] = eval_row["risk_level"]
        result["eval_id"] = eval_row["id"]
        result["eval_date"] = eval_row["created_at"]
        result["policy_version"] = eval_row["policy_version"]
        try:
            result["policy_eval"] = json.loads(eval_row["policy_results"])
        except Exception:
            pass
    else:
        result["risk_score"] = 0
        result["risk_level"] = "UNEVALUATED"

    db.close()

    # Load pre-analysis if available
    if result.get("pre_analysis_path") and os.path.exists(result["pre_analysis_path"]):
        try:
            with open(result["pre_analysis_path"]) as f:
                result["pre_analysis"] = json.load(f)
        except Exception:
            pass

    return jsonify(result)


@app.route("/api/scan", methods=["POST"])
def start_scan():
    """Queue a package for scanning."""
    data = request.json or {}
    package = data.get("package", "").strip()
    if not package:
        return jsonify({"error": "package required"}), 400

    priority = data.get("priority", 50)

    # Insert scan record
    db = get_db()
    scan_id = db.execute(
        "INSERT INTO scans (package, version, status) VALUES (?, ?, 'queued') RETURNING id",
        (package, data.get("version"))
    ).fetchone()[0]
    db.commit()
    db.close()

    scan_queue.put((priority, time.time(), package, data.get("version"), scan_id))

    broadcast_event("scan_queued", {
        "id": scan_id, "package": package, "version": data.get("version"),
        "priority": priority
    })

    return jsonify({"id": scan_id, "status": "queued"})


@app.route("/api/subscriptions", methods=["GET"])
def list_subscriptions():
    db = get_db()
    rows = db.execute("SELECT * FROM subscriptions ORDER BY priority DESC, package").fetchall()
    db.close()
    return jsonify([dict(r) for r in rows])


@app.route("/api/subscriptions", methods=["POST"])
def add_subscription():
    data = request.json or {}
    package = data.get("package", "").strip()
    if not package:
        return jsonify({"error": "package required"}), 400

    db = get_db()
    try:
        db.execute(
            """INSERT INTO subscriptions (package, priority, notes)
               VALUES (?, ?, ?)
               ON CONFLICT(package) DO UPDATE SET
                   priority = excluded.priority,
                   notes = excluded.notes""",
            (package, data.get("priority", 50), data.get("notes", ""))
        )
        db.commit()
    except Exception as e:
        db.close()
        return jsonify({"error": str(e)}), 500
    db.close()

    broadcast_event("subscription_added", {"package": package})

    # Immediately queue a scan for this package
    start_scan_internal(package, priority=data.get("priority", 10), source="subscription")

    return jsonify({"status": "subscribed", "package": package})


@app.route("/api/subscriptions/<package>", methods=["DELETE"])
def remove_subscription(package):
    db = get_db()
    db.execute("DELETE FROM subscriptions WHERE package = ?", (package,))
    db.commit()
    db.close()
    broadcast_event("subscription_removed", {"package": package})
    return jsonify({"status": "removed"})


@app.route("/api/queue")
def get_queue():
    db = get_db()
    rows = db.execute(
        "SELECT * FROM scans WHERE status IN ('queued', 'running') ORDER BY created_at"
    ).fetchall()
    db.close()

    with active_scans_lock:
        running_count = active_scans

    return jsonify({"queue": [dict(r) for r in rows], "active_scans": running_count})


@app.route("/api/packages")
def list_packages():
    """Package-centric view: one row per package with aggregate stats."""
    q = request.args.get("q", "").strip()
    risk = request.args.get("risk", "")
    sort = request.args.get("sort", "risk")  # risk, name, scans, recent
    limit = min(int(request.args.get("limit", 50)), 200)
    offset = int(request.args.get("offset", 0))

    db = get_db()
    conditions = ["s.status = 'completed'"]
    params = []

    if q:
        conditions.append("s.package LIKE ?")
        params.append(f"%{q}%")

    where = " AND ".join(conditions)

    # Aggregate per package
    order = {
        "risk": "max_risk_score DESC, scan_count DESC",
        "name": "package ASC",
        "scans": "scan_count DESC",
        "recent": "last_scanned DESC",
        "files": "total_files DESC",
        "network": "total_net DESC",
    }.get(sort, "max_risk_score DESC")

    query = f"""
        SELECT
            s.package,
            COUNT(*) as scan_count,
            COUNT(DISTINCT s.version) as version_count,
            COALESCE(MAX(e.risk_score), 0) as max_risk_score,
            COALESCE(MAX(e.risk_level), 'UNEVALUATED') as worst_risk_level,
            SUM(s.network_connections) as total_net,
            SUM(s.files_opened) as total_files,
            SUM(s.processes_spawned) as total_procs,
            MAX(s.completed_at) as last_scanned,
            GROUP_CONCAT(DISTINCT s.version) as versions
        FROM scans s
        LEFT JOIN evaluations e ON e.id = (
            SELECT e2.id FROM evaluations e2 WHERE e2.scan_id = s.id
            ORDER BY e2.created_at DESC LIMIT 1
        )
        WHERE {where}
        GROUP BY s.package
    """

    # Apply risk filter on aggregate
    if risk:
        query = f"SELECT * FROM ({query}) WHERE worst_risk_level = ?"
        params.append(risk.upper())

    total = db.execute(f"SELECT COUNT(*) FROM ({query})", params).fetchone()[0]
    rows = db.execute(f"{query} ORDER BY {order} LIMIT ? OFFSET ?", params + [limit, offset]).fetchall()
    db.close()

    packages = []
    for r in rows:
        versions_str = r["versions"] or ""
        packages.append({
            "package": r["package"],
            "scan_count": r["scan_count"],
            "version_count": r["version_count"],
            "max_risk_score": r["max_risk_score"] or 0,
            "worst_risk_level": r["worst_risk_level"] or "CLEAN",
            "total_net": r["total_net"] or 0,
            "total_files": r["total_files"] or 0,
            "total_procs": r["total_procs"] or 0,
            "last_scanned": r["last_scanned"],
            "versions": [v for v in versions_str.split(",") if v][:10],
        })

    return jsonify({"packages": packages, "total": total, "limit": limit, "offset": offset})


@app.route("/api/packages/<package>/scans")
def package_scans(package):
    """All scans for a specific package, ordered by version/date."""
    db = get_db()
    rows = db.execute(
        """SELECT * FROM scans WHERE package = ? AND status = 'completed'
           ORDER BY completed_at DESC LIMIT 50""",
        (package,)
    ).fetchall()
    db.close()
    return jsonify([dict(r) for r in rows])


@app.route("/api/dashboard")
def dashboard():
    """Aggregate dashboard data for the overview."""
    db = get_db()
    d = {
        "total_packages": db.execute("SELECT COUNT(DISTINCT package) FROM scans WHERE status='completed'").fetchone()[0],
        "total_scans": db.execute("SELECT COUNT(*) FROM scans WHERE status='completed'").fetchone()[0],
        "queued": db.execute("SELECT COUNT(*) FROM scans WHERE status='queued'").fetchone()[0],
        "running": db.execute("SELECT COUNT(*) FROM scans WHERE status='running'").fetchone()[0],
        "failed": db.execute("SELECT COUNT(*) FROM scans WHERE status='failed'").fetchone()[0],
        "risk_distribution": {},
        "top_risky": [],
        "recent_completions": [],
        "aggregate": {},
    }

    evals_pending = eval_queue.qsize() + reeval_queue.qsize()
    d["evals_pending"] = evals_pending
    d["evaluated"] = db.execute("SELECT COUNT(DISTINCT scan_id) FROM evaluations").fetchone()[0]

    # Coverage metrics
    try:
        indexed = db.execute("SELECT COUNT(*) FROM pypi_packages").fetchone()[0]
    except Exception:
        indexed = 0
    scanned_pkgs = d["total_packages"]
    unique_versions = db.execute(
        "SELECT COUNT(DISTINCT package || '==' || COALESCE(version,'')) FROM scans WHERE status='completed'"
    ).fetchone()[0]
    total_pypi = 780000  # approximate
    d["coverage"] = {
        "total_pypi": total_pypi,
        "indexed": indexed,
        "scanned_packages": scanned_pkgs,
        "scanned_versions": unique_versions,
        "pct_of_pypi": round(scanned_pkgs / total_pypi * 100, 4) if total_pypi else 0,
        "pct_of_indexed": round(scanned_pkgs / indexed * 100, 1) if indexed else 0,
    }

    # Risk distribution from evaluations table
    rows = db.execute("""
        SELECT e.risk_level, COUNT(DISTINCT s.package) as c
        FROM evaluations e JOIN scans s ON s.id = e.scan_id
        WHERE e.id IN (SELECT MAX(id) FROM evaluations GROUP BY scan_id)
        GROUP BY e.risk_level
    """).fetchall()
    d["risk_distribution"] = {r["risk_level"] or "UNEVALUATED": r["c"] for r in rows}

    # Top risky packages from evaluations
    rows = db.execute("""
        SELECT s.package, MAX(e.risk_score) as score, MAX(e.risk_level) as level,
               COUNT(*) as scans, SUM(s.network_connections) as net
        FROM scans s JOIN evaluations e ON e.scan_id = s.id
        WHERE e.risk_level NOT IN ('CLEAN')
        AND e.id IN (SELECT MAX(id) FROM evaluations GROUP BY scan_id)
        GROUP BY s.package ORDER BY score DESC LIMIT 20
    """).fetchall()
    d["top_risky"] = [dict(r) for r in rows]

    # Recent completions (join with latest eval)
    rows = db.execute("""
        SELECT s.id, s.package, s.version, s.network_connections, s.files_opened, s.completed_at,
               COALESCE(e.risk_level, 'UNEVALUATED') as risk_level,
               COALESCE(e.risk_score, 0) as risk_score
        FROM scans s LEFT JOIN evaluations e ON e.scan_id = s.id
            AND e.id = (SELECT MAX(e2.id) FROM evaluations e2 WHERE e2.scan_id = s.id)
        WHERE s.status='completed'
        ORDER BY s.completed_at DESC LIMIT 15
    """).fetchall()
    d["recent_completions"] = [dict(r) for r in rows]

    # Aggregates
    row = db.execute("""
        SELECT SUM(network_connections) as net, SUM(files_opened) as files,
               SUM(processes_spawned) as procs, SUM(dns_lookups) as dns,
               SUM(packages_installed) as pkgs
        FROM scans WHERE status='completed'
    """).fetchone()
    d["aggregate"] = {
        "network_connections": row["net"] or 0,
        "files_opened": row["files"] or 0,
        "processes_spawned": row["procs"] or 0,
        "dns_lookups": row["dns"] or 0,
        "packages_installed": row["pkgs"] or 0,
    }

    db.close()
    return jsonify(d)


@app.route("/api/reeval-all", methods=["POST"])
def reeval_all():
    """Re-evaluate ALL existing attestations with current policies. No Docker needed."""
    db = get_db()
    rows = db.execute(
        "SELECT id FROM scans WHERE status='completed' AND attestation_path IS NOT NULL"
    ).fetchall()
    db.close()

    count = 0
    for r in rows:
        reeval_queue.put(r["id"])
        count += 1

    broadcast_event("reeval_started", {"count": count})
    return jsonify({"status": "queued_for_reeval", "count": count})


@app.route("/api/reeval/<int:scan_id>", methods=["POST"])
def reeval_single(scan_id):
    """Re-evaluate a single scan with current policies."""
    reeval_queue.put(scan_id)
    return jsonify({"status": "queued_for_reeval", "scan_id": scan_id})


@app.route("/api/scans/<int:scan_id>/raw")
def get_raw_attestation(scan_id):
    """Return the raw attestation JSON for a scan."""
    db = get_db()
    row = db.execute("SELECT attestation_path FROM scans WHERE id = ?", (scan_id,)).fetchone()
    db.close()
    if not row or not row["attestation_path"] or not os.path.exists(row["attestation_path"]):
        return jsonify({"error": "attestation not found"}), 404
    with open(row["attestation_path"]) as f:
        return Response(f.read(), mimetype="application/json")


@app.route("/api/scans/<int:scan_id>/download/<artifact>")
def download_artifact(scan_id, artifact):
    """Download scan artifacts: attestation, sarif, sbom, import."""
    db = get_db()
    row = db.execute("SELECT attestation_path, package, version FROM scans WHERE id = ?", (scan_id,)).fetchone()
    db.close()
    if not row or not row["attestation_path"]:
        return jsonify({"error": "scan not found"}), 404

    base_dir = os.path.dirname(row["attestation_path"])
    base_name = os.path.basename(row["attestation_path"])

    file_map = {
        "attestation": (row["attestation_path"], "application/json", f"{row['package']}-{row['version']}-attestation.json"),
        "sarif": (os.path.join(base_dir, base_name.replace("pip-install-", "pip-install-sarif-").replace(".json", ".sarif")),
                  "application/sarif+json", f"{row['package']}-{row['version']}.sarif"),
        "sbom": (os.path.join(base_dir, base_name.replace("pip-install-", "pip-install-sbom-").replace(".json", ".cdx.json")),
                 "application/json", f"{row['package']}-{row['version']}-sbom.cdx.json"),
        "import": (os.path.join(base_dir, base_name.replace("pip-install-", "pip-install-import-")),
                   "application/json", f"{row['package']}-{row['version']}-import-attestation.json"),
    }

    if artifact not in file_map:
        return jsonify({"error": f"unknown artifact: {artifact}"}), 400

    file_path, mime, filename = file_map[artifact]
    if not os.path.exists(file_path):
        return jsonify({"error": f"{artifact} not available"}), 404

    return send_file(file_path, mimetype=mime, as_attachment=True, download_name=filename)


@app.route("/api/scans/<int:scan_id>/evaluations")
def get_scan_evaluations(scan_id):
    """Return all evaluations for a scan (history of policy runs)."""
    db = get_db()
    rows = db.execute(
        "SELECT * FROM evaluations WHERE scan_id = ? ORDER BY created_at DESC", (scan_id,)
    ).fetchall()
    db.close()
    results = []
    for r in rows:
        entry = dict(r)
        if entry.get("policy_results"):
            entry["policy_results"] = json.loads(entry["policy_results"])
        results.append(entry)
    return jsonify(results)


@app.route("/api/explain/<int:scan_id>", methods=["POST"])
def explain_scan(scan_id):
    """Use Claude API to explain scan findings with actual evidence."""
    db = get_db()
    row = db.execute("SELECT * FROM scans WHERE id = ?", (scan_id,)).fetchone()
    db.close()
    if not row:
        return jsonify({"error": "not found"}), 404

    scan = dict(row)
    policy_eval = json.loads(scan.get("policy_results", "{}")) if scan.get("policy_results") else {}

    # Build policy violations list
    violations = []
    for name, pol in policy_eval.get("policies", {}).items():
        for v in pol.get("violations", []):
            violations.append(f"[{name}] {v}")

    # Extract ACTUAL evidence from the attestation
    evidence_network = []
    evidence_processes = []
    evidence_files_sensitive = []
    evidence_files_dirs = {}
    evidence_packages = []
    evidence_dns = []
    evidence_sockets = []

    if scan.get("attestation_path") and os.path.exists(scan["attestation_path"]):
        try:
            with open(scan["attestation_path"]) as f:
                envelope = json.load(f)
            payload = json.loads(base64.b64decode(envelope["payload"]))
            atts = payload.get("predicate", {}).get("attestations", [])

            for att in atts:
                data = att.get("attestation", {})

                # command-run: extract processes, network, files
                if "processes" in data:
                    for p in data["processes"]:
                        evidence_processes.append({
                            "pid": p.get("processid"),
                            "program": p.get("program", ""),
                            "cmdline": p.get("cmdline", "")[:200],
                        })

                        # Network connections with actual IPs
                        net = p.get("network")
                        if net:
                            for c in net.get("connections", []):
                                evidence_network.append(
                                    f"{c.get('syscall','?')} {c.get('family','?')} {c.get('address','')}:{c.get('port','')} (PID {p['processid']})"
                                )
                            for d in net.get("dnsLookups", []):
                                evidence_dns.append(f"{d.get('serverAddress','')}:{d.get('serverPort','')}")
                            for s in net.get("sockets", []):
                                evidence_sockets.append(f"{s.get('family','')} {s.get('type','')} proto={s.get('protocol','')}")

                        # Files: group by dir, flag sensitive
                        for filepath in (p.get("openedfiles") or {}).keys():
                            dirname = filepath.rsplit("/", 1)[0] if "/" in filepath else "/"
                            evidence_files_dirs[dirname] = evidence_files_dirs.get(dirname, 0) + 1
                            for sens in [".ssh", ".aws", ".gnupg", ".kube", ".env", "shadow",
                                         "credentials", "id_rsa", "id_ed25519", ".pypirc"]:
                                if sens in filepath:
                                    evidence_files_sensitive.append(filepath)

                # pip-install: extract package list
                if "packages" in data:
                    for pkg in data["packages"]:
                        evidence_packages.append(f"{pkg['name']}=={pkg['version']}")

        except Exception:
            pass

    # Top file directories by count
    top_dirs = sorted(evidence_files_dirs.items(), key=lambda x: -x[1])[:20]
    total_files = sum(evidence_files_dirs.values())

    prompt = f"""You are a supply chain security forensic analyst. A pip package was installed inside a sandboxed Docker container with full ptrace system call tracing. Every syscall (execve, openat, socket, connect, bind, sendto) was intercepted and recorded. The attestation was then evaluated against Rego security policies.

Your job: analyze the EVIDENCE below and explain what happened. Cite specific IPs, file paths, process names, and policy violations. Do NOT speculate beyond what the evidence shows. If something is benign, explain WHY it's benign (e.g. "151.101.64.223 is Fastly CDN which hosts PyPI packages"). If something is suspicious, point to the specific evidence.

## Package
{scan['package']}=={scan.get('version') or 'latest'}

## Policy Evaluation
Risk Level: {scan.get('risk_level', 'UNKNOWN')} | Score: {scan.get('risk_score', 0)} | Total violations: {len(violations)}

Violations:
{chr(10).join(violations) if violations else '(none)'}

## EVIDENCE: Network Connections ({len(evidence_network)} total)
{chr(10).join(evidence_network) if evidence_network else '(none)'}

## EVIDENCE: DNS Lookups
{chr(10).join(evidence_dns) if evidence_dns else '(none)'}

## EVIDENCE: Sockets Created ({len(evidence_sockets)})
{chr(10).join(evidence_sockets) if evidence_sockets else '(none)'}

## EVIDENCE: Processes Spawned ({len(evidence_processes)})
{chr(10).join(f"PID {p['pid']}: {p['program']} | {p['cmdline']}" for p in evidence_processes)}

## EVIDENCE: Files Accessed ({total_files} total)
Top directories:
{chr(10).join(f"  {d}: {c} files" for d, c in top_dirs)}

Sensitive files accessed:
{chr(10).join(evidence_files_sensitive) if evidence_files_sensitive else '(none)'}

## EVIDENCE: Packages Installed ({len(evidence_packages)})
{chr(10).join(evidence_packages)}

---

Analyze this evidence in 3 sections:

**What happened** - Walk through exactly what the install did: which processes ran, where network connections went (identify the IPs), what files were touched. Cite the evidence.

**Risk assessment** - For each policy violation, explain whether it represents real risk or expected behavior, citing the specific evidence. If connections went to known infrastructure (PyPI CDN, Docker DNS), say so. If unknown IPs appear, flag them explicitly.

**Verdict** - Is this package safe to install? Be direct. If it's clean, say so confidently and explain why. If there are concerns, identify exactly what needs investigation."""

    try:
        import anthropic
        client = anthropic.Anthropic()
        message = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=1500,
            messages=[{"role": "user", "content": prompt}],
        )
        explanation = message.content[0].text
    except ImportError:
        explanation = "Anthropic SDK not installed. Run: pip install anthropic"
    except Exception as e:
        explanation = f"Error calling Claude API: {e}"

    return jsonify({"explanation": explanation})


@app.route("/api/verification-material")
def verification_material():
    """Public endpoint: everything needed to verify attestations independently."""
    archivista = os.environ.get("ARCHIVISTA_SERVER", "https://archivista.testifysec.io")

    # Read policy gitoids for all tiers
    policy_gitoids = {}
    gitoid_file = PROJECT_DIR / "policies" / "policy-gitoids.txt"
    if gitoid_file.exists():
        for line in gitoid_file.read_text().strip().split("\n"):
            if "=" in line:
                tier, gid = line.split("=", 1)
                policy_gitoids[tier.strip()] = gid.strip()

    # Fallback to single policy
    policy_gitoid_file = PROJECT_DIR / "policies" / "policy-gitoid.txt"
    policy_gitoid = ""
    if policy_gitoid_file.exists():
        policy_gitoid = policy_gitoid_file.read_text().strip()

    # Read public key
    pubkey_file = PROJECT_DIR / "policies" / "policy-key.pub"
    pubkey = ""
    if pubkey_file.exists():
        pubkey = pubkey_file.read_text().strip()

    from policy_eval import compute_policy_version, get_policy_info
    policy_info = get_policy_info()

    return jsonify({
        "archivista_server": archivista,
        "policy_tiers": {
            tier: {
                "gitoid": gid,
                "url": f"{archivista}/download/{gid}",
            } for tier, gid in policy_gitoids.items()
        },
        "policy_gitoid": policy_gitoid or policy_gitoids.get("standard", ""),
        "public_key_pem": pubkey,
        "policy_version": policy_info["version"],
        "policy_count": policy_info["policy_count"],
        "policies": policy_info["policies"],
    })


@app.route("/api/policies")
def list_policies():
    """Public endpoint: list all Rego policy files with metadata."""
    from policy_eval import get_policy_info
    return jsonify(get_policy_info())


@app.route("/api/policies/rego/<name>")
def get_policy_rego(name):
    """Public endpoint: download a specific Rego policy file."""
    policy_file = POLICY_DIR.parent / "policies" / "rego" / f"{name}.rego"
    if not policy_file.exists():
        return jsonify({"error": "not found"}), 404
    return Response(policy_file.read_text(), mimetype="text/plain",
                    headers={"Content-Disposition": f"attachment; filename={name}.rego"})


@app.route("/api/scans/<int:scan_id>/verify-command")
def verify_command(scan_id):
    """Generate the exact cilock verify command for a specific scan."""
    db = get_db()
    row = db.execute("SELECT * FROM scans WHERE id = ?", (scan_id,)).fetchone()
    db.close()
    if not row:
        return jsonify({"error": "not found"}), 404

    archivista = os.environ.get("ARCHIVISTA_SERVER", "https://archivista.testifysec.io")
    scan_gitoid = row["gitoid"] or None

    # Load policy tiers
    policy_gitoids = {}
    gitoid_file = PROJECT_DIR / "policies" / "policy-gitoids.txt"
    if gitoid_file.exists():
        for line in gitoid_file.read_text().strip().split("\n"):
            if "=" in line:
                tier, gid = line.split("=", 1)
                policy_gitoids[tier.strip()] = gid.strip()

    # Get subject digest from attestation
    subject = None
    if row["attestation_path"] and os.path.exists(row["attestation_path"]):
        try:
            with open(row["attestation_path"]) as f:
                env = json.load(f)
            payload = json.loads(base64.b64decode(env["payload"]))
            sub = payload.get("subject", [{}])[0]
            digest = list(sub.get("digest", {}).values())[0] if sub.get("digest") else None
            if digest:
                subject = f"sha256:{digest}"
        except Exception:
            pass

    # Read public key
    pubkey_file = PROJECT_DIR / "policies" / "policy-key.pub"
    pubkey = pubkey_file.read_text().strip() if pubkey_file.exists() else ""

    # Build commands for each tier
    tiers = {}
    for tier, pgid in policy_gitoids.items():
        cmd = f"""#!/bin/bash
# pip-witness verification: {row['package']}=={row['version'] or 'latest'}
# Policy tier: {tier}

# 1. Save the public key
cat > policy.pub << 'PUBKEY'
{pubkey}
PUBKEY

# 2. Download attestation and policy from Archivista
curl -sL {archivista}/download/{scan_gitoid} -o attestation.json
curl -sL {archivista}/download/{pgid} -o policy-{tier}.json

# 3. Verify
cilock verify \\
  --policy policy-{tier}.json \\
  --publickey policy.pub \\
  --attestations attestation.json \\
  --subjects "{subject or 'sha256:SUBJECT_DIGEST'}"

echo "Verification complete. Exit code: $?"
"""
        tiers[tier] = {"gitoid": pgid, "command": cmd}

    return jsonify({
        "package": row["package"],
        "version": row["version"],
        "scan_gitoid": scan_gitoid,
        "subject": subject,
        "archivista_server": archivista,
        "tiers": tiers,
        "attestation_url": f"{archivista}/download/{scan_gitoid}" if scan_gitoid else None,
    })


@app.route("/api/policies/coverage")
def policy_coverage():
    """Show how many attestations have been evaluated with each policy version."""
    from policy_eval import compute_policy_version
    current_version = compute_policy_version()

    db = get_db()
    total_scans = db.execute("SELECT COUNT(*) FROM scans WHERE status='completed' AND attestation_path IS NOT NULL").fetchone()[0]

    # Count evals per policy version
    rows = db.execute("""
        SELECT policy_version, COUNT(DISTINCT scan_id) as count
        FROM evaluations GROUP BY policy_version
    """).fetchall()
    versions = {r["policy_version"]: r["count"] for r in rows}

    # How many scans have current version eval?
    current_count = db.execute("""
        SELECT COUNT(DISTINCT scan_id) FROM evaluations WHERE policy_version = ?
    """, (current_version,)).fetchone()[0]

    unevaluated = total_scans - current_count
    db.close()

    return jsonify({
        "current_policy_version": current_version,
        "total_scans": total_scans,
        "evaluated_with_current": current_count,
        "unevaluated": unevaluated,
        "pct_current": round(current_count / max(total_scans, 1) * 100, 1),
        "versions": versions,
    })


@app.route("/api/reeval-unevaluated", methods=["POST"])
def reeval_unevaluated():
    """Re-evaluate only scans that haven't been evaluated with the current policy version."""
    from policy_eval import compute_policy_version
    current = compute_policy_version()

    db = get_db()
    rows = db.execute("""
        SELECT s.id FROM scans s
        WHERE s.status = 'completed' AND s.attestation_path IS NOT NULL
        AND s.id NOT IN (SELECT scan_id FROM evaluations WHERE policy_version = ?)
    """, (current,)).fetchall()
    db.close()

    for r in rows:
        reeval_queue.put(r["id"])

    return jsonify({"count": len(rows), "policy_version": current})


@app.route("/api/pypi-browse")
def pypi_browse():
    """Browse all known PyPI packages with scan status."""
    q = request.args.get("q", "").strip()
    filter_type = request.args.get("filter", "")  # popular, unscanned, scanned
    limit = min(int(request.args.get("limit", 50)), 200)
    offset = int(request.args.get("offset", 0))

    db = get_db()

    # Check if the pypi_packages table exists
    table_exists = db.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name='pypi_packages'"
    ).fetchone()

    if not table_exists:
        db.close()
        return jsonify({"packages": [], "total": 0, "message": "Run pypi_firehose.py --seed to populate"})

    conditions = []
    params = []
    if q:
        conditions.append("p.name LIKE ?")
        params.append(f"%{q}%")
    if filter_type == "popular":
        conditions.append("p.is_popular = 1")
    elif filter_type == "unscanned":
        conditions.append("s.id IS NULL")
    elif filter_type == "scanned":
        conditions.append("s.id IS NOT NULL")

    where = " AND ".join(conditions) if conditions else "1=1"

    total = db.execute(f"""
        SELECT COUNT(DISTINCT p.name) FROM pypi_packages p
        LEFT JOIN scans s ON s.package = p.name AND s.status = 'completed'
        WHERE {where}
    """, params).fetchone()[0]

    rows = db.execute(f"""
        SELECT p.name, p.latest_version, p.download_count, p.is_popular,
               p.last_release, p.summary,
               COUNT(DISTINCT s.id) as scan_count,
               MAX(s.completed_at) as last_scanned
        FROM pypi_packages p
        LEFT JOIN scans s ON s.package = p.name AND s.status = 'completed'
        WHERE {where}
        GROUP BY p.name
        ORDER BY p.download_count DESC, p.name
        LIMIT ? OFFSET ?
    """, params + [limit, offset]).fetchall()

    db.close()
    return jsonify({
        "packages": [dict(r) for r in rows],
        "total": total,
        "limit": limit,
        "offset": offset,
    })


@app.route("/api/pypi/<package>/<version>/provenance")
def pypi_provenance(package, version):
    """Verify source provenance: PyPI → PEP 740 → GitHub commit → tag."""
    import urllib.request as ur

    result = {
        "package": package, "version": version,
        "pep740": None, "github": None, "verified": False, "chain": [],
    }

    # Step 1: Get PEP 740 provenance from PyPI
    try:
        # Get filename first
        resp = ur.urlopen(f"https://pypi.org/pypi/{package}/{version}/json", timeout=10)
        pypi_data = json.loads(resp.read())
        filename = None
        for u in pypi_data.get("urls", []):
            if u.get("packagetype") == "bdist_wheel":
                filename = u["filename"]
                break
        if not filename and pypi_data.get("urls"):
            filename = pypi_data["urls"][0]["filename"]

        if filename:
            prov_url = f"https://pypi.org/integrity/{package}/{version}/{filename}/provenance"
            resp = ur.urlopen(prov_url, timeout=10)
            prov = json.loads(resp.read())

            for bundle in prov.get("attestation_bundles", []):
                pub = bundle.get("publisher", {})
                result["pep740"] = {
                    "publisher_kind": pub.get("kind"),
                    "repository": pub.get("repository"),
                    "workflow": pub.get("workflow"),
                    "environment": pub.get("environment"),
                }
                result["chain"].append(f"PEP 740: Published by {pub.get('kind')} from {pub.get('repository')} via {pub.get('workflow')}")

                # Extract commit SHA from certificate
                import re
                for att in bundle.get("attestations", []):
                    cert_b64 = att.get("verification_material", {}).get("certificate", "")
                    if cert_b64:
                        cert_der = base64.b64decode(cert_b64)
                        cert_text = cert_der.decode("latin-1")
                        shas = set(re.findall(r"[0-9a-f]{40}", cert_text))
                        if shas:
                            result["pep740"]["commit_sha"] = list(shas)[0]
                            result["chain"].append(f"Sigstore cert: commit {list(shas)[0]}")
    except Exception as e:
        result["chain"].append(f"PEP 740: Not available ({e})")

    # Step 2: Verify against GitHub
    if result.get("pep740") and result["pep740"].get("repository") and result["pep740"].get("commit_sha"):
        repo = result["pep740"]["repository"]
        commit = result["pep740"]["commit_sha"]
        try:
            # Check commit exists
            resp = ur.urlopen(f"https://api.github.com/repos/{repo}/commits/{commit}", timeout=10)
            commit_data = json.loads(resp.read())
            result["github"] = {
                "commit_exists": True,
                "commit_sha": commit_data["sha"],
                "author": commit_data["commit"]["author"]["name"],
                "author_email": commit_data["commit"]["author"]["email"],
                "date": commit_data["commit"]["author"]["date"],
                "message": commit_data["commit"]["message"].split("\n")[0][:100],
            }
            result["chain"].append(f"GitHub: commit exists, author: {commit_data['commit']['author']['name']}")

            # Check tag — try v{version}, {version}, and release/{version}
            tag_found = False
            for tag_fmt in [f"v{version}", version, f"release/{version}"]:
                try:
                    resp = ur.urlopen(f"https://api.github.com/repos/{repo}/git/refs/tags/{tag_fmt}", timeout=10)
                    tag_data = json.loads(resp.read())
                    tag_sha = tag_data.get("object", {}).get("sha", "")
                    tag_type = tag_data.get("object", {}).get("type", "")
                    # Dereference annotated tags to get the actual commit
                    if tag_type == "tag":
                        try:
                            resp2 = ur.urlopen(f"https://api.github.com/repos/{repo}/git/tags/{tag_sha}", timeout=10)
                            tag_obj = json.loads(resp2.read())
                            tag_sha = tag_obj.get("object", {}).get("sha", tag_sha)
                        except Exception:
                            pass
                    tag_matches = tag_sha == commit
                    result["github"]["tag"] = tag_fmt
                    result["github"]["tag_sha"] = tag_sha
                    result["github"]["tag_matches_commit"] = tag_matches
                    tag_found = True
                    if tag_matches:
                        result["chain"].append(f"GitHub: tag {tag_fmt} → commit {commit[:12]} ✓")
                        result["verified"] = True
                    else:
                        result["chain"].append(f"GitHub: tag {tag_fmt} → {tag_sha[:12]}, expected {commit[:12]}")
                    break
                except Exception:
                    continue
            if not tag_found:
                result["chain"].append(f"GitHub: no matching tag found for version {version}")
        except Exception as e:
            result["chain"].append(f"GitHub: could not verify commit ({e})")

    return jsonify(result)


@app.route("/api/pypi/<package>/versions")
def pypi_versions(package):
    """Fetch all versions of a package from PyPI with scan status."""
    import urllib.request as ur
    try:
        resp = ur.urlopen(f"https://pypi.org/pypi/{package}/json", timeout=10)
        data = json.loads(resp.read())
    except Exception as e:
        return jsonify({"error": str(e)}), 502

    releases = data.get("releases", {})
    info = data.get("info", {})

    # Sort by upload date
    versions = []
    for ver, dists in releases.items():
        if not dists:
            continue
        upload_dates = [d.get("upload_time", "") for d in dists if d.get("upload_time")]
        has_sdist = any(d.get("packagetype") == "sdist" for d in dists)
        has_wheel = any(d.get("packagetype") == "bdist_wheel" for d in dists)
        sha256s = [d.get("digests", {}).get("sha256", "") for d in dists if d.get("digests", {}).get("sha256")]

        versions.append({
            "version": ver,
            "upload_date": max(upload_dates) if upload_dates else "",
            "has_sdist": has_sdist,
            "has_wheel": has_wheel,
            "dist_count": len(dists),
            "sha256": sha256s[0] if sha256s else "",
        })

    versions.sort(key=lambda x: x["upload_date"], reverse=True)

    # Check which versions we've already scanned
    db = get_db()
    scanned = {}
    rows = db.execute(
        "SELECT version, status FROM scans WHERE package = ?", (package,)
    ).fetchall()
    for r in rows:
        scanned[r["version"]] = r["status"]
    db.close()

    for v in versions:
        v["scan_status"] = scanned.get(v["version"], "not_scanned")

    return jsonify({
        "package": package,
        "summary": info.get("summary", ""),
        "author": info.get("author", ""),
        "home_page": info.get("home_page", ""),
        "versions": versions,
        "total_versions": len(versions),
    })


@app.route("/api/search")
def search_scans():
    """Search scans by package name, risk level, or network activity."""
    q = request.args.get("q", "").strip()
    risk = request.args.get("risk", "")
    has_network = request.args.get("has_network", "")
    limit = min(int(request.args.get("limit", 100)), 500)
    offset = int(request.args.get("offset", 0))

    db = get_db()
    conditions = []
    params = []

    if q:
        conditions.append("package LIKE ?")
        params.append(f"%{q}%")
    if risk:
        conditions.append("risk_level = ?")
        params.append(risk.upper())
    if has_network == "true":
        conditions.append("network_connections > 0")
    if has_network == "suspicious":
        # More than just PyPI CDN connections (> 9 is typical baseline)
        conditions.append("network_connections > 12")

    where = " AND ".join(conditions) if conditions else "1=1"
    total = db.execute(f"SELECT COUNT(*) FROM scans WHERE {where}", params).fetchone()[0]
    rows = db.execute(
        f"SELECT * FROM scans WHERE {where} ORDER BY created_at DESC LIMIT ? OFFSET ?",
        params + [limit, offset]
    ).fetchall()
    db.close()

    return jsonify({"results": [dict(r) for r in rows], "total": total, "limit": limit, "offset": offset})


@app.route("/api/stats")
def get_stats():
    db = get_db()
    evals_pending = eval_queue.qsize() + reeval_queue.qsize()
    stats = {
        "total_scans": db.execute("SELECT COUNT(*) FROM scans").fetchone()[0],
        "completed_scans": db.execute("SELECT COUNT(*) FROM scans WHERE status='completed'").fetchone()[0],
        "failed_scans": db.execute("SELECT COUNT(*) FROM scans WHERE status='failed'").fetchone()[0],
        "queued_scans": db.execute("SELECT COUNT(*) FROM scans WHERE status='queued' OR status='running'").fetchone()[0],
        "evaluated": db.execute("SELECT COUNT(DISTINCT scan_id) FROM evaluations").fetchone()[0],
        "evals_pending": evals_pending,
        "subscriptions": db.execute("SELECT COUNT(*) FROM subscriptions").fetchone()[0],
        "high_risk": db.execute("SELECT COUNT(DISTINCT scan_id) FROM evaluations WHERE risk_level IN ('HIGH','CRITICAL')").fetchone()[0],
        "total_network_connections": db.execute("SELECT SUM(network_connections) FROM scans WHERE status='completed'").fetchone()[0] or 0,
        "total_files_opened": db.execute("SELECT SUM(files_opened) FROM scans WHERE status='completed'").fetchone()[0] or 0,
    }

    # ETA calculation from rolling scan rate
    now = time.time()
    recent = [t for t in scan_completions if now - t < 600]  # last 10 minutes
    if len(recent) >= 2:
        elapsed = recent[-1] - recent[0]
        rate = len(recent) / (elapsed / 60) if elapsed > 0 else 0  # scans per minute
        stats["scans_per_minute"] = round(rate, 1)
        queued = stats["queued_scans"]
        if rate > 0 and queued > 0:
            eta_minutes = queued / rate
            stats["eta_minutes"] = round(eta_minutes, 1)
            if eta_minutes < 60:
                stats["eta_display"] = f"{int(eta_minutes)}m"
            elif eta_minutes < 1440:
                stats["eta_display"] = f"{eta_minutes/60:.1f}h"
            else:
                stats["eta_display"] = f"{eta_minutes/1440:.1f}d"
        else:
            stats["eta_minutes"] = 0
            stats["eta_display"] = "idle"
    else:
        stats["scans_per_minute"] = 0
        stats["eta_minutes"] = 0
        stats["eta_display"] = "calculating..."

    db.close()
    return jsonify(stats)


# ---------------------------------------------------------------------------
# Version Diff Engine
# ---------------------------------------------------------------------------

def load_attestation_data(attestation_path: str) -> dict:
    """Load and parse attestation envelope into normalized data."""
    if not attestation_path or not os.path.exists(attestation_path):
        return {}
    try:
        with open(attestation_path) as f:
            envelope = json.load(f)
        payload = json.loads(base64.b64decode(envelope["payload"]))
        atts = payload.get("predicate", {}).get("attestations", [])

        result = {"network": [], "files": set(), "processes": [], "packages": [],
                  "sockets": [], "dns": [], "setup_py": []}

        for att in atts:
            data = att.get("attestation", {})
            if "processes" in data:
                for p in data["processes"]:
                    result["processes"].append({
                        "pid": p.get("processid"),
                        "program": p.get("program", ""),
                        "cmdline": p.get("cmdline", ""),
                    })
                    result["files"].update(p.get("openedfiles", {}).keys())
                    net = p.get("network")
                    if net:
                        for c in net.get("connections", []):
                            result["network"].append(
                                f"{c.get('syscall','')}:{c.get('family','')}:{c.get('address','')}:{c.get('port','')}"
                            )
                        for s in net.get("sockets", []):
                            result["sockets"].append(f"{s.get('family','')}:{s.get('type','')}:{s.get('protocol','')}")
                        for d in net.get("dnsLookups", []):
                            result["dns"].append(f"{d.get('serverAddress','')}:{d.get('serverPort','')}")
            if "packages" in data:
                for pkg in data["packages"]:
                    result["packages"].append(f"{pkg['name']}=={pkg['version']}")
            if "setupPyAnalysis" in data:
                for a in data["setupPyAnalysis"]:
                    result["setup_py"].extend(a.get("suspiciousCalls", []))
                    result["setup_py"].extend(a.get("networkImports", []))

        result["files"] = sorted(result["files"])
        return result
    except Exception:
        return {}


def compute_diff(old_data: dict, new_data: dict) -> dict:
    """Compute the diff between two attestation datasets."""
    diff = {
        "network": {
            "added": sorted(set(new_data.get("network", [])) - set(old_data.get("network", []))),
            "removed": sorted(set(old_data.get("network", [])) - set(new_data.get("network", []))),
        },
        "files": {
            "added": sorted(set(new_data.get("files", [])) - set(old_data.get("files", []))),
            "removed": sorted(set(old_data.get("files", [])) - set(new_data.get("files", []))),
        },
        "packages": {
            "added": sorted(set(new_data.get("packages", [])) - set(old_data.get("packages", []))),
            "removed": sorted(set(old_data.get("packages", [])) - set(new_data.get("packages", []))),
        },
        "sockets": {
            "added": sorted(set(new_data.get("sockets", [])) - set(old_data.get("sockets", []))),
            "removed": sorted(set(old_data.get("sockets", [])) - set(new_data.get("sockets", []))),
        },
        "dns": {
            "added": sorted(set(new_data.get("dns", [])) - set(old_data.get("dns", []))),
            "removed": sorted(set(old_data.get("dns", [])) - set(new_data.get("dns", []))),
        },
        "processes": {
            "added": [p for p in new_data.get("processes", [])
                      if p["program"] not in {x["program"] for x in old_data.get("processes", [])}],
            "removed": [p for p in old_data.get("processes", [])
                        if p["program"] not in {x["program"] for x in new_data.get("processes", [])}],
        },
        "setup_py": {
            "added": sorted(set(new_data.get("setup_py", [])) - set(old_data.get("setup_py", []))),
            "removed": sorted(set(old_data.get("setup_py", [])) - set(new_data.get("setup_py", []))),
        },
    }

    # Calculate a change risk score
    risk = 0
    risk += len(diff["network"]["added"]) * 10  # New network connections are very suspicious
    risk += len(diff["files"]["added"]) * 0.1    # New files are normal-ish
    risk += len(diff["packages"]["added"]) * 3    # New dependencies expand attack surface
    risk += len(diff["processes"]["added"]) * 8   # New subprocesses are suspicious
    risk += len(diff["dns"]["added"]) * 7         # New DNS targets
    risk += len(diff["setup_py"]["added"]) * 15   # New suspicious patterns in setup.py

    # Flag files touching sensitive paths
    sensitive_files = [f for f in diff["files"]["added"]
                       if any(s in f for s in [".ssh", ".aws", ".gnupg", ".kube", "/etc/shadow",
                                               "/etc/passwd", ".env", "credentials", ".pypirc"])]
    risk += len(sensitive_files) * 20
    diff["sensitive_files_added"] = sensitive_files

    diff["change_risk_score"] = int(risk)
    diff["change_risk_level"] = (
        "CRITICAL" if risk >= 50 else
        "HIGH" if risk >= 30 else
        "MEDIUM" if risk >= 15 else
        "LOW" if risk >= 5 else
        "CLEAN"
    )

    return diff


@app.route("/api/diff/<int:old_scan_id>/<int:new_scan_id>")
def diff_scans(old_scan_id, new_scan_id):
    """Compare two scan attestations and return the diff."""
    db = get_db()
    old_row = db.execute("SELECT * FROM scans WHERE id = ?", (old_scan_id,)).fetchone()
    new_row = db.execute("SELECT * FROM scans WHERE id = ?", (new_scan_id,)).fetchone()
    db.close()

    if not old_row or not new_row:
        return jsonify({"error": "scan not found"}), 404

    old_data = load_attestation_data(old_row["attestation_path"])
    new_data = load_attestation_data(new_row["attestation_path"])

    diff = compute_diff(old_data, new_data)
    diff["old_scan"] = {"id": old_scan_id, "package": old_row["package"], "version": old_row["version"]}
    diff["new_scan"] = {"id": new_scan_id, "package": new_row["package"], "version": new_row["version"]}

    return jsonify(diff)


@app.route("/api/package/<package>/versions")
def package_versions(package):
    """Get all scanned versions of a package for diff selection."""
    db = get_db()
    rows = db.execute(
        """SELECT id, package, version, status, risk_level, risk_score,
                  network_connections, files_opened, completed_at
           FROM scans WHERE package = ? AND status = 'completed'
           ORDER BY completed_at DESC""",
        (package,)
    ).fetchall()
    db.close()
    return jsonify([dict(r) for r in rows])


# ---------------------------------------------------------------------------
# Scanner Thread
# ---------------------------------------------------------------------------

def start_scan_internal(package, version=None, priority=50, source="internal"):
    db = get_db()
    scan_id = db.execute(
        "INSERT INTO scans (package, version, status) VALUES (?, ?, 'queued') RETURNING id",
        (package, version)
    ).fetchone()[0]
    db.commit()
    db.close()
    scan_queue.put((priority, time.time(), package, version, scan_id))
    return scan_id


def run_scan(package: str, version: str | None, scan_id: int):
    """Execute a pip-witness Docker scan. Only produces attestation + metrics.
    Policy evaluation happens separately in the eval queue."""
    db = get_db()
    pkg_spec = f"{package}=={version}" if version else package

    db.execute("UPDATE scans SET status='running', started_at=datetime('now') WHERE id=?", (scan_id,))
    db.commit()
    broadcast_event("scan_started", {"id": scan_id, "package": package, "version": version})

    try:
        scan_att_dir = ATTESTATION_DIR / f"scan-{scan_id}"
        scan_att_dir.mkdir(parents=True, exist_ok=True)
        step_name = "pip-install"  # Must match policy step name for cilock verify

        archivista_server = os.environ.get("ARCHIVISTA_SERVER", "https://archivista.testifysec.io")

        docker_env = [
            "-e", f"STEP_NAME={step_name}",
            "-e", f"ARCHIVISTA_SERVER={archivista_server}",
        ]

        def progress(phase, detail=""):
            broadcast_event("scan_progress", {
                "id": scan_id, "package": package, "phase": phase, "detail": detail
            })

        progress("starting", "Launching sandbox container...")

        # Stream output from Docker to detect phases
        # Shared pip cache across containers (hash-addressed, safe to share)
        pip_cache_dir = ATTESTATION_DIR.parent / "cache" / "pip"
        pip_cache_dir.mkdir(parents=True, exist_ok=True)

        proc = subprocess.Popen(
            ["docker", "run", "--rm",
             "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined",
             "-v", f"{scan_att_dir}:/attestations",
             "-v", f"{pip_cache_dir}:/root/.cache/pip",
             *docker_env,
             "pip-witness:latest", pkg_spec],
            stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
        )

        SCAN_TIMEOUT = int(os.environ.get("SCAN_TIMEOUT", "120"))  # 2 minutes max
        all_output = []
        gitoid = None
        deadline = time.time() + SCAN_TIMEOUT

        try:
            for line in proc.stdout:
                if time.time() > deadline:
                    proc.kill()
                    raise subprocess.TimeoutExpired(cmd="pip-witness", timeout=SCAN_TIMEOUT)
                all_output.append(line)
                stripped = line.strip()
                if "[1/" in stripped:
                    progress("download", "Downloading package source...")
                elif "[2/" in stripped:
                    progress("analyze", "Static analysis of package contents...")
                elif "[3/" in stripped:
                    progress("trace", "Installing with ptrace instrumentation...")
                elif "[4/" in stripped or "[5/" in stripped:
                    progress("complete", "Finalizing attestation...")
                elif "Collecting" in stripped and "from" not in stripped.lower():
                    progress("trace", f"Installing: {stripped[:60]}")
                elif "Starting" in stripped and "attestor" in stripped:
                    progress("attest", stripped[:60])
                if "gitoid" in stripped.lower() or "stored" in stripped.lower():
                    parts = stripped.split()
                    for p in parts:
                        if len(p) == 64 and all(c in "0123456789abcdef" for c in p):
                            gitoid = p
            proc.wait(timeout=10)
        except subprocess.TimeoutExpired:
            proc.kill()
            proc.wait()
            raise

        progress("metrics", "Extracting metrics from attestation...")

        # Find the primary DSSE attestation (exclude sbom/sarif/import variants)
        all_json = list(scan_att_dir.glob(f"{step_name}-*.json"))
        att_path = None
        for candidate in sorted(all_json):
            fname = candidate.name
            # Primary attestation: pip-install-<ts>.json (no sbom/sarif/import in name)
            if 'sbom' not in fname and 'sarif' not in fname and 'import' not in fname:
                att_path = str(candidate)
                break
        if not att_path and all_json:
            # Fallback: prefer DSSE envelopes over raw SBOMs
            for candidate in sorted(all_json):
                if not str(candidate).endswith('.cdx.json'):
                    try:
                        with open(candidate) as _cf:
                            _cd = json.load(_cf)
                        if 'payload' in _cd:
                            att_path = str(candidate)
                            break
                    except Exception:
                        pass

        pre_analysis = scan_att_dir / "pre-analysis.json"
        pre_path = str(pre_analysis) if pre_analysis.exists() else None
        metrics = extract_metrics(att_path) if att_path else {}

        # Upload ALL DSSE collections to Archivista (Issue #5)
        import urllib.request as ur
        import hashlib as _hl
        dsse_files = []
        for candidate in scan_att_dir.glob(f"{step_name}*.json"):
            fname = candidate.name
            # Skip non-DSSE files (raw SBOMs end in .cdx.json)
            if fname.endswith('.cdx.json'):
                continue
            try:
                with open(candidate) as _f:
                    _d = json.load(_f)
                if 'payload' in _d and 'signatures' in _d:
                    dsse_files.append(str(candidate))
            except Exception:
                continue

        for dsse_path in dsse_files:
            try:
                with open(dsse_path, "rb") as f:
                    att_bytes = f.read()
                req = ur.Request(
                    f"{archivista_server}/upload",
                    data=att_bytes,
                    headers={"Content-Type": "application/json"},
                    method="POST",
                )
                resp = ur.urlopen(req, timeout=15)
                resp_data = json.loads(resp.read())
                uploaded_gitoid = resp_data.get("gitoid", "")
                if uploaded_gitoid:
                    step_label = os.path.basename(dsse_path).split('-')[2] if 'import' in dsse_path or 'sbom' in dsse_path or 'sarif' in dsse_path else "install"
                    progress("archivista", f"Stored {step_label}: {uploaded_gitoid[:16]}...")
                # If this is the primary attestation, capture its gitoid
                if dsse_path == att_path and uploaded_gitoid:
                    gitoid = uploaded_gitoid
            except Exception as e:
                print(f"Archivista upload failed for {os.path.basename(dsse_path)}: {e}", file=sys.stderr)

        # Compute gitoid locally if Archivista didn't return one for primary
        if not gitoid and att_path:
            try:
                with open(att_path, "rb") as f:
                    content = f.read()
                header = f"blob {len(content)}\0".encode()
                gitoid = _hl.sha256(header + content).hexdigest()
                progress("archivista", f"Computed gitoid locally: {gitoid[:16]}...")
            except Exception:
                pass

        db.execute("""
            UPDATE scans SET status='completed', completed_at=datetime('now'),
                attestation_path=?, gitoid=?, pre_analysis_path=?,
                network_connections=?, files_opened=?, processes_spawned=?,
                dns_lookups=?, sockets_created=?, packages_installed=?
            WHERE id=?
        """, (att_path, gitoid, pre_path,
              metrics.get("network_connections", 0), metrics.get("files_opened", 0),
              metrics.get("processes_spawned", 0), metrics.get("dns_lookups", 0),
              metrics.get("sockets_created", 0), metrics.get("packages_installed", 0),
              scan_id))
        db.commit()

        # Post to eval queue for policy evaluation
        if att_path:
            eval_queue.put(scan_id)

        # Track completion time for ETA calculation
        scan_completions.append(time.time())
        # Keep last 100 completions for rate calculation
        while len(scan_completions) > 100:
            scan_completions.pop(0)

        broadcast_event("scan_completed", {
            "id": scan_id, "package": package, "version": version,
            "network_connections": metrics.get("network_connections", 0),
            "files_opened": metrics.get("files_opened", 0),
        })

    except subprocess.TimeoutExpired:
        db.execute("UPDATE scans SET status='failed', error='timeout', completed_at=datetime('now') WHERE id=?", (scan_id,))
        db.commit()
        broadcast_event("scan_failed", {"id": scan_id, "package": package, "error": "timeout"})
    except Exception as e:
        db.execute("UPDATE scans SET status='failed', error=?, completed_at=datetime('now') WHERE id=?", (str(e), scan_id))
        db.commit()
        broadcast_event("scan_failed", {"id": scan_id, "package": package, "error": str(e)})
    finally:
        db.close()


def evaluate_scan(scan_id: int):
    """Evaluate a scan using cilock verify against signed policies.
    Produces a policyverify attestation per tier and stores results."""
    db = get_db()
    row = db.execute("SELECT attestation_path, gitoid, package, version FROM scans WHERE id=?", (scan_id,)).fetchone()
    if not row or not row["attestation_path"]:
        db.close()
        return

    att_path = row["attestation_path"]
    package = row["package"]
    version = row["version"]

    # Load policy tiers
    policy_gitoids = {}
    gitoid_file = PROJECT_DIR / "policies" / "policy-gitoids.txt"
    if gitoid_file.exists():
        for line in gitoid_file.read_text().strip().split("\n"):
            if "=" in line:
                tier, gid = line.split("=", 1)
                policy_gitoids[tier.strip()] = gid.strip()

    pubkey_path = PROJECT_DIR / "policies" / "policy-key.pub"

    # Get subject digest from attestation
    subject = None
    try:
        with open(att_path) as f:
            env = json.load(f)
        payload = json.loads(base64.b64decode(env["payload"]))
        sub = payload.get("subject", [{}])[0]
        digest = list(sub.get("digest", {}).values())[0] if sub.get("digest") else None
        if digest:
            subject = f"sha256:{digest}"
    except Exception:
        pass

    if not subject or not policy_gitoids or not pubkey_path.exists():
        # Fallback to OPA standalone if cilock verify prerequisites aren't met
        try:
            from policy_eval import evaluate_all_policies
            results = evaluate_all_policies(att_path)
            policy_ver = results.get("policy_version", "unknown")
            db.execute("""
                INSERT INTO evaluations (scan_id, risk_score, risk_level, policy_results, policy_version)
                VALUES (?, ?, ?, ?, ?)
            """, (scan_id, results.get("risk_score", 0), results.get("risk_level", "CLEAN"),
                  json.dumps(results), policy_ver))
            db.commit()
            broadcast_event("eval_completed", {
                "scan_id": scan_id, "package": package, "version": version,
                "risk_score": results.get("risk_score", 0), "risk_level": results.get("risk_level", "CLEAN"),
                "total_violations": results.get("total_violations", 0),
            })
        except Exception as e:
            print(f"OPA fallback eval error for scan {scan_id}: {e}", file=sys.stderr)
        finally:
            db.close()
        return

    # Run cilock verify natively (no Docker needed for policy evaluation)
    cilock_bin = str(PROJECT_DIR / "bin" / "cilock")
    if not os.path.exists(cilock_bin):
        cilock_bin = "cilock"  # fall back to PATH

    archivista = os.environ.get("ARCHIVISTA_SERVER", "https://archivista.testifysec.io")
    tier_results = {}
    overall_level = "CLEAN"

    for tier, pgid in policy_gitoids.items():
        try:
            # Use locally cached policy file (stored in viewer dir, not policies/ which gets copied during Docker build)
            import urllib.request as ur
            cache_dir = APP_DIR / "policy_cache"
            cache_dir.mkdir(exist_ok=True)
            policy_path = str(cache_dir / f"policy-{tier}.json")
            if not os.path.exists(policy_path):
                ur.urlretrieve(f"{archivista}/download/{pgid}", policy_path)

            # Run cilock verify natively — no Docker overhead
            result = subprocess.run(
                [cilock_bin, "verify",
                 "--policy", policy_path,
                 "--publickey", str(pubkey_path),
                 "--attestations", att_path,
                 "--subjects", subject],
                capture_output=True, text=True, timeout=30
            )

            passed = result.returncode == 0
            # Extract reject reasons from stderr
            reasons = []
            for line in result.stderr.split("\n"):
                if "collection rejected" in line.lower() or "denied" in line.lower():
                    reasons.append(line.strip())
                elif "Verification succeeded" in line:
                    pass

            tier_results[tier] = {
                "passed": passed,
                "reasons": reasons,
                "exit_code": result.returncode,
            }

            if not passed:
                if tier == "critical":
                    overall_level = "CRITICAL"
                elif tier == "standard" and overall_level not in ("CRITICAL",):
                    overall_level = "HIGH"
                elif tier == "strict" and overall_level == "CLEAN":
                    overall_level = "MEDIUM"

            # DO NOT delete cached policy files — they're reused across evaluations

        except Exception as e:
            tier_results[tier] = {"passed": False, "error": str(e), "reasons": []}

    # Risk level is determined by which tiers pass
    # No synthetic scores — just tier pass/fail
    critical_pass = tier_results.get("critical", {}).get("passed", True)
    standard_pass = tier_results.get("standard", {}).get("passed", True)
    strict_pass = tier_results.get("strict", {}).get("passed", True)

    # Check for errors (file not found etc) — don't count as failures
    has_errors = any(t.get("error") for t in tier_results.values())

    if not critical_pass and not has_errors:
        overall_level = "CRITICAL"
        risk_score = 3  # fails critical
    elif not standard_pass and not has_errors:
        overall_level = "HIGH"
        risk_score = 2  # fails standard
    elif not strict_pass and not has_errors:
        overall_level = "MEDIUM"
        risk_score = 1  # fails strict only
    elif has_errors:
        overall_level = "UNEVALUATED"
        risk_score = 0  # can't determine
    else:
        overall_level = "CLEAN"
        risk_score = 0

    total_violations = sum(len(t.get("reasons", [])) for t in tier_results.values())

    from policy_eval import compute_policy_version
    policy_ver = compute_policy_version()

    results_json = json.dumps({
        "tiers": tier_results,
        "overall_level": overall_level,
        "risk_score": risk_score,
        "total_violations": total_violations,
        "evaluation_method": "cilock_verify",
        "policy_version": policy_ver,
    })

    db.execute("""
        INSERT INTO evaluations (scan_id, risk_score, risk_level, policy_results, policy_version)
        VALUES (?, ?, ?, ?, ?)
    """, (scan_id, risk_score, overall_level, results_json, policy_ver))
    db.commit()

    broadcast_event("eval_completed", {
        "scan_id": scan_id, "package": package, "version": version,
        "risk_score": risk_score, "risk_level": overall_level,
        "total_violations": total_violations,
        "tiers": {t: {"passed": r.get("passed")} for t, r in tier_results.items()},
    })

    db.close()


def extract_metrics(attestation_path: str) -> dict:
    """Extract key metrics from an attestation envelope."""
    metrics = {
        "network_connections": 0, "files_opened": 0, "processes_spawned": 0,
        "dns_lookups": 0, "sockets_created": 0, "packages_installed": 0,
    }
    try:
        with open(attestation_path) as f:
            envelope = json.load(f)
        payload = json.loads(base64.b64decode(envelope["payload"]))
        atts = payload.get("predicate", {}).get("attestations", [])

        for att in atts:
            data = att.get("attestation", {})
            if "processes" in data:
                procs = data["processes"]
                metrics["processes_spawned"] = len(procs)
                for p in procs:
                    metrics["files_opened"] += len(p.get("openedfiles", {}))
                    net = p.get("network")
                    if net:
                        metrics["sockets_created"] += len(net.get("sockets", []))
                        metrics["network_connections"] += len(net.get("connections", []))
                        metrics["dns_lookups"] += len(net.get("dnsLookups", []))
            if "totalInstalled" in data:
                metrics["packages_installed"] = data["totalInstalled"]
    except Exception:
        pass
    return metrics


# ---------------------------------------------------------------------------
# Three-Queue Worker Architecture
#
# Queue 1: scan_queue -> Docker containers (slow, expensive, capped at 8)
# Queue 2: eval_queue -> OPA policy evaluation on new attestations (fast)
# Queue 3: reeval_queue -> OPA re-evaluation with updated policies (fast)
#
# Attestations are monotonic (immutable). Evaluations are not.
# ---------------------------------------------------------------------------

SCAN_WORKERS = int(os.environ.get("SCAN_WORKERS", "8"))   # Docker containers cap
EVAL_WORKERS = int(os.environ.get("EVAL_WORKERS", "4"))    # OPA eval threads

active_scans = 0
active_scans_lock = threading.Lock()


def scan_worker():
    """Pulls from scan_queue, runs Docker containers."""
    global active_scans
    while True:
        try:
            priority, ts, package, version, scan_id = scan_queue.get(timeout=5)
        except queue.Empty:
            continue
        with active_scans_lock:
            active_scans += 1
        try:
            run_scan(package, version, scan_id)
        except Exception as e:
            print(f"Scan worker error: {e}", file=sys.stderr)
        finally:
            with active_scans_lock:
                active_scans -= 1
            scan_queue.task_done()


def eval_worker():
    """Pulls from eval_queue, runs OPA policy evaluation. Fast — no Docker."""
    while True:
        try:
            scan_id = eval_queue.get(timeout=5)
        except queue.Empty:
            continue
        try:
            evaluate_scan(scan_id)
        except Exception as e:
            print(f"Eval worker error: {e}", file=sys.stderr)
        finally:
            eval_queue.task_done()


def reeval_worker():
    """Pulls from reeval_queue, re-evaluates existing attestations with current policies."""
    while True:
        try:
            scan_id = reeval_queue.get(timeout=5)
        except queue.Empty:
            continue
        try:
            evaluate_scan(scan_id)
        except Exception as e:
            print(f"Re-eval worker error: {e}", file=sys.stderr)
        finally:
            reeval_queue.task_done()


def status_logger():
    """Periodic status logging."""
    while True:
        time.sleep(30)
        sq = scan_queue.qsize()
        eq = eval_queue.qsize()
        rq = reeval_queue.qsize()
        with active_scans_lock:
            running = active_scans
        if sq > 0 or eq > 0 or rq > 0 or running > 0:
            print(f"Queues: scan={sq} (running={running}/{SCAN_WORKERS}) | eval={eq} | reeval={rq}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    init_db()
    ATTESTATION_DIR.mkdir(parents=True, exist_ok=True)

    # Recover queued/running scans from DB into in-memory queue
    db = get_db()
    stale = db.execute(
        "SELECT id, package, version FROM scans WHERE status IN ('queued', 'running')"
    ).fetchall()
    db.execute("UPDATE scans SET status='queued' WHERE status='running'")  # reset stuck running
    db.commit()
    db.close()
    for row in stale:
        scan_queue.put((50, time.time(), row["package"], row["version"], row["id"]))
    if stale:
        print(f"Recovered {len(stale)} queued scans from DB")

    # Start scan workers (Docker — capped)
    for i in range(SCAN_WORKERS):
        threading.Thread(target=scan_worker, daemon=True, name=f"scan-{i}").start()

    # Start eval workers (OPA — fast)
    for i in range(EVAL_WORKERS):
        threading.Thread(target=eval_worker, daemon=True, name=f"eval-{i}").start()

    # Start re-eval workers (OPA — fast, separate queue)
    for i in range(EVAL_WORKERS):
        threading.Thread(target=reeval_worker, daemon=True, name=f"reeval-{i}").start()

    # Status logger
    threading.Thread(target=status_logger, daemon=True).start()

    print(f"pip-witness: {SCAN_WORKERS} scan workers, {EVAL_WORKERS} eval workers, {EVAL_WORKERS} reeval workers")

    port = int(os.environ.get("PORT", 8080))
    print(f"pip-witness viewer running at http://localhost:{port}")
    app.run(host="0.0.0.0", port=port, debug=False, threaded=True)


if __name__ == "__main__":
    main()
