#!/usr/bin/env python3
"""
export_bin.py - Export HuggingFace model to MacinAI .bin format

Converts a trained model from safetensors to the custom MacinAI binary
format for the C89 inference engine on 68K Macs.

Supported architectures:
  - LLaMA-family (LLaMA, Mistral, Qwen, TinyLlama, SmolLM, etc.)
  - GPT-2-family (GPT-2, OPT, Pythia, GPT-J, etc.)

Format: 128-byte header + vocab section + weight tensors
All multi-byte values are big-endian (native 68K byte order).

Supports two quantization modes:
  --quantize f32   Float32 (default, ~378 MB for 94.5M params)
  --quantize q8    Q8_0 per-tensor quantization (~94.5 MB)

Q8 format per tensor: [4-byte float scale][N bytes int8 data]
  scale = max(abs(tensor)) / 127.0
  int8_data[i] = round(tensor[i] / scale)
  Norms (LayerNorm/RMSNorm weights) are always stored as float32.

Usage:
    python export_bin.py --model-dir models/sft_v19 --quantize q8
    python export_bin.py --model-dir models/gpt2_hf --quantize q8
    python export_bin.py --model-dir models/sft_v19 --output models/macinai_local.bin

Requires: torch, transformers, numpy
"""

import argparse
import json
import os
import platform
import struct
import subprocess
import sys

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


# --- Constants ---
FILE_TYPE = "MCAI"      # MacinAI Model
FILE_CREATOR = "OAS "   # Old Apple Stuff
MAGIC = 0x4D434149  # 'MCAI'
FORMAT_VERSION = 2
HEADER_SIZE = 128   # 32 longs = 128 bytes

# Quantization types (must match ModelConfig.h QuantType enum)
QUANT_FLOAT32 = 0
QUANT_INT8 = 1

# Weight tensor order (must match C engine expectations)
LAYER_WEIGHT_NAMES = [
    "self_attn.q_proj.weight",
    "self_attn.k_proj.weight",
    "self_attn.v_proj.weight",
    "self_attn.o_proj.weight",
    "mlp.gate_proj.weight",
    "mlp.up_proj.weight",
    "mlp.down_proj.weight",
    "input_layernorm.weight",
    "post_attention_layernorm.weight",
]


def set_mac_type_creator(filepath):
    """Set Mac file type and creator codes.

    On macOS: uses SetFile if available, falls back to xattr.
    On Linux: prints the command to run after transferring to Mac.
    """
    if platform.system() == "Darwin":
        # Try SetFile first (from Xcode command line tools)
        try:
            subprocess.run(
                ["SetFile", "-t", FILE_TYPE, "-c", FILE_CREATOR, filepath],
                check=True, capture_output=True,
            )
            print(f"\n  File type set: type='{FILE_TYPE}' creator='{FILE_CREATOR}'")
            return
        except (FileNotFoundError, subprocess.CalledProcessError):
            pass

        # Fallback: xattr (works on macOS without Xcode tools)
        try:
            # Mac type/creator stored in com.apple.FinderInfo xattr
            # 32 bytes: [4 type][4 creator][24 other finder info]
            finder_info = (
                FILE_TYPE.encode("mac_roman")
                + FILE_CREATOR.encode("mac_roman")
                + b"\x00" * 24
            )
            import xattr
            xattr.setxattr(filepath, "com.apple.FinderInfo", finder_info)
            print(f"\n  File type set via xattr: type='{FILE_TYPE}' "
                  f"creator='{FILE_CREATOR}'")
            return
        except Exception:
            pass

    # Not on macOS or both methods failed
    print(f"\n  NOTE: After transferring to Mac, run:")
    print(f"    SetFile -t {FILE_TYPE} -c '{FILE_CREATOR}' {os.path.basename(filepath)}")


def write_big_endian_long(f, value):
    """Write a 32-bit signed integer in big-endian byte order."""
    f.write(struct.pack(">i", value))


def write_big_endian_float_tensor(f, tensor):
    """Write a float32 tensor in big-endian byte order.

    Converts from PyTorch's native format to big-endian IEEE 754 float32.
    """
    data = tensor.detach().float().cpu().numpy().flatten()
    packed = struct.pack(">" + "f" * len(data), *data)
    f.write(packed)
    return len(data) * 4  # bytes written


def quantize_tensor_q8(tensor):
    """Quantize a tensor to Q8_0 (per-tensor scale + int8 data).

    Returns (scale, int8_data) where:
      scale = max(abs(tensor)) / 127.0
      int8_data = round(tensor / scale), clamped to [-128, 127]
    """
    data = tensor.detach().float().cpu().numpy().flatten()
    max_abs = np.max(np.abs(data))
    if max_abs == 0:
        max_abs = 1e-10
    scale = float(max_abs / 127.0)
    int8_data = np.round(data / scale).clip(-128, 127).astype(np.int8)
    return scale, int8_data


def write_q8_tensor(f, tensor):
    """Write a Q8 quantized tensor: [4-byte BE float scale][int8 data].

    Returns bytes written.
    """
    scale, int8_data = quantize_tensor_q8(tensor)
    # Write scale as big-endian float32
    f.write(struct.pack(">f", scale))
    # Write int8 data (endianness doesn't matter for single bytes)
    f.write(int8_data.tobytes())
    return 4 + len(int8_data)


def write_vocab_section(f, tokenizer_dir, model_vocab_size=None):
    """Write the vocab section: token strings + BPE merge rules.

    Token format: [1 byte length][N bytes string] for each token 0..vocab_size-1
    Merge format: [1 byte len1][string1][1 byte len2][string2] for each merge

    model_vocab_size: if set, pad with empty entries to match embedding matrix.
    """
    # Load the HF tokenizer to get vocab and merges
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)

    vocab = tokenizer.get_vocab()
    # Build id->token mapping
    id_to_token = {v: k for k, v in vocab.items()}
    vocab_size = len(vocab)

    # Must write exactly model_vocab_size entries to match header/embedding matrix
    write_count = vocab_size
    if model_vocab_size is not None and model_vocab_size > vocab_size:
        print(f"  Tokenizer vocab: {vocab_size}, model vocab: {model_vocab_size}")
        print(f"  Padding {model_vocab_size - vocab_size} empty entries")
        write_count = model_vocab_size
    elif model_vocab_size is not None and model_vocab_size < vocab_size:
        print(f"  WARNING: Tokenizer vocab ({vocab_size}) > model ({model_vocab_size})")
        write_count = model_vocab_size

    print(f"  Vocab size: {write_count}")

    bytes_written = 0

    # Write each token string (ordered by ID), padded to write_count
    for token_id in range(write_count):
        token_str = id_to_token.get(token_id, "")
        # Encode as UTF-8 bytes (Mac will handle ASCII subset)
        token_bytes = token_str.encode("utf-8", errors="replace")
        if len(token_bytes) > 255:
            token_bytes = token_bytes[:255]
        # Write 1-byte length prefix + string bytes
        f.write(struct.pack("B", len(token_bytes)))
        f.write(token_bytes)
        bytes_written += 1 + len(token_bytes)

    # Load merge rules from tokenizer.json
    tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
    num_merges = 0
    if os.path.exists(tokenizer_json_path):
        with open(tokenizer_json_path, "r") as tj:
            tok_data = json.load(tj)
        merges = tok_data.get("model", {}).get("merges", [])
        num_merges = len(merges)
        print(f"  BPE merges: {num_merges}")

        for merge_entry in merges:
            # Merges can be ["tok1", "tok2"] lists or "tok1 tok2" strings
            if isinstance(merge_entry, list):
                parts = merge_entry
            else:
                parts = merge_entry.split(" ", 1)
            if len(parts) != 2:
                f.write(struct.pack("B", 0))
                f.write(struct.pack("B", 0))
                bytes_written += 2
                continue

            for part in parts:
                part_bytes = part.encode("utf-8", errors="replace")
                if len(part_bytes) > 255:
                    part_bytes = part_bytes[:255]
                f.write(struct.pack("B", len(part_bytes)))
                f.write(part_bytes)
                bytes_written += 1 + len(part_bytes)
    else:
        print("  WARNING: tokenizer.json not found, no merge rules written")

    return bytes_written, write_count, num_merges


def detect_model_type(config):
    """Detect whether a model is LLaMA-family or GPT-2-family.

    Returns 'gpt2' or 'llama'.
    """
    model_type = getattr(config, "model_type", "").lower()
    if model_type in ("gpt2", "gpt_neo", "gpt_neox", "gptj", "opt",
                       "pythia", "falcon", "phi"):
        return "gpt2"
    return "llama"


def export_bin(model_dir, tokenizer_dir, output_path, quantize="f32"):
    """Main export function."""

    print(f"Loading model from {model_dir}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_dir, torch_dtype=torch.float32
    )
    model.eval()

    config = model.config
    state_dict = model.state_dict()

    quant_type = QUANT_INT8 if quantize == "q8" else QUANT_FLOAT32
    is_q8 = (quant_type == QUANT_INT8)

    # Detect model family
    model_family = detect_model_type(config)
    is_gpt2 = (model_family == "gpt2")

    if is_gpt2:
        # GPT-2 config uses different attribute names
        num_layers = getattr(config, "n_layer",
                             getattr(config, "num_hidden_layers", 12))
        hidden_dim = getattr(config, "n_embd",
                             getattr(config, "hidden_size", 768))
        num_heads = getattr(config, "n_head",
                            getattr(config, "num_attention_heads", 12))
        num_kv_heads = num_heads  # GPT-2 has no GQA
        head_dim = hidden_dim // num_heads
        ffn_dim = getattr(config, "n_inner", None) or 4 * hidden_dim
        vocab_size = config.vocab_size
        max_seq_len = getattr(config, "n_positions",
                              getattr(config, "max_position_embeddings", 1024))
        rope_theta = 0  # GPT-2 uses learned positional embeddings, not RoPE
        tie_embeddings = getattr(config, "tie_word_embeddings", True)
    else:
        # LLaMA-family config
        num_layers = config.num_hidden_layers
        hidden_dim = config.hidden_size
        num_heads = config.num_attention_heads
        num_kv_heads = getattr(config, "num_key_value_heads", num_heads)
        head_dim = hidden_dim // num_heads
        ffn_dim = config.intermediate_size
        vocab_size = config.vocab_size
        max_seq_len = getattr(config, "max_position_embeddings", 1024)
        # rope_theta: check direct attribute first, then rope_parameters dict (Qwen)
        rope_theta = getattr(config, "rope_theta", None)
        if rope_theta is None:
            rope_params = getattr(config, "rope_parameters", {})
            if isinstance(rope_params, dict):
                rope_theta = rope_params.get("rope_theta", 10000)
            else:
                rope_theta = 10000
        rope_theta = int(rope_theta)
        tie_embeddings = getattr(config, "tie_word_embeddings", True)

    total_params = sum(p.numel() for p in model.parameters())

    arch_label = "GPT-2 (LayerNorm + GeLU + learned pos)" if is_gpt2 \
        else "LLaMA (RMSNorm + SwiGLU + RoPE)"
    print(f"\nModel architecture: {arch_label}")
    print(f"  Family:     {model_family}")
    print(f"  Layers:     {num_layers}")
    print(f"  Hidden:     {hidden_dim}")
    print(f"  Heads:      {num_heads} (KV: {num_kv_heads})")
    print(f"  Head dim:   {head_dim}")
    print(f"  FFN dim:    {ffn_dim}")
    print(f"  Vocab:      {vocab_size}")
    print(f"  Max seq:    {max_seq_len}")
    if not is_gpt2:
        print(f"  RoPE theta: {rope_theta}")
    else:
        print(f"  Pos embed:  learned ({max_seq_len} positions)")
    print(f"  Tie embeds: {tie_embeddings}")
    print(f"  Total params: {total_params:,}")

    # List all tensor names for debugging
    print(f"\nState dict keys ({len(state_dict)}):")
    for name, tensor in state_dict.items():
        print(f"  {name}: {list(tensor.shape)}")

    # --- Write the .bin file ---
    print(f"\nWriting {output_path}...")

    # Check for attention bias
    if is_gpt2:
        # GPT-2 always has bias on all projections
        has_attn_bias = True
        has_ffn_bias = True
        print(f"\n  GPT-2: All projections have bias (attn + FFN)")
    else:
        # LLaMA-family: check for Qwen2-style Q/K/V bias
        has_attn_bias = f"model.layers.0.self_attn.q_proj.bias" in state_dict
        has_ffn_bias = False
        if has_attn_bias:
            kv_dim = num_kv_heads * head_dim
            print(f"\n  Attention bias detected (Qwen2-style):")
            print(f"    q_bias: [{hidden_dim}] = {hidden_dim * 4} bytes")
            print(f"    k_bias: [{kv_dim}] = {kv_dim * 4} bytes")
            print(f"    v_bias: [{kv_dim}] = {kv_dim * 4} bytes")
            print(f"    Total per layer: {(hidden_dim + 2 * kv_dim) * 4} bytes")
        else:
            print(f"\n  No attention bias detected")

    with open(output_path, "wb") as f:
        # 1. Write placeholder header (will rewrite with correct offsets)
        f.write(b"\x00" * HEADER_SIZE)

        # 2. Write vocab section
        vocab_offset = HEADER_SIZE
        print("\nWriting vocab section...")
        vocab_bytes, actual_vocab_size, num_merges = write_vocab_section(
            f, tokenizer_dir, model_vocab_size=vocab_size
        )
        print(f"  Vocab section: {vocab_bytes:,} bytes")

        # Verify vocab size matches model
        if actual_vocab_size != vocab_size:
            print(f"  WARNING: Tokenizer vocab ({actual_vocab_size}) != "
                  f"model vocab ({vocab_size})")

        # 3. Write weight tensors
        weights_offset = HEADER_SIZE + vocab_bytes
        print(f"\nWriting weight tensors at offset {weights_offset:,}...")

        total_weight_bytes = 0

        # Weight writer function based on quantization mode
        # Norms are ALWAYS float32 (small, need precision)
        NORM_NAMES = {"input_layernorm.weight", "post_attention_layernorm.weight",
                       "ln_1.weight", "ln_1.bias", "ln_2.weight", "ln_2.bias",
                       "ln_f.weight", "ln_f.bias"}

        def write_tensor(f_out, tensor, name="", force_f32=False):
            """Write tensor in the selected format (Q8 or f32)."""
            is_norm = name in NORM_NAMES
            if is_q8 and not is_norm and not force_f32:
                return write_q8_tensor(f_out, tensor)
            else:
                return write_big_endian_float_tensor(f_out, tensor)

        quant_label = "Q8_0" if is_q8 else "float32"
        print(f"\n  Quantization: {quant_label}")

        if is_gpt2:
            # ============================================================
            # GPT-2 weight writing
            # ============================================================
            # GPT-2's Conv1D stores weights as [in_features, out_features].
            # Our .bin format expects [out_features, in_features] (Linear).
            # So ALL Conv1D weights need transposing.
            # ============================================================

            # 3a. Token embeddings (wte)
            embed_key = "transformer.wte.weight"
            if embed_key not in state_dict:
                print(f"  ERROR: {embed_key} not found!")
                sys.exit(1)
            print(f"  {embed_key}: {list(state_dict[embed_key].shape)}")
            nbytes = write_tensor(f, state_dict[embed_key])
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes ({quant_label})")

            # 3b. Positional embeddings (wpe) - ALWAYS float32
            pos_key = "transformer.wpe.weight"
            if pos_key not in state_dict:
                print(f"  ERROR: {pos_key} not found!")
                sys.exit(1)
            pos_tensor = state_dict[pos_key]
            print(f"  {pos_key}: {list(pos_tensor.shape)}")
            nbytes = write_big_endian_float_tensor(f, pos_tensor)
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes (float32, always)")

            # 3c. Per-layer weights
            for layer_idx in range(num_layers):
                layer_bytes = 0
                prefix = f"transformer.h.{layer_idx}"
                print(f"  Layer {layer_idx}:", end="")

                # --- Combined QKV: split and transpose ---
                c_attn_w_key = f"{prefix}.attn.c_attn.weight"
                c_attn_b_key = f"{prefix}.attn.c_attn.bias"
                if c_attn_w_key not in state_dict:
                    print(f"\n    ERROR: {c_attn_w_key} not found!")
                    sys.exit(1)

                # c_attn.weight: [in=hidden, out=3*hidden] (Conv1D format)
                # Transpose to [3*hidden, hidden], then split into Q, K, V
                c_attn_w = state_dict[c_attn_w_key].t()  # [3*hidden, hidden]
                q_w, k_w, v_w = c_attn_w.split(hidden_dim, dim=0)
                # Each is now [hidden, hidden] in [out, in] layout - correct

                # Write Q, K, V projection weights
                nbytes = write_tensor(f, q_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                nbytes = write_tensor(f, k_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                nbytes = write_tensor(f, v_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # --- Output projection: transpose ---
                c_proj_w_key = f"{prefix}.attn.c_proj.weight"
                if c_proj_w_key not in state_dict:
                    print(f"\n    ERROR: {c_proj_w_key} not found!")
                    sys.exit(1)
                # c_proj.weight: [in=hidden, out=hidden] -> transpose
                o_w = state_dict[c_proj_w_key].t()  # [hidden, hidden]
                nbytes = write_tensor(f, o_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # --- FFN fc1 (up): transpose ---
                c_fc_w_key = f"{prefix}.mlp.c_fc.weight"
                if c_fc_w_key not in state_dict:
                    print(f"\n    ERROR: {c_fc_w_key} not found!")
                    sys.exit(1)
                # c_fc.weight: [in=hidden, out=ffn] -> transpose to [ffn, hidden]
                fc1_w = state_dict[c_fc_w_key].t()  # [ffn, hidden]
                nbytes = write_tensor(f, fc1_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # --- FFN fc2 (down): transpose ---
                c_proj_mlp_w_key = f"{prefix}.mlp.c_proj.weight"
                if c_proj_mlp_w_key not in state_dict:
                    print(f"\n    ERROR: {c_proj_mlp_w_key} not found!")
                    sys.exit(1)
                # mlp.c_proj.weight: [in=ffn, out=hidden] -> transpose to [hidden, ffn]
                fc2_w = state_dict[c_proj_mlp_w_key].t()  # [hidden, ffn]
                nbytes = write_tensor(f, fc2_w)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # --- LayerNorm weights and biases (always float32) ---
                # input_norm (ln_1)
                ln1_w_key = f"{prefix}.ln_1.weight"
                ln1_b_key = f"{prefix}.ln_1.bias"
                if ln1_w_key not in state_dict or ln1_b_key not in state_dict:
                    print(f"\n    ERROR: {ln1_w_key} or {ln1_b_key} not found!")
                    sys.exit(1)
                nbytes = write_big_endian_float_tensor(f, state_dict[ln1_w_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes
                nbytes = write_big_endian_float_tensor(f, state_dict[ln1_b_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # post_attn_norm (ln_2)
                ln2_w_key = f"{prefix}.ln_2.weight"
                ln2_b_key = f"{prefix}.ln_2.bias"
                if ln2_w_key not in state_dict or ln2_b_key not in state_dict:
                    print(f"\n    ERROR: {ln2_w_key} or {ln2_b_key} not found!")
                    sys.exit(1)
                nbytes = write_big_endian_float_tensor(f, state_dict[ln2_w_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes
                nbytes = write_big_endian_float_tensor(f, state_dict[ln2_b_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # --- Bias terms (always float32) ---
                # Split combined QKV bias
                if c_attn_b_key not in state_dict:
                    print(f"\n    ERROR: {c_attn_b_key} not found!")
                    sys.exit(1)
                c_attn_b = state_dict[c_attn_b_key]
                q_b, k_b, v_b = c_attn_b.split(hidden_dim, dim=0)

                # q_bias, k_bias, v_bias
                nbytes = write_big_endian_float_tensor(f, q_b)
                total_weight_bytes += nbytes
                layer_bytes += nbytes
                nbytes = write_big_endian_float_tensor(f, k_b)
                total_weight_bytes += nbytes
                layer_bytes += nbytes
                nbytes = write_big_endian_float_tensor(f, v_b)
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # o_bias (c_proj bias)
                c_proj_b_key = f"{prefix}.attn.c_proj.bias"
                if c_proj_b_key not in state_dict:
                    print(f"\n    ERROR: {c_proj_b_key} not found!")
                    sys.exit(1)
                nbytes = write_big_endian_float_tensor(f, state_dict[c_proj_b_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # ffn_up_bias (c_fc bias)
                c_fc_b_key = f"{prefix}.mlp.c_fc.bias"
                if c_fc_b_key not in state_dict:
                    print(f"\n    ERROR: {c_fc_b_key} not found!")
                    sys.exit(1)
                nbytes = write_big_endian_float_tensor(f, state_dict[c_fc_b_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                # ffn_down_bias (mlp.c_proj bias)
                c_proj_mlp_b_key = f"{prefix}.mlp.c_proj.bias"
                if c_proj_mlp_b_key not in state_dict:
                    print(f"\n    ERROR: {c_proj_mlp_b_key} not found!")
                    sys.exit(1)
                nbytes = write_big_endian_float_tensor(
                    f, state_dict[c_proj_mlp_b_key])
                total_weight_bytes += nbytes
                layer_bytes += nbytes

                fmt = "f32" if not is_q8 else "Q8+f32norms"
                fmt += "+bias"
                print(f" {layer_bytes:,} bytes ({fmt})")

            # 3d. Final LayerNorm (always float32)
            ln_f_w_key = "transformer.ln_f.weight"
            ln_f_b_key = "transformer.ln_f.bias"
            if ln_f_w_key not in state_dict or ln_f_b_key not in state_dict:
                print(f"  ERROR: {ln_f_w_key} or {ln_f_b_key} not found!")
                sys.exit(1)
            print(f"  {ln_f_w_key}: {list(state_dict[ln_f_w_key].shape)}")
            nbytes = write_big_endian_float_tensor(f, state_dict[ln_f_w_key])
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes (float32, always)")
            print(f"  {ln_f_b_key}: {list(state_dict[ln_f_b_key].shape)}")
            nbytes = write_big_endian_float_tensor(f, state_dict[ln_f_b_key])
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes (float32, always)")

            # 3e. lm_head: tied to wte for GPT-2
            if tie_embeddings:
                print("  lm_head: TIED to wte (not written)")
            else:
                lm_key = "lm_head.weight"
                if lm_key in state_dict:
                    print(f"  {lm_key}: {list(state_dict[lm_key].shape)}")
                    nbytes = write_tensor(f, state_dict[lm_key])
                    total_weight_bytes += nbytes
                    print(f"    -> {nbytes:,} bytes")

        else:
            # ============================================================
            # LLaMA-family weight writing (existing path)
            # ============================================================

            # 3a. Embedding table
            embed_key = "model.embed_tokens.weight"
            if embed_key not in state_dict:
                print(f"  ERROR: {embed_key} not found!")
                sys.exit(1)
            print(f"  {embed_key}: {list(state_dict[embed_key].shape)}")
            nbytes = write_tensor(f, state_dict[embed_key])
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes ({quant_label})")

            # 3b. Per-layer weights
            for layer_idx in range(num_layers):
                layer_bytes = 0
                print(f"  Layer {layer_idx}:", end="")
                for weight_name in LAYER_WEIGHT_NAMES:
                    full_key = f"model.layers.{layer_idx}.{weight_name}"
                    if full_key not in state_dict:
                        print(f"\n    ERROR: {full_key} not found!")
                        sys.exit(1)
                    tensor = state_dict[full_key]
                    nbytes = write_tensor(f, tensor, weight_name)
                    total_weight_bytes += nbytes
                    layer_bytes += nbytes
                # Write attention bias tensors if present (always float32)
                if has_attn_bias:
                    kv_dim = num_kv_heads * head_dim
                    bias_names = [
                        ("self_attn.q_proj.bias", hidden_dim),
                        ("self_attn.k_proj.bias", kv_dim),
                        ("self_attn.v_proj.bias", kv_dim),
                    ]
                    for bias_name, expected_size in bias_names:
                        full_key = f"model.layers.{layer_idx}.{bias_name}"
                        if full_key not in state_dict:
                            print(f"\n    ERROR: {full_key} not found!")
                            sys.exit(1)
                        bias_tensor = state_dict[full_key]
                        nbytes = write_big_endian_float_tensor(f, bias_tensor)
                        total_weight_bytes += nbytes
                        layer_bytes += nbytes

                fmt = "f32" if not is_q8 else "Q8+f32norms"
                if has_attn_bias:
                    fmt += "+bias"
                print(f" {layer_bytes:,} bytes ({fmt})")

            # 3c. Final RMSNorm (always float32)
            norm_key = "model.norm.weight"
            if norm_key not in state_dict:
                print(f"  ERROR: {norm_key} not found!")
                sys.exit(1)
            print(f"  {norm_key}: {list(state_dict[norm_key].shape)}")
            nbytes = write_big_endian_float_tensor(f, state_dict[norm_key])
            total_weight_bytes += nbytes
            print(f"    -> {nbytes:,} bytes (float32, always)")

            # 3d. lm_head (skip if tied)
            if tie_embeddings:
                print("  lm_head: TIED to embed_tokens (not written)")
            else:
                lm_key = "lm_head.weight"
                if lm_key in state_dict:
                    print(f"  {lm_key}: {list(state_dict[lm_key].shape)}")
                    nbytes = write_tensor(f, state_dict[lm_key])
                    total_weight_bytes += nbytes
                    print(f"    -> {nbytes:,} bytes")

        file_size = f.tell()
        print(f"\n  Weight section: {total_weight_bytes:,} bytes")
        print(f"  Total file size: {file_size:,} bytes "
              f"({file_size / 1024 / 1024:.1f} MB)")

        # 4. Rewrite header with correct offsets
        f.seek(0)

        # Detect architecture type, chat template, special tokens, pre-tokenizer
        arch_type = 1 if is_gpt2 else 0  # 0=LLaMA, 1=GPT-2
        flags = 0
        chat_template = 0  # 0=custom MacinAI, 1=ChatML, 2=raw
        im_start_token = 0
        im_end_token = 0
        pre_tokenizer_type = 0  # 0=GPT-2 (group digits), 1=Qwen (single digits)

        # Detect chat template and special tokens from tokenizer
        if is_gpt2:
            # GPT-2 is a text completion model, no chat template
            chat_template = 2  # kChatTemplateRaw
            try:
                tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
                eos_id = tokenizer.eos_token_id or 50256
                im_start_token = eos_id  # <|endoftext|>
                im_end_token = eos_id    # <|endoftext|>
                print(f"\n  GPT-2 chat template: Raw (text completion)")
                print(f"    <|endoftext|> = {eos_id}")
            except Exception as e:
                im_start_token = 50256
                im_end_token = 50256
                print(f"\n  GPT-2 chat template: Raw (default tokens)")
                print(f"    Warning: {e}")
        else:
            try:
                tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
                bos_str = getattr(tokenizer, 'bos_token', '')

                # Find <|im_start|> and <|im_end|> token IDs
                vocab = tokenizer.get_vocab()
                im_start_token = vocab.get('<|im_start|>', 0)
                im_end_token = vocab.get('<|im_end|>', 0)

                if bos_str == '[BOS]':
                    chat_template = 0  # Custom MacinAI
                    print(f"\n  Detected chat template: Custom MacinAI")
                elif im_start_token > 0 and im_end_token > 0:
                    chat_template = 1  # ChatML
                    print(f"\n  Detected chat template: ChatML")
                    print(f"    <|im_start|> = {im_start_token}")
                    print(f"    <|im_end|> = {im_end_token}")
                else:
                    # Check for Zephyr format via chat template text
                    chat_tpl_str = getattr(tokenizer, 'chat_template', '') or ''
                    is_zephyr = ('<|user|>' in chat_tpl_str
                                 or '<|assistant|>' in chat_tpl_str)
                    if is_zephyr:
                        chat_template = 3  # Zephyr
                        im_end_token = tokenizer.eos_token_id or 2
                        print(f"\n  Detected chat template: Zephyr "
                              f"(<|user|>/<|assistant|>)")
                        print(f"    EOS = {im_end_token}")
                    else:
                        chat_template = 1  # Default to ChatML
                        im_start_token = tokenizer.bos_token_id or 1
                        im_end_token = tokenizer.eos_token_id or 2
                        print(f"\n  Detected chat template: ChatML (fallback)")
                        print(f"    im_start={im_start_token}, "
                              f"im_end={im_end_token}")
            except Exception as e:
                print(f"\n  Warning: Could not detect chat template: {e}")

        if has_attn_bias:
            flags |= 0x01  # kFlagHasAttnBias
        if has_ffn_bias:
            flags |= 0x02  # kFlagHasFFNBias

        if not tie_embeddings:
            flags |= 0x04  # kFlagSeparateLMHead

        # Detect pre-tokenizer type from tokenizer.json
        try:
            tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
            if os.path.exists(tokenizer_json_path):
                with open(tokenizer_json_path) as tjf:
                    tj_data = json.load(tjf)
                pre_tok = tj_data.get("pre_tokenizer", {})
                pre_type = pre_tok.get("type", "")
                if pre_type == "Sequence":
                    # Check for Split+ByteLevel (Qwen style) or Metaspace inside Sequence
                    subs = pre_tok.get("pretokenizers", [])
                    found_metaspace = False
                    for sub in subs:
                        if sub.get("type") == "Metaspace":
                            pre_tokenizer_type = 2  # SentencePiece Metaspace
                            print(f"  Pre-tokenizer: SentencePiece Metaspace (in Sequence)")
                            found_metaspace = True
                            break
                    if not found_metaspace:
                        for sub in subs:
                            if sub.get("type") == "Split":
                                regex = sub.get("pattern", {}).get("Regex", "")
                                if "\\p{N}" in regex and "\\p{N}+" not in regex:
                                    pre_tokenizer_type = 1  # Qwen: single digits
                                    print(f"  Pre-tokenizer: Qwen (single-digit split)")
                                else:
                                    print(f"  Pre-tokenizer: GPT-2 (grouped digits)")
                                break
                elif pre_type == "ByteLevel":
                    print(f"  Pre-tokenizer: ByteLevel (GPT-2 compatible)")
                elif pre_type == "Metaspace":
                    pre_tokenizer_type = 2  # SentencePiece Metaspace
                    print(f"  Pre-tokenizer: SentencePiece Metaspace")
                else:
                    # Also check inside Sequence for Metaspace
                    print(f"  Pre-tokenizer: unknown ({pre_type}), using GPT-2")
        except Exception as e:
            print(f"  Pre-tokenizer detection failed: {e}")

        print(f"  archType={arch_type}, flags=0x{flags:02x}, "
              f"chatTemplate={chat_template}, preTokenizer={pre_tokenizer_type}")

        # 20 data fields + 12 reserved = 32 longs = 128 bytes
        write_big_endian_long(f, MAGIC)              # [0]  magic
        write_big_endian_long(f, FORMAT_VERSION)      # [1]  version
        write_big_endian_long(f, num_layers)          # [2]  numLayers
        write_big_endian_long(f, hidden_dim)          # [3]  hiddenDim
        write_big_endian_long(f, num_heads)           # [4]  numHeads
        write_big_endian_long(f, num_kv_heads)        # [5]  numKVHeads
        write_big_endian_long(f, head_dim)            # [6]  headDim
        write_big_endian_long(f, ffn_dim)             # [7]  ffnDim
        write_big_endian_long(f, vocab_size)          # [8]  vocabSize
        write_big_endian_long(f, max_seq_len)         # [9]  maxSeqLen
        write_big_endian_long(f, rope_theta)          # [10] ropeTheta
        write_big_endian_long(f, quant_type)          # [11] quantType
        write_big_endian_long(f, total_params)        # [12] totalParams
        write_big_endian_long(f, file_size)           # [13] fileSize
        write_big_endian_long(f, vocab_offset)        # [14] vocabOffset
        write_big_endian_long(f, weights_offset)      # [15] weightsOffset
        write_big_endian_long(f, num_merges)          # [16] numMerges
        write_big_endian_long(f, arch_type)           # [17] archType
        write_big_endian_long(f, flags)               # [18] flags
        write_big_endian_long(f, chat_template)       # [19] chatTemplate
        write_big_endian_long(f, im_start_token)      # [20] imStartToken
        write_big_endian_long(f, im_end_token)        # [21] imEndToken
        write_big_endian_long(f, pre_tokenizer_type) # [22] preTokenizerType

        # 9 reserved longs (zero)
        for _ in range(9):
            write_big_endian_long(f, 0)

    print(f"\nExport complete: {output_path}")
    print(f"  Header:  {HEADER_SIZE} bytes")
    print(f"  Vocab:   {vocab_bytes:,} bytes (at offset {vocab_offset})")
    print(f"  Weights: {total_weight_bytes:,} bytes "
          f"(at offset {weights_offset})")
    print(f"  Total:   {file_size:,} bytes ({file_size / 1024 / 1024:.1f} MB)")

    # Verify header
    print("\nVerifying header...")
    with open(output_path, "rb") as f:
        header = f.read(HEADER_SIZE)
        magic_check = struct.unpack(">I", header[0:4])[0]
        if magic_check == MAGIC:
            print("  Magic: OK (MCAI)")
        else:
            print(f"  Magic: FAILED (got 0x{magic_check:08X})")

        ver = struct.unpack(">i", header[4:8])[0]
        print(f"  Version: {ver}")

        nl = struct.unpack(">i", header[8:12])[0]
        print(f"  Layers: {nl}")

        hd = struct.unpack(">i", header[12:16])[0]
        print(f"  Hidden: {hd}")

        at = struct.unpack(">i", header[68:72])[0]
        print(f"  ArchType: {at} ({'GPT-2' if at == 1 else 'LLaMA'})")

        fl = struct.unpack(">i", header[72:76])[0]
        print(f"  Flags: 0x{fl:02x}")

        fs = struct.unpack(">i", header[52:56])[0]
        print(f"  File size in header: {fs:,}")
        print(f"  Actual file size:    {file_size:,}")
        if fs == file_size:
            print("  Size check: OK")
        else:
            print("  Size check: MISMATCH!")

    # Set Mac file type/creator codes
    # Type: 'MCAI' (MacinAI Model), Creator: 'OAS ' (Old Apple Stuff)
    set_mac_type_creator(output_path)


def main():
    parser = argparse.ArgumentParser(
        description="Export HuggingFace model to MacinAI .bin format"
    )
    parser.add_argument(
        "--model-dir",
        required=True,
        help="Path to HuggingFace model directory",
    )
    parser.add_argument(
        "--tokenizer-dir",
        default=None,
        help="Path to tokenizer directory (default: same as model-dir)",
    )
    parser.add_argument(
        "--output",
        default="models/macinai_local.bin",
        help="Output .bin file path (default: models/macinai_local.bin)",
    )
    parser.add_argument(
        "--quantize",
        default="f32",
        choices=["f32", "q8"],
        help="Quantization: f32 (default, ~378MB) or q8 (~94.5MB)",
    )

    args = parser.parse_args()

    if args.tokenizer_dir is None:
        # Use the model directory as tokenizer source by default
        args.tokenizer_dir = args.model_dir

    if not os.path.exists(args.model_dir):
        print(f"Error: Model directory not found: {args.model_dir}")
        sys.exit(1)

    if not os.path.exists(args.tokenizer_dir):
        print(f"Error: Tokenizer directory not found: {args.tokenizer_dir}")
        sys.exit(1)

    # Create output directory if needed
    os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)

    export_bin(args.model_dir, args.tokenizer_dir, args.output,
               quantize=args.quantize)


if __name__ == "__main__":
    main()
