#!/usr/bin/env python3
"""Evaluate Rego policies against attestation data using OPA.

Runs each risk-dimension policy against the appropriate attestation type
and returns structured results per-policy. This replaces the ad-hoc Python
risk scoring with formal, auditable policy evaluation.

Policy dimensions:
  - network_exfiltration: evaluates command-run attestation
  - credential_harvesting: evaluates command-run attestation
  - code_execution: evaluates command-run attestation
  - persistence: evaluates command-run attestation
  - supply_chain: evaluates pip-install attestation
  - package_signing: evaluates pip-install attestation
"""

import base64
import hashlib
import json
import os
import subprocess
import tempfile
from pathlib import Path
from typing import Optional

POLICY_DIR = Path(__file__).parent.parent / "policies" / "rego"

# Cache the policy version hash (recomputed on module load)
_policy_version_cache = None


def compute_policy_version() -> str:
    """Compute a deterministic SHA256 hash of all .rego policy files.
    Sorted by filename, concatenated. Changes when any policy is added/modified/removed."""
    global _policy_version_cache
    if _policy_version_cache:
        return _policy_version_cache

    h = hashlib.sha256()
    rego_files = sorted(POLICY_DIR.glob("*.rego"))
    for f in rego_files:
        h.update(f.name.encode())
        h.update(f.read_bytes())
    _policy_version_cache = h.hexdigest()[:16]  # 16 chars is enough for identification
    return _policy_version_cache


def invalidate_policy_cache():
    """Call when policies change to force recomputation."""
    global _policy_version_cache
    _policy_version_cache = None


def get_policy_info() -> dict:
    """Return info about all policies: names, descriptions, version hash."""
    version = compute_policy_version()
    policies = []
    for f in sorted(POLICY_DIR.glob("*.rego")):
        # Extract description from comment block
        desc = ""
        with open(f) as fh:
            for line in fh:
                if line.startswith("#") and not line.startswith("# Copyright"):
                    desc = line.lstrip("# ").strip()
                    if desc and not desc.startswith("import") and not desc.startswith("package"):
                        break
        policies.append({
            "name": f.stem,
            "file": f.name,
            "description": desc,
            "attestation_type": POLICY_ATTESTATION_MAP.get(f.stem, "unknown"),
        })
    return {
        "version": version,
        "policy_count": len(policies),
        "policies": policies,
    }

# Map policy name -> which attestation type it evaluates
POLICY_ATTESTATION_MAP = {
    # Core runtime policies (evaluate command-run attestation)
    "network_exfiltration": "command-run",
    "credential_harvesting": "command-run",
    "code_execution": "command-run",
    "persistence": "command-run",
    # Attack-specific IOC detection
    "pth_injection": "command-run",
    "teampcp_ioc": "command-run",
    "openclaw_ioc": "command-run",
    # Theoretical/emerging threat detection
    "container_escape": "command-run",
    "dns_exfiltration": "command-run",
    "evasion_detection": "command-run",
    "pickle_model": "pip-install",
    "import_time_risk": "command-run",
    "attack_sequences": "command-run",
    # Supply chain policies (evaluate pip-install attestation)
    "supply_chain": "pip-install",
    "package_signing": "pip-install",
    "dependency_confusion": "pip-install",
    "release_integrity": "pip-install",
}

# Severity weights for violation severity prefixes
# Policies now prefix violations with CRITICAL:, HIGH:, MEDIUM:, LOW:, INFO:
VIOLATION_WEIGHTS = {
    "CRITICAL": 20,
    "HIGH": 10,
    "MEDIUM": 4,
    "LOW": 1,
    "INFO": 0,  # informational, doesn't add to risk score
}

# Fallback weights per dimension (if violations don't have severity prefixes)
POLICY_SEVERITY = {
    "network_exfiltration": 10,
    "credential_harvesting": 15,
    "code_execution": 8,
    "persistence": 12,
    "supply_chain": 5,
    "package_signing": 3,
}


def score_violation(violation: str) -> int:
    """Extract severity from a violation message and return its weight."""
    for prefix, weight in VIOLATION_WEIGHTS.items():
        if violation.startswith(prefix + ":"):
            return weight
    return 5  # default if no prefix


def extract_attestation_by_type(attestation_path: str, att_type_fragment: str) -> Optional[dict]:
    """Extract a specific attestation type's data from an envelope."""
    try:
        with open(attestation_path) as f:
            envelope = json.load(f)
        payload = json.loads(base64.b64decode(envelope["payload"]))
        atts = payload.get("predicate", {}).get("attestations", [])
        for att in atts:
            if att_type_fragment in att.get("type", ""):
                return att.get("attestation", {})
    except Exception:
        pass
    return None


def evaluate_policy(policy_name: str, input_data: dict) -> dict:
    """Run a single Rego policy against input data using OPA."""
    policy_file = POLICY_DIR / f"{policy_name}.rego"
    if not policy_file.exists():
        return {"policy": policy_name, "error": f"policy file not found: {policy_file}"}

    # Write input to temp file
    with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
        json.dump(input_data, f)
        input_path = f.name

    try:
        # Run OPA eval, optionally with allowlist data
        package_name = f"pip_witness.{policy_name}"
        allowlist_path = POLICY_DIR / "allowlist.json"
        cmd = [
            "opa", "eval",
            "--data", str(policy_file),
            "--input", input_path,
            "--format", "json",
        ]
        if allowlist_path.exists():
            cmd.extend(["--data", str(allowlist_path)])
        cmd.append(f"data.{package_name}.deny")

        result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)

        if result.returncode != 0:
            return {
                "policy": policy_name,
                "error": result.stderr.strip(),
                "passed": False,
                "violations": [],
            }

        opa_result = json.loads(result.stdout)

        # Extract deny results - OPA returns different shapes:
        # set/dict: {"violation_msg": true, ...}
        # list: ["violation_msg", ...]
        violations = []
        for r in opa_result.get("result", []):
            for expr in r.get("expressions", []):
                val = expr.get("value")
                if isinstance(val, list):
                    violations.extend([str(v) for v in val])
                elif isinstance(val, dict):
                    # OPA set representation: keys are the deny messages
                    violations.extend([str(k) for k in val.keys()])
                elif isinstance(val, str):
                    violations.append(val)

        return {
            "policy": policy_name,
            "passed": len(violations) == 0,
            "violations": violations,
            "violation_count": len(violations),
        }

    except subprocess.TimeoutExpired:
        return {"policy": policy_name, "error": "timeout", "passed": False, "violations": []}
    except Exception as e:
        return {"policy": policy_name, "error": str(e), "passed": False, "violations": []}
    finally:
        os.unlink(input_path)


def evaluate_all_policies(attestation_path: str) -> dict:
    """Evaluate all risk policies against an attestation and return structured results."""
    results = {
        "policies": {},
        "overall_passed": True,
        "total_violations": 0,
        "risk_score": 0,
        "risk_level": "CLEAN",
        "dimensions": {},
        "policy_version": compute_policy_version(),
    }

    for policy_name, att_type in POLICY_ATTESTATION_MAP.items():
        att_data = extract_attestation_by_type(attestation_path, att_type)
        if att_data is None:
            results["policies"][policy_name] = {
                "policy": policy_name,
                "passed": True,
                "violations": [],
                "violation_count": 0,
                "note": f"No {att_type} attestation found",
            }
            continue

        policy_result = evaluate_policy(policy_name, att_data)
        results["policies"][policy_name] = policy_result

        violations = policy_result.get("violations", [])
        if violations:
            results["total_violations"] += len(violations)
            for v in violations:
                w = score_violation(v)
                results["risk_score"] += w
                if w >= 10:  # HIGH or CRITICAL
                    results["overall_passed"] = False

        # Dimension summary
        results["dimensions"][policy_name] = {
            "passed": policy_result.get("passed", True),
            "count": policy_result.get("violation_count", 0),
            "severity": POLICY_SEVERITY.get(policy_name, 5),
        }

    # Calculate risk level from score
    score = results["risk_score"]
    if score >= 100:
        results["risk_level"] = "CRITICAL"
    elif score >= 50:
        results["risk_level"] = "HIGH"
    elif score >= 20:
        results["risk_level"] = "MEDIUM"
    elif score >= 5:
        results["risk_level"] = "LOW"
    else:
        results["risk_level"] = "CLEAN"

    return results


def main():
    import sys
    if len(sys.argv) < 2:
        print("Usage: policy_eval.py <attestation-path>")
        sys.exit(1)

    results = evaluate_all_policies(sys.argv[1])
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    main()
