#!/usr/bin/env python3
"""Generate SARIF report from pip-witness attestation findings.

Combines:
1. Policy evaluation findings (from cilock verify Rego policies)
2. SBOM vulnerability scan (from pip-audit against installed packages)
3. Pickle bytecode analysis results
4. Static analysis of installed files

Output is SARIF v2.1.0 compatible with GitHub Code Scanning, VS Code, etc.
The SARIF file gets attested by cilock's SARIF attestor → signed → Archivista.
"""

import json
import os
import subprocess
import sys
from datetime import datetime, timezone


def create_sarif_report(package, version, attestation_path=None, sbom_path=None):
    """Generate a SARIF report for a scanned package."""

    sarif = {
        "$schema": "https://raw.githubusercontent.com/oasis-tcs/sarif-spec/master/Schemata/sarif-schema-2.1.0.json",
        "version": "2.1.0",
        "runs": []
    }

    # --- Run 1: pip-witness policy findings ---
    policy_run = {
        "tool": {
            "driver": {
                "name": "pip-witness",
                "version": "1.0.0",
                "informationUri": "https://github.com/aflock-ai/rookery",
                "rules": []
            }
        },
        "results": [],
        "invocations": [{
            "executionSuccessful": True,
            "startTimeUtc": datetime.now(timezone.utc).isoformat(),
        }]
    }

    # Define rules for each policy
    rule_map = {
        "CRITICAL:UNKNOWN_DESTINATION": {"id": "PW001", "level": "error", "name": "UnknownNetworkDestination",
            "shortDescription": {"text": "Package connected to unknown IP during install"}},
        "CRITICAL:MALICIOUS_PICKLE": {"id": "PW002", "level": "error", "name": "MaliciousPickle",
            "shortDescription": {"text": "Pickle file imports dangerous modules (os, subprocess)"}},
        "CRITICAL:UNKNOWN_PTH": {"id": "PW003", "level": "error", "name": "UnknownPthFile",
            "shortDescription": {"text": "Non-standard .pth file executes on every Python startup"}},
        "CRITICAL:REVERSE_SHELL": {"id": "PW004", "level": "error", "name": "ReverseShell",
            "shortDescription": {"text": "Socket-to-stdio redirection pattern detected"}},
        "CRITICAL:STAGED_PAYLOAD": {"id": "PW005", "level": "error", "name": "StagedPayload",
            "shortDescription": {"text": "Download + write + chmod +x = staged binary execution"}},
        "CRITICAL:CRED_EXFIL_SEQUENCE": {"id": "PW006", "level": "error", "name": "CredentialExfiltration",
            "shortDescription": {"text": "Read credentials + network connection = exfiltration"}},
        "HIGH:SHELL_SPAWNED": {"id": "PW010", "level": "warning", "name": "ShellSpawned",
            "shortDescription": {"text": "Shell interpreter invoked during pip install"}},
        "HIGH:FILELESS_EXEC": {"id": "PW011", "level": "warning", "name": "FilelessExecution",
            "shortDescription": {"text": "memfd_create used for anonymous code execution"}},
        "HIGH:SHELLCODE": {"id": "PW012", "level": "warning", "name": "ShellcodeInjection",
            "shortDescription": {"text": "Memory made executable via mprotect"}},
        "HIGH:UNATTESTED_PRERELEASE": {"id": "PW013", "level": "warning", "name": "UnattestedPrerelease",
            "shortDescription": {"text": "Pre-release version without PEP 740 attestation"}},
        "HIGH:DEP_CONFUSION_VERSION": {"id": "PW014", "level": "warning", "name": "DependencyConfusion",
            "shortDescription": {"text": "Suspiciously high version number"}},
        "MEDIUM:UNKNOWN_PROGRAM": {"id": "PW020", "level": "note", "name": "UnknownProgram",
            "shortDescription": {"text": "Unexpected program spawned during install"}},
        "MEDIUM:UNKNOWN_PICKLE": {"id": "PW021", "level": "note", "name": "UnknownPickle",
            "shortDescription": {"text": "Pickle file with unknown hash in site-packages"}},
        "LOW:PICKLE_WITH_RECONSTRUCT": {"id": "PW030", "level": "note", "name": "PickleReconstruct",
            "shortDescription": {"text": "Pickle uses REDUCE/GLOBAL for safe data reconstruction"}},
    }

    rules_added = set()

    # Parse attestation for findings
    if attestation_path and os.path.exists(attestation_path):
        try:
            import base64
            with open(attestation_path) as f:
                env = json.load(f)
            payload = json.loads(base64.b64decode(env["payload"]))

            for att in payload.get("predicate", {}).get("attestations", []):
                data = att.get("attestation", {})

                # Extract pickle analysis findings
                ifa = data.get("installedFileAnalysis", {})
                for pa in ifa.get("pickleAnalysis", []):
                    if not pa.get("isSafe"):
                        for ref in pa.get("globalRefs", []):
                            dangerous = any(ref.startswith(d) for d in
                                ["os", "subprocess", "builtins.exec", "socket", "http"])
                            if dangerous:
                                rule_id = "PW002"
                                msg = f"Pickle {pa['path']} imports {ref} — executes on pickle.load()"
                            else:
                                rule_id = "PW030"
                                msg = f"Pickle {pa['path']} uses {ref} for data reconstruction"

                            if rule_id not in rules_added:
                                rule_info = rule_map.get(
                                    "CRITICAL:MALICIOUS_PICKLE" if rule_id == "PW002" else "LOW:PICKLE_WITH_RECONSTRUCT",
                                    {"id": rule_id, "level": "note", "name": "Pickle",
                                     "shortDescription": {"text": "Pickle finding"}})
                                policy_run["tool"]["driver"]["rules"].append(rule_info)
                                rules_added.add(rule_id)

                            policy_run["results"].append({
                                "ruleId": rule_id,
                                "level": "error" if rule_id == "PW002" else "note",
                                "message": {"text": msg},
                                "locations": [{"physicalLocation": {
                                    "artifactLocation": {"uri": pa["path"]}
                                }}]
                            })

                # Extract network findings
                for proc in data.get("processes", []):
                    net = proc.get("network")
                    if not net:
                        continue
                    known_prefixes = ["151.101.", "199.232.", "192.168.65.", "127.", "8.8.", "1.1.1.", "1.0.0."]
                    for conn in net.get("connections", []):
                        if conn.get("family") == "AF_UNIX":
                            continue
                        addr = conn.get("address", "")
                        if addr and not any(addr.startswith(p) for p in known_prefixes) and \
                           not addr.startswith("::") and addr != "0.0.0.0":
                            rule_id = "PW001"
                            if rule_id not in rules_added:
                                policy_run["tool"]["driver"]["rules"].append(rule_map["CRITICAL:UNKNOWN_DESTINATION"])
                                rules_added.add(rule_id)
                            policy_run["results"].append({
                                "ruleId": rule_id,
                                "level": "error",
                                "message": {"text": f"PID {proc['processid']} connected to unknown IP {addr}:{conn.get('port','')}"},
                            })

                    # Extract syscall event findings
                    for evt in proc.get("syscallEvents", []):
                        if evt["syscall"] == "mprotect":
                            rule_id = "PW012"
                            if rule_id not in rules_added:
                                policy_run["tool"]["driver"]["rules"].append(rule_map["HIGH:SHELLCODE"])
                                rules_added.add(rule_id)
                            policy_run["results"].append({
                                "ruleId": rule_id,
                                "level": "warning",
                                "message": {"text": f"PID {proc['processid']}: {evt['detail']}"},
                            })
                        elif evt["syscall"] == "memfd_create":
                            rule_id = "PW011"
                            if rule_id not in rules_added:
                                policy_run["tool"]["driver"]["rules"].append(rule_map["HIGH:FILELESS_EXEC"])
                                rules_added.add(rule_id)
                            policy_run["results"].append({
                                "ruleId": rule_id,
                                "level": "warning",
                                "message": {"text": f"PID {proc['processid']}: {evt['detail']}"},
                            })

        except Exception as e:
            policy_run["results"].append({
                "ruleId": "PW999",
                "level": "note",
                "message": {"text": f"Attestation parse error: {e}"},
            })

    sarif["runs"].append(policy_run)

    # --- Run 2: SBOM vulnerability scan ---
    vuln_run = {
        "tool": {
            "driver": {
                "name": "pip-audit",
                "version": "2.0",
                "informationUri": "https://github.com/pypa/pip-audit",
                "rules": []
            }
        },
        "results": [],
    }

    # Run pip-audit if available
    try:
        result = subprocess.run(
            ["pip-audit", "--format=json", "--output=-"],
            capture_output=True, text=True, timeout=60
        )
        if result.stdout:
            audit_data = json.loads(result.stdout)
            for dep in audit_data.get("dependencies", []):
                for vuln in dep.get("vulns", []):
                    cve = vuln.get("id", "UNKNOWN")
                    rule_id = cve.replace("-", "")
                    desc = vuln.get("description", "No description")[:200]

                    vuln_run["tool"]["driver"]["rules"].append({
                        "id": rule_id,
                        "name": cve,
                        "shortDescription": {"text": f"{cve} in {dep['name']}=={dep['version']}"},
                        "helpUri": f"https://osv.dev/vulnerability/{cve}",
                    })
                    vuln_run["results"].append({
                        "ruleId": rule_id,
                        "level": "warning",
                        "message": {"text": f"{cve}: {desc}"},
                        "locations": [{"logicalLocation": {
                            "name": f"{dep['name']}=={dep['version']}",
                            "kind": "package"
                        }}]
                    })
    except FileNotFoundError:
        pass  # pip-audit not installed
    except Exception as e:
        vuln_run["results"].append({
            "ruleId": "AUDIT999",
            "level": "note",
            "message": {"text": f"pip-audit error: {e}"},
        })

    if vuln_run["results"]:
        sarif["runs"].append(vuln_run)

    return sarif


def main():
    if len(sys.argv) < 3:
        print("Usage: generate_sarif.py <package> <version> [attestation_path] [sbom_path]")
        sys.exit(1)

    package = sys.argv[1]
    version = sys.argv[2]
    att_path = sys.argv[3] if len(sys.argv) > 3 else None
    sbom_path = sys.argv[4] if len(sys.argv) > 4 else None

    sarif = create_sarif_report(package, version, att_path, sbom_path)

    # Output path
    output = os.environ.get("SARIF_OUTPUT", f"/tmp/{package}-{version}.sarif")
    with open(output, "w") as f:
        json.dump(sarif, f, indent=2)

    result_count = sum(len(run.get("results", [])) for run in sarif["runs"])
    print(f"SARIF: {output} ({result_count} findings across {len(sarif['runs'])} runs)")


if __name__ == "__main__":
    main()
