#!/usr/bin/env python3

import hashlib
import argparse
from pathlib import Path


# ------------------------------------------------------------
# FINALNE MAPOWANIA ODTWORZONE Z ANALIZY
# ------------------------------------------------------------

# perm[new_bit] = old_bit
DATA_PERM = (1, 0, 5, 4, 7, 6, 2, 3)

# logical_addr_bit -> original_physical_addr_bit
ADDR_PERM = (2, 10, 11, 9, 3, 8, 7, 6, 5, 4, 0, 1)

ROM_SIZE = 4096
ADDR_BITS = 12


# ------------------------------------------------------------
# DANE
# ------------------------------------------------------------

def build_data_lut(perm):
    """
    Buduje LUT dla 256 bajtów.
    perm[new_bit] = old_bit
    """
    lut = bytearray(256)

    for b in range(256):
        out = 0
        for new_bit, old_bit in enumerate(perm):
            if (b >> old_bit) & 1:
                out |= (1 << new_bit)
        lut[b] = out

    return bytes(lut)


def apply_data_permutation(data, perm):
    lut = build_data_lut(perm)
    return bytes(lut[b] for b in data)


# ------------------------------------------------------------
# ADRESY
# ------------------------------------------------------------

def build_addr_map(logical_to_physical, nbits=ADDR_BITS):
    """
    logical_to_physical[logical_bit] = original_physical_bit

    Zwraca mapę:
        logical_address -> physical_address
    """
    size = 1 << nbits
    addr_map = [0] * size

    for logical_addr in range(size):
        physical_addr = 0
        for logical_bit, physical_bit in enumerate(logical_to_physical):
            if (logical_addr >> logical_bit) & 1:
                physical_addr |= (1 << physical_bit)
        addr_map[logical_addr] = physical_addr

    return addr_map


def apply_address_permutation(data, logical_to_physical):
    addr_map = build_addr_map(logical_to_physical, nbits=ADDR_BITS)
    return bytes(data[addr_map[a]] for a in range(len(data)))


# ------------------------------------------------------------
# MAIN
# ------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Naprawia dump ROM ZABEZTUR: permutacja danych + adresów."
    )
    parser.add_argument("input", help="Oryginalny plik dumpa 4KB")
    parser.add_argument(
        "-o", "--output",
        help="Plik wynikowy; domyślnie <wejscie>_recovered.bin"
    )
    args = parser.parse_args()

    in_path = Path(args.input)
    raw = in_path.read_bytes()

    if len(raw) != ROM_SIZE:
        raise SystemExit(
            f"Blad: oczekiwano pliku {ROM_SIZE} bajtow (4KB), a jest {len(raw)} bajtow."
        )

    # 1) napraw dane
    data_fixed = apply_data_permutation(raw, DATA_PERM)

    # 2) napraw adresy
    rom_fixed = apply_address_permutation(data_fixed, ADDR_PERM)

    if args.output:
        out_path = Path(args.output)
    else:
        out_path = in_path.with_name(in_path.stem + "_recovered.bin")

    out_path.write_bytes(rom_fixed)

    # --- policz "skróty" pliku wynikowego ---
    md5 = hashlib.md5(rom_fixed).hexdigest()
    sha256 = hashlib.sha256(rom_fixed).hexdigest()

    print("Gotowe.")
    print(f"Wejscie : {in_path}")
    print(f"Wyjscie : {out_path}")
    print(f"Dane    : {DATA_PERM}")
    print(f"Adresy  : {ADDR_PERM}")
    print()
    print("Skróty pliku wynikowego:")
    print(f"  MD5    : {md5}")
    print(f"  SHA256 : {sha256}")

if __name__ == "__main__":
    main()

