#!/usr/bin/env python3
"""Build a hash allowlist from observed-clean scan data.

Scans the attestation database for files observed across multiple clean scans.
Files seen in N+ clean scans with the same hash are considered known-good.

The allowlist is exported as JSON for OPA to consume as external data,
and as a compact bloom/cuckoo filter for fast membership testing.

Architecture:
  - Scan all CLEAN attestations
  - Extract (filepath_pattern, sha256) pairs for .pth, .pkl, .pickle files
  - Files seen in >= MIN_SCANS clean scans are allowlisted
  - Export as policies/allowlist.json for OPA data bundle
  - Policy references data.allowlist.pth_hashes / data.allowlist.pkl_hashes
"""

import base64
import json
import os
import sqlite3
import sys
from collections import defaultdict
from pathlib import Path

DB_PATH = Path(__file__).parent / "pip_witness.db"
ALLOWLIST_PATH = Path(__file__).parent.parent / "policies" / "allowlist.json"
MIN_SCANS = 1  # Minimum clean scans to be considered known-good (start at 1, increase over time)


def main():
    db = sqlite3.connect(str(DB_PATH))
    db.row_factory = sqlite3.Row

    # Get all completed scans that evaluated as CLEAN
    rows = db.execute("""
        SELECT s.attestation_path, s.package
        FROM scans s
        JOIN evaluations e ON e.scan_id = s.id
        WHERE s.status = 'completed'
          AND s.attestation_path IS NOT NULL
          AND e.risk_level = 'CLEAN'
          AND e.id = (SELECT MAX(e2.id) FROM evaluations e2 WHERE e2.scan_id = s.id)
    """).fetchall()
    db.close()

    print(f"Processing {len(rows)} clean scans...")

    # Collect hashes: sha256 -> {filename, packages, count}
    pth_hashes = defaultdict(lambda: {"files": set(), "packages": set(), "count": 0})
    pkl_hashes = defaultdict(lambda: {"files": set(), "packages": set(), "count": 0})

    for row in rows:
        att_path = row["attestation_path"]
        pkg = row["package"]
        if not att_path or not os.path.exists(att_path):
            continue

        try:
            with open(att_path) as f:
                envelope = json.load(f)
            payload = json.loads(base64.b64decode(envelope["payload"]))

            for att in payload.get("predicate", {}).get("attestations", []):
                data = att.get("attestation", {})
                for proc in data.get("processes", []):
                    for filepath, digest in (proc.get("openedfiles") or {}).items():
                        if not digest:
                            continue
                        sha = digest.get("sha256", "")
                        if not sha:
                            continue

                        fname = filepath.split("/")[-1]

                        if filepath.endswith(".pth") and "site-packages" in filepath:
                            pth_hashes[sha]["files"].add(fname)
                            pth_hashes[sha]["packages"].add(pkg)
                            pth_hashes[sha]["count"] += 1

                        elif filepath.endswith((".pkl", ".pickle", ".joblib")):
                            pkl_hashes[sha]["files"].add(fname)
                            pkl_hashes[sha]["packages"].add(pkg)
                            pkl_hashes[sha]["count"] += 1

        except Exception:
            continue

    # Filter by minimum scan count
    safe_pth = {sha: {"files": sorted(info["files"]), "packages": sorted(info["packages"]),
                       "count": info["count"]}
                for sha, info in pth_hashes.items() if info["count"] >= MIN_SCANS}

    safe_pkl = {sha: {"files": sorted(info["files"]), "packages": sorted(info["packages"]),
                       "count": info["count"]}
                for sha, info in pkl_hashes.items() if info["count"] >= MIN_SCANS}

    allowlist = {
        "pth_hashes": safe_pth,
        "pkl_hashes": safe_pkl,
        "pth_sha_set": sorted(safe_pth.keys()),
        "pkl_sha_set": sorted(safe_pkl.keys()),
        "generated_from": len(rows),
        "min_scans": MIN_SCANS,
    }

    with open(ALLOWLIST_PATH, "w") as f:
        json.dump(allowlist, f, indent=2)

    print(f"Allowlist written to {ALLOWLIST_PATH}")
    print(f"  .pth hashes: {len(safe_pth)}")
    print(f"  .pkl hashes: {len(safe_pkl)}")
    print(f"  From {len(rows)} clean scans")

    # Print summary
    for sha, info in sorted(safe_pth.items(), key=lambda x: -x[1]["count"]):
        print(f"  PTH: {sha[:16]}... {', '.join(info['files']):30s} (seen {info['count']}x in {', '.join(info['packages'][:3])})")
    for sha, info in sorted(safe_pkl.items(), key=lambda x: -x[1]["count"])[:10]:
        print(f"  PKL: {sha[:16]}... {', '.join(info['files']):30s} (seen {info['count']}x)")


if __name__ == "__main__":
    main()
