from pyvis import network as net
import pandas as pd
import numpy as np
from collections import defaultdict
import webbrowser

class BayesianKG:
    def __init__(self, prior_strength = 0.5, evidence_scale = 3.0):
        # Prior_Strength = how strong the belief WAS
        self.edge_beliefs = {} #subj-pred-obj --> (a,b)
        self.node_reliability = defaultdict(lambda: (prior_strength, prior_strength)) #node --> (a,b)
        self.predicate_priors = defaultdict(lambda: (prior_strength, prior_strength))  # NEW: pred --> (a,b)
        self.prior_strength = prior_strength
        self.evidence_scale = evidence_scale
        
    def get_reliability(self, node):
        alpha, beta = self.node_reliability[node]
        return alpha / (alpha + beta)
    
    def update_node_reliability(self, node, confidence):
        alpha, beta = self.node_reliability[node]
        alpha += confidence
        beta += (1-confidence)
        self.node_reliability[node] = (alpha, beta)
    
    def get_node_reliability(self, node):
        return self.get_reliability(node)
    
    # def get_predicate_prior(self, pred):  # NEW METHOD
    #     """Get the learned prior for this predicate type"""
    #     alpha, beta = self.predicate_priors[pred]
    #     return alpha / (alpha + beta)
    
    def update_predicate_prior(self, pred, confidence):  # NEW METHOD
        """Update predicate-level statistics"""
        alpha, beta = self.predicate_priors[pred]
        alpha += confidence
        beta += (1 - confidence)
        self.predicate_priors[pred] = (alpha, beta)
    
    def get_edge_confidence(self, subj, pred, obj): # Get confidence for a given edge
        edge_key = (subj, pred, obj)
        if edge_key not in self.edge_beliefs:
            return 0.5
        alpha, beta = self.edge_beliefs[edge_key]
        return alpha / (alpha + beta)

    def get_edge_uncertainty(self, subj, pred, obj): # Get certainty of confidence value --> variance
        edge_key = (subj, pred, obj)
        if edge_key not in self.edge_beliefs:
            return 0.25
        alpha, beta = self.edge_beliefs[edge_key]
        n = alpha + beta
        return (alpha * beta) / (n * n * (n+1))

    def add_observation(self, subj, pred, obj, confidence):
        edge_key = (subj, pred, obj)

        # Add node confidence as weight
        subj_reliability = self.get_reliability(subj)
        obj_reliability = self.get_reliability(obj)
        node_weight = (subj_reliability + obj_reliability)/2

        # NEW: Get predicate prior (learned from similar relationships)
        # pred_prior = self.get_predicate_prior(pred)

        # Add edge to KG if not already there
        if edge_key not in self.edge_beliefs:
            # NEW: Initialize with predicate prior instead of generic prior
            pred_alpha, pred_beta = self.predicate_priors[pred]
            self.edge_beliefs[edge_key] = (pred_alpha, pred_beta) # set default alpha/beta before overwrite

        alpha, beta = self.edge_beliefs[edge_key]
        # alpha += confidence
        # beta += (1 - confidence)

        # weighted_conf = confidence * node_weight
        # alpha += weighted_conf
        # beta += (1 - weighted_conf)
        
        evidence_strength = node_weight * self.evidence_scale
        alpha += confidence * evidence_strength
        beta += (1 - confidence) * evidence_strength

        self.edge_beliefs[edge_key] = (alpha, beta) # overwrite alpha/beta

        self.update_node_reliability(subj, confidence)
        self.update_node_reliability(obj, confidence)
        self.update_predicate_prior(pred, confidence)  # NEW: Update predicate prior

        return alpha / (alpha + beta)

def color_to_confidence(conf):
    norm_conf = (conf - 0.3) / (1.0 - 0.3)
    norm_conf = max(0, min(1, norm_conf))
    r = int(255 * (1 - norm_conf))
    g = int(255 * norm_conf)
    b = 0
    return f'rgb({r},{g},{b})'

def node_color_from_conf(conf):
    # Clamp 0–1
    conf = max(0, min(1, conf))

    if conf <= 0.5:
        # Interpolate between red (0) → pastel (0.5)
        t = conf / 0.5
        r = int(255 + t * (180 - 255))   # 255 → 180
        g = int(0   + t * (200 - 0))     # 0   → 200
        b = int(0   + t * (255 - 0))     # 0   → 255
    else:
        # Interpolate between pastel (0.5) → green (1)
        t = (conf - 0.5) / 0.5
        r = int(180 + t * (0   - 180))   # 180 → 0
        g = int(200 + t * (255 - 200))   # 200 → 255
        b = int(255 + t * (0   - 255))   # 255 → 0

    return f"rgb({r},{g},{b})"

df = pd.read_csv('MedicalExample/Medical3.csv')

bkg = BayesianKG()

bayesian_confidences = []
for _, row in df.iterrows():
    subj, pred, obj, conf = row['Subject'], row['Predicate'], row['Object'], row['confidence']
    # bayesian_confidences.setdefault(subj, []).append(conf)
    # bayesian_confidences.setdefault(obj, []).append(conf)
    bayesian_conf = bkg.add_observation(subj, pred, obj, conf)
    bayesian_confidences.append(bayesian_conf)

df['bayesian_confidence'] = bayesian_confidences

node_stats = {}
all_nodes = set(df['Subject']).union(set(df['Object']))
for node in all_nodes:
    reliability = bkg.get_reliability(node)
    edges = df[(df['Subject'] == node) | (df['Object'] == node)]
    avg_conf = edges['bayesian_confidence'].mean()
    node_stats[node] = {
        'reliability': reliability,
        'avg_conf': avg_conf,
        'edge_count': len(edges)
    }

graph = net.Network(
    notebook = False,
    cdn_resources = 'remote',
    directed = True,
    height = '750px',
    width = '100%',
    select_menu = True,
    filter_menu = True
)
graph.set_options("""
const options = {
  "physics": {"enabled": true, "solver": "forceAtlas2Based", "avoidOverlap": 0.5},
  "interaction": {"navigationButtons": true, "keyboard": true},
  "nodes": {"size": 15},
  "edges": {"width": 2, "smooth": true, "font": {"size": 15, "align": "top"}}
}
""")

for node, stats in node_stats.items():
    title_text = (f"Node: {node}\n"
                  f"Reliability: {stats['reliability']:.3f}\n"
                  f"Avg Confidence: {stats['avg_conf']:.3f}\n"
                  f"Edges: {stats['edge_count']}")
    node_size = 10 + (stats['reliability'] * 20)
    # rel_color = int(255 * stats['reliability'])
    node_color = node_color_from_conf(stats['avg_conf'])

    graph.add_node(
        node, 
        label = node,
        shape = "box",
        title = title_text,
        color = node_color,
        size = node_size
    )

for idx, row in df.iterrows():
    subj, pred, obj = row['Subject'], row['Predicate'], row['Object']
    original_conf = row['confidence']
    bayesian_conf = row['bayesian_confidence']
    uncertainty = bkg.get_edge_uncertainty(subj, pred, obj)

    color = color_to_confidence(bayesian_conf)
    edge_label = f"{pred}"
    edge_title = (f"{subj} {pred} {obj}\n"
                  f"Original: {original_conf:.2f}\n"
                  f"Bayesian: {bayesian_conf:.2f}\n"
                  f"Uncertainty: {uncertainty:.4f}")

    edge_width = 1 + (bayesian_conf * 4)
    
    graph.add_edge(
        subj, obj,
        label = edge_label,
        title = edge_title,
        color = color,
        width = edge_width
    )
output_file = 'MedicalExample/Medical_BKG.html'
graph.write_html(output_file)
webbrowser.open(output_file)

# Print summary statistics
print("\n=== Bayesian Knowledge Graph Statistics ===")
print(f"Total nodes: {len(node_stats)}")
print(f"Total edges: {len(df)}")
print("\nTop 5 most reliable nodes:")
top_nodes = sorted(node_stats.items(), key=lambda x: x[1]['reliability'], reverse=True)[:5]
for node, stats in top_nodes:
    print(f"  {node}: {stats['reliability']:.3f}")

print("\nEdges with largest confidence updates:")
df['confidence_change'] = abs(df['bayesian_confidence'] - df['confidence'])
top_changes = df.nlargest(5, 'confidence_change')[['Subject', 'Predicate', 'Object', 'confidence', 'bayesian_confidence']]
print(top_changes.to_string(index=False))
    
