from pyvis import network as net
import pandas as pd
import numpy as np
import webbrowser

# === Load CSV ===
df = pd.read_csv("MedicalExample/Medical3.csv")

# === Edge color based on confidence ===
def color_from_conf(conf):
    # map confidence 0.3 → red, 1.0 → green
    norm_conf = (conf - 0.3) / (1.0 - 0.3)
    norm_conf = max(0, min(1, norm_conf))
    r = int(255 * (1 - norm_conf))
    g = int(255 * norm_conf)
    return f"rgb({r},{g},0)"

# === Create PyVis graph ===
graph = net.Network(
    notebook=False,
    cdn_resources='remote',
    directed=True,
    height='750px',
    width='100%',
    select_menu=True,
    filter_menu=True
)

graph.set_options("""
const options = {
  "physics": {"enabled": true, "solver": "forceAtlas2Based", "avoidOverlap": 0.6},
  "interaction": {"navigationButtons": true, "keyboard": true},
  "nodes": {"size": 15},
  "edges": {"width": 2, "smooth": true, "font": {"size": 15, "align": "top"}}
}
""")

# === Add nodes (uniform appearance) ===
all_nodes = set(df["Subject"]).union(set(df["Object"]))

for node in all_nodes:
    graph.add_node(
        node,
        label=node,
        shape="box",
        color="rgb(180,200,255)",  # static color for all nodes
        size=15,
        title=f"Node: {node}"
    )

# === Add edges exactly as in CSV ===
for _, row in df.iterrows():
    subj, pred, obj, conf = row["Subject"], row["Predicate"], row["Object"], row["confidence"]

    edge_color = color_from_conf(conf)
    edge_width = 1 + (conf * 4)

    graph.add_edge(
        subj,
        obj,
        label=pred,
        title=f"{subj} {pred} {obj}\nConfidence: {conf:.2f}",
        color=edge_color,
        width=edge_width
    )

# === Export HTML ===
output_file = "MedicalExample/Medical_Raw_KG.html"
graph.write_html(output_file)
webbrowser.open(output_file)

print("\n=== Raw Medical Knowledge Graph ===")
print(f"Total nodes: {len(all_nodes)}")
print(f"Total edges: {len(df)}")
print("Visualization written to:", output_file)
