"""RDKit-based molecular graph construction."""

from __future__ import annotations

from typing import List, Tuple

import networkx as nx
import numpy as np
from rdkit import Chem

from .data_loader import DATA
from .utils import read_xyz_file


def build_graph_rdkit(
    xyz_file: str | List[Tuple[str, Tuple[float, float, float]]],
    charge: int = 0,
    bohr_units: bool = False,
) -> nx.Graph:
    """Build molecular graph using RDKit's DetermineBonds algorithm.

    Uses RDKit's distance-based bond perception with Huckel rule for
    conjugation.

    Parameters
    ----------
    xyz_file : str or list of (symbol, (x, y, z))
        Either path to XYZ file or list of atom tuples.
    charge : int
        Total molecular charge.
    bohr_units : bool
        Whether coordinates are in Bohr (only used if *xyz_file* is a path).

    Returns
    -------
    nx.Graph
        Molecular graph with nodes containing:
        ``symbol``, ``atomic_number``, ``position``, ``charges`` (empty),
        ``formal_charge``, ``valence``, ``agg_charge``.

    Raises
    ------
    ValueError
        If RDKit fails to parse the structure or determine bonds.

    Notes
    -----
    RDKit has limited support for coordination complexes.  For
    metal-containing systems, consider ``build_graph()`` with
    ``method='cheminf'`` or ``build_graph_orca()`` instead.
    """
    from rdkit.Chem import rdDetermineBonds

    # Handle input
    if isinstance(xyz_file, str):
        atoms = read_xyz_file(xyz_file, bohr_units=bohr_units)
    else:
        atoms = xyz_file

    # Build XYZ block for RDKit
    nat = len(atoms)
    symbols = [symbol for symbol, _ in atoms]
    positions = [pos for _, pos in atoms]
    xyz_lines = [str(nat), f"Generated by xyzgraph build_graph_rdkit (charge={charge})"]
    for sym, (x, y, z) in zip(symbols, positions):
        xyz_lines.append(f"{sym} {x:.6f} {y:.6f} {z:.6f}")
    xyz_block = "\n".join(xyz_lines) + "\n"

    # Parse with RDKit
    raw_mol = Chem.MolFromXYZBlock(xyz_block)
    if raw_mol is None:
        raise ValueError("RDKit MolFromXYZBlock failed to parse structure")

    # Determine bonds
    mol = Chem.Mol(raw_mol)
    try:
        rdDetermineBonds.DetermineBonds(mol, charge=charge, useHueckel=True)
    except Exception as e:
        if any(s in DATA.metals for s in symbols):
            raise ValueError(f"RDKit DetermineBonds failed (metal atoms detected): {e}") from e
        raise ValueError(f"RDKit DetermineBonds failed: {e}") from e

    if mol.GetNumBonds() == 0:
        raise ValueError("RDKit DetermineBonds produced no bonds")

    # Light sanitize
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
    except Exception:
        pass

    # Build NetworkX graph
    G = nx.Graph()

    for a in mol.GetAtoms():
        i = a.GetIdx()
        symbol = a.GetSymbol()
        atomic_number = DATA.s2n.get(symbol)
        if atomic_number is None:
            raise ValueError(f"Unknown element symbol: {symbol}")

        G.add_node(
            i,
            symbol=symbol,
            atomic_number=atomic_number,
            position=positions[i],
            formal_charge=a.GetFormalCharge(),
        )

    for b in mol.GetBonds():
        i = b.GetBeginAtomIdx()
        j = b.GetEndAtomIdx()

        if b.GetIsAromatic() or b.GetBondType() == Chem.BondType.AROMATIC:
            bo = 1.5
        elif b.GetBondType() == Chem.BondType.SINGLE:
            bo = 1.0
        elif b.GetBondType() == Chem.BondType.DOUBLE:
            bo = 2.0
        elif b.GetBondType() == Chem.BondType.TRIPLE:
            bo = 3.0
        else:
            bo = 1.0

        pos_i = np.array(positions[i])
        pos_j = np.array(positions[j])
        distance = float(np.linalg.norm(pos_i - pos_j))

        si = G.nodes[i]["symbol"]
        sj = G.nodes[j]["symbol"]

        G.add_edge(
            i,
            j,
            bond_order=bo,
            distance=distance,
            bond_type=(si, sj),
            metal_coord=(si in DATA.metals or sj in DATA.metals),
        )

    # Derived properties
    for node in G.nodes():
        # Split valence: organic (excludes metal bonds) and metal (coordination bonds)
        organic_val = sum(
            G[node][nbr].get("bond_order", 1.0)
            for nbr in G.neighbors(node)
            if G.nodes[nbr]["symbol"] not in DATA.metals
        )
        metal_val = sum(
            G[node][nbr].get("bond_order", 1.0) for nbr in G.neighbors(node) if G.nodes[nbr]["symbol"] in DATA.metals
        )
        G.nodes[node]["valence"] = organic_val
        G.nodes[node]["metal_valence"] = metal_val

        agg_charge = float(G.nodes[node]["formal_charge"])
        for nbr in G.neighbors(node):
            if G.nodes[nbr]["symbol"] == "H":
                agg_charge += G.nodes[nbr]["formal_charge"]
        G.nodes[node]["agg_charge"] = agg_charge

    # Metadata
    from . import __citation__, __version__

    G.graph["metadata"] = {
        "version": __version__,
        "citation": __citation__,
        "source": "rdkit",
    }
    G.graph["total_charge"] = charge
    G.graph["method"] = "rdkit"

    return G
