#!/usr/bin/env python3
"""PyPI incremental crawler with risk-based package prioritization.

Prioritization algorithm for selecting which packages to scan first:

TIER 1 (Priority 1-10): Subscribed packages - user-requested monitoring
TIER 2 (Priority 11-20): Recently updated packages with sdist (setup.py execution)
TIER 3 (Priority 21-30): High-download packages never scanned
TIER 4 (Priority 31-40): New packages (< 7 days old) - highest typosquat risk
TIER 5 (Priority 41-50): Packages with new maintainers or ownership transfers
TIER 6 (Priority 51-70): Medium-download packages, stale scans (> 30 days)
TIER 7 (Priority 71-90): Low-download, never-scanned packages
TIER 8 (Priority 91-99): Everything else (alphabetical fill)

Risk signals that increase priority (lower number = higher priority):
- Has sdist distribution (setup.py runs arbitrary code)
- Single maintainer with no other packages
- Name similar to popular package (typosquatting)
- Recently changed maintainer
- Unusually high number of releases in short time
- No source repo link
- Very new account publishing packages
"""

import json
import math
import os
import sqlite3
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from urllib.request import urlopen, Request
from urllib.error import URLError

VIEWER_DB = Path(__file__).parent.parent / "viewer" / "pip_witness.db"
CRAWLER_DB = Path(__file__).parent / "crawler.db"
PYPI_SIMPLE = "https://pypi.org/simple/"
PYPI_JSON = "https://pypi.org/pypi/{}/json"
PYPI_STATS = "https://pypistats.org/api/packages/{}/recent"

# Popular packages for typosquat detection
POPULAR_PACKAGES = {
    "requests", "boto3", "urllib3", "setuptools", "certifi", "idna",
    "charset-normalizer", "typing-extensions", "python-dateutil", "packaging",
    "botocore", "s3transfer", "numpy", "pyyaml", "pip", "six", "jmespath",
    "cryptography", "cffi", "pycparser", "wheel", "attrs", "pyasn1",
    "platformdirs", "importlib-metadata", "colorama", "awscli", "jinja2",
    "markupsafe", "tomli", "pytz", "zipp", "rsa", "filelock", "protobuf",
    "click", "scipy", "pandas", "pillow", "grpcio", "google-api-core",
    "flask", "django", "fastapi", "uvicorn", "pydantic", "sqlalchemy",
    "celery", "redis", "psycopg2", "pymongo", "httpx", "aiohttp",
    "transformers", "torch", "tensorflow", "langchain", "openai",
    "anthropic", "litellm", "llama-index", "huggingface-hub",
}


def get_crawler_db():
    db = sqlite3.connect(str(CRAWLER_DB))
    db.row_factory = sqlite3.Row
    return db


def init_crawler_db():
    db = get_crawler_db()
    db.executescript("""
        CREATE TABLE IF NOT EXISTS packages (
            name TEXT PRIMARY KEY,
            last_checked TEXT,
            latest_version TEXT,
            has_sdist INTEGER DEFAULT 0,
            download_count INTEGER DEFAULT 0,
            maintainer_count INTEGER DEFAULT 0,
            first_seen TEXT,
            release_count INTEGER DEFAULT 0,
            has_source_repo INTEGER DEFAULT 1,
            priority_score INTEGER DEFAULT 50,
            last_release_date TEXT,
            is_new INTEGER DEFAULT 0
        );

        CREATE TABLE IF NOT EXISTS crawl_state (
            key TEXT PRIMARY KEY,
            value TEXT
        );

        CREATE INDEX IF NOT EXISTS idx_packages_priority ON packages(priority_score);
        CREATE INDEX IF NOT EXISTS idx_packages_checked ON packages(last_checked);
    """)
    db.commit()
    db.close()


def levenshtein(s1: str, s2: str) -> int:
    """Simple Levenshtein distance for typosquat detection."""
    if len(s1) < len(s2):
        return levenshtein(s2, s1)
    if len(s2) == 0:
        return len(s1)
    prev = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        curr = [i + 1]
        for j, c2 in enumerate(s2):
            curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (c1 != c2)))
        prev = curr
    return prev[-1]


def is_typosquat_candidate(name: str) -> bool:
    """Check if a package name is suspiciously similar to a popular package."""
    normalized = name.lower().replace("-", "").replace("_", "")
    for popular in POPULAR_PACKAGES:
        pop_norm = popular.lower().replace("-", "").replace("_", "")
        if normalized == pop_norm:
            continue  # exact match, it IS the popular package
        dist = levenshtein(normalized, pop_norm)
        if dist <= 2 and len(normalized) > 3:
            return True
        # Check for common typosquat patterns
        if normalized.startswith(pop_norm) and len(normalized) - len(pop_norm) <= 3:
            return True
        if normalized.endswith(pop_norm) and len(normalized) - len(pop_norm) <= 3:
            return True
    return False


def calculate_priority(pkg_info: dict) -> int:
    """Calculate priority score (1=highest, 99=lowest) based on risk signals."""
    score = 50  # default middle priority

    # Download count factor (higher downloads = scan sooner)
    downloads = pkg_info.get("download_count", 0)
    if downloads > 1_000_000:
        score -= 20
    elif downloads > 100_000:
        score -= 15
    elif downloads > 10_000:
        score -= 10
    elif downloads > 1_000:
        score -= 5

    # Has sdist (runs setup.py = much higher risk)
    if pkg_info.get("has_sdist"):
        score -= 10

    # New package (< 7 days old)
    first_seen = pkg_info.get("first_seen")
    if first_seen:
        try:
            age = datetime.now(timezone.utc) - datetime.fromisoformat(first_seen)
            if age < timedelta(days=7):
                score -= 15  # very new = suspicious
            elif age < timedelta(days=30):
                score -= 8
        except (ValueError, TypeError):
            pass

    # Typosquatting signal
    if is_typosquat_candidate(pkg_info.get("name", "")):
        score -= 20

    # Single maintainer
    if pkg_info.get("maintainer_count", 0) <= 1:
        score -= 5

    # No source repo
    if not pkg_info.get("has_source_repo"):
        score -= 8

    # Recent release activity (many releases in short time is suspicious)
    if pkg_info.get("release_count", 0) > 20:
        last_release = pkg_info.get("last_release_date")
        if last_release:
            try:
                days_since = (datetime.now(timezone.utc) - datetime.fromisoformat(last_release)).days
                if days_since < 7 and pkg_info["release_count"] > 5:
                    score -= 10
            except (ValueError, TypeError):
                pass

    return max(1, min(99, score))


def fetch_package_list() -> list[str]:
    """Fetch list of all packages from PyPI Simple API."""
    print("Fetching PyPI package index...")
    req = Request(PYPI_SIMPLE, headers={"Accept": "application/vnd.pypi.simple.v1+json"})
    try:
        resp = urlopen(req, timeout=60)
        data = json.loads(resp.read())
        packages = [p["name"] for p in data.get("projects", [])]
        print(f"  Found {len(packages)} packages")
        return packages
    except Exception as e:
        print(f"  Error fetching package list: {e}")
        # Fallback: HTML parsing
        try:
            resp = urlopen(PYPI_SIMPLE, timeout=60)
            html = resp.read().decode()
            import re
            packages = re.findall(r'href="[^"]*">([^<]+)</a>', html)
            print(f"  Found {len(packages)} packages (HTML fallback)")
            return packages
        except Exception as e2:
            print(f"  HTML fallback also failed: {e2}")
            return []


def fetch_package_info(name: str) -> Optional[dict]:
    """Fetch package metadata from PyPI JSON API."""
    try:
        url = PYPI_JSON.format(name)
        resp = urlopen(url, timeout=15)
        data = json.loads(resp.read())

        info = data.get("info", {})
        releases = data.get("releases", {})

        # Check for sdist
        has_sdist = False
        latest_version = info.get("version", "")
        if latest_version and latest_version in releases:
            for dist in releases[latest_version]:
                if dist.get("packagetype") == "sdist":
                    has_sdist = True
                    break

        # Count maintainers (author + maintainer fields)
        maintainers = set()
        if info.get("author"):
            maintainers.add(info["author"])
        if info.get("maintainer"):
            maintainers.add(info["maintainer"])
        if info.get("author_email"):
            maintainers.add(info["author_email"])

        # Check for source repo
        has_source = bool(
            info.get("home_page") or
            info.get("project_url") or
            any("github" in (u or "").lower() or "gitlab" in (u or "").lower()
                for u in (info.get("project_urls") or {}).values())
        )

        # Release dates
        release_dates = []
        for ver, dists in releases.items():
            for d in dists:
                if d.get("upload_time"):
                    release_dates.append(d["upload_time"])
        release_dates.sort()

        first_seen = release_dates[0] if release_dates else None
        last_release = release_dates[-1] if release_dates else None

        return {
            "name": name,
            "latest_version": latest_version,
            "has_sdist": has_sdist,
            "maintainer_count": len(maintainers),
            "has_source_repo": has_source,
            "release_count": len(releases),
            "first_seen": first_seen,
            "last_release_date": last_release,
        }
    except Exception as e:
        return None


def fetch_download_count(name: str) -> int:
    """Fetch recent download count from pypistats."""
    try:
        resp = urlopen(PYPI_STATS.format(name), timeout=10)
        data = json.loads(resp.read())
        return data.get("data", {}).get("last_month", 0)
    except Exception:
        return 0


def get_already_scanned() -> set:
    """Get set of package names already scanned in viewer DB."""
    if not VIEWER_DB.exists():
        return set()
    try:
        db = sqlite3.connect(str(VIEWER_DB))
        rows = db.execute("SELECT DISTINCT package FROM scans WHERE status='completed'").fetchall()
        db.close()
        return {r[0] for r in rows}
    except Exception:
        return set()


def get_subscribed_packages() -> dict:
    """Get subscribed packages with their priorities."""
    if not VIEWER_DB.exists():
        return {}
    try:
        db = sqlite3.connect(str(VIEWER_DB))
        rows = db.execute("SELECT package, priority FROM subscriptions").fetchall()
        db.close()
        return {r[0]: r[1] for r in rows}
    except Exception:
        return {}


def submit_scan(package: str, priority: int = 50, source: str = "crawler"):
    """Submit a scan to the viewer."""
    import urllib.request
    data = json.dumps({"package": package, "priority": priority, "source": source}).encode()
    req = urllib.request.Request(
        "http://localhost:8085/api/scan",
        data=data,
        headers={"Content-Type": "application/json"},
        method="POST"
    )
    try:
        urllib.request.urlopen(req, timeout=5)
        return True
    except Exception as e:
        print(f"  Failed to submit scan: {e}")
        return False


def crawl_and_prioritize(batch_size: int = 50, enrich_batch: int = 20):
    """Main crawl loop: fetch packages, enrich metadata, prioritize, submit scans."""
    init_crawler_db()
    db = get_crawler_db()

    # Get crawl state
    row = db.execute("SELECT value FROM crawl_state WHERE key='last_index_fetch'").fetchone()
    last_fetch = row[0] if row else None

    # Re-fetch index if > 1 hour old
    should_refetch = True
    if last_fetch:
        try:
            age = datetime.now(timezone.utc) - datetime.fromisoformat(last_fetch)
            should_refetch = age > timedelta(hours=1)
        except (ValueError, TypeError):
            pass

    if should_refetch:
        packages = fetch_package_list()
        if packages:
            # Insert new packages
            existing = {r[0] for r in db.execute("SELECT name FROM packages").fetchall()}
            new_pkgs = set(packages) - existing
            if new_pkgs:
                print(f"Found {len(new_pkgs)} new packages")
                for name in new_pkgs:
                    db.execute(
                        "INSERT OR IGNORE INTO packages (name, is_new) VALUES (?, 1)",
                        (name,)
                    )
                db.commit()

            db.execute(
                "INSERT OR REPLACE INTO crawl_state (key, value) VALUES ('last_index_fetch', ?)",
                (datetime.now(timezone.utc).isoformat(),)
            )
            db.commit()

    # Phase 1: Enrich un-enriched packages (fetch metadata)
    unenriched = db.execute(
        "SELECT name FROM packages WHERE last_checked IS NULL ORDER BY is_new DESC, RANDOM() LIMIT ?",
        (enrich_batch,)
    ).fetchall()

    scanned = get_already_scanned()
    subscribed = get_subscribed_packages()

    for row in unenriched:
        name = row[0]
        print(f"Enriching: {name}")
        info = fetch_package_info(name)
        if info:
            info["download_count"] = fetch_download_count(name)
            priority = calculate_priority(info)

            # Boost priority if subscribed
            if name in subscribed:
                priority = min(priority, subscribed[name])

            # Lower priority if already scanned
            if name in scanned:
                priority = min(priority + 30, 99)

            db.execute("""
                UPDATE packages SET
                    last_checked=datetime('now'),
                    latest_version=?,
                    has_sdist=?,
                    download_count=?,
                    maintainer_count=?,
                    first_seen=?,
                    release_count=?,
                    has_source_repo=?,
                    priority_score=?,
                    last_release_date=?,
                    is_new=0
                WHERE name=?
            """, (
                info["latest_version"], int(info["has_sdist"]),
                info.get("download_count", 0), info["maintainer_count"],
                info["first_seen"], info["release_count"],
                int(info["has_source_repo"]), priority,
                info["last_release_date"], name
            ))
            db.commit()
            time.sleep(0.5)  # rate limit

    # Phase 2: Submit top priority packages for scanning
    # Get packages not yet scanned, ordered by priority
    candidates = db.execute("""
        SELECT name, priority_score, latest_version, has_sdist, download_count
        FROM packages
        WHERE last_checked IS NOT NULL
            AND priority_score < 80
        ORDER BY priority_score ASC
        LIMIT ?
    """, (batch_size,)).fetchall()

    submitted = 0
    for row in candidates:
        name, priority, version, has_sdist, downloads = row
        if name in scanned:
            continue

        print(f"  Submitting: {name} (priority={priority}, sdist={has_sdist}, downloads={downloads})")
        if submit_scan(name, priority, "crawler"):
            submitted += 1
            scanned.add(name)

        if submitted >= batch_size:
            break

        time.sleep(0.2)

    print(f"\nSubmitted {submitted} packages for scanning")
    db.close()


def main():
    import argparse
    parser = argparse.ArgumentParser(description="PyPI incremental crawler")
    parser.add_argument("--batch", type=int, default=10, help="Number of packages to submit per run")
    parser.add_argument("--enrich", type=int, default=30, help="Number of packages to enrich per run")
    parser.add_argument("--continuous", action="store_true", help="Run continuously")
    parser.add_argument("--interval", type=int, default=300, help="Seconds between crawl cycles (with --continuous)")
    args = parser.parse_args()

    if args.continuous:
        print("Starting continuous crawl mode...")
        while True:
            try:
                crawl_and_prioritize(batch_size=args.batch, enrich_batch=args.enrich)
            except Exception as e:
                print(f"Crawl error: {e}")
            print(f"Sleeping {args.interval}s until next cycle...")
            time.sleep(args.interval)
    else:
        crawl_and_prioritize(batch_size=args.batch, enrich_batch=args.enrich)


if __name__ == "__main__":
    main()
