
"""
AI Posture Correction Assistant - single-file runner
Features:
- Uses MediaPipe Pose and OpenCV to detect pose landmarks
- Calculates joint angles and simple posture correctness checks
- Computes torso vs leg length ratio to identify long legs / long torso
- Gives personalized recommendation for squats (e.g., elevate heels) based on proportions
- Works with webcam (default) or a video file passed with --video path
- Saves a short CSV summary with detected proportion class and basic stats

Run on macOS (assumes python3 and packages installed):
pip install opencv-python mediapipe numpy pandas

Usage:
python ai_posture_assistant.py            # use webcam
python ai_posture_assistant.py --video demo.mp4   # use video file
"""

import cv2
import mediapipe as mp
import numpy as np
import argparse
import time
import pandas as pd
from collections import Counter
import textwrap

mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

def to_pixel_coords(landmark, image_shape):
    h, w = image_shape[:2]
    return int(landmark.x * w), int(landmark.y * h)

def distance(a, b):
    return np.linalg.norm(np.array(a) - np.array(b))

def classify_body_proportions(landmarks, image_shape, threshold=1.25):
    L_sh = to_pixel_coords(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER], image_shape)
    R_sh = to_pixel_coords(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER], image_shape)
    L_hip = to_pixel_coords(landmarks[mp_pose.PoseLandmark.LEFT_HIP], image_shape)
    R_hip = to_pixel_coords(landmarks[mp_pose.PoseLandmark.RIGHT_HIP], image_shape)
    L_ank = to_pixel_coords(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE], image_shape)
    R_ank = to_pixel_coords(landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE], image_shape)
    
    torso_left = distance(L_sh, L_hip)
    torso_right = distance(R_sh, R_hip)
    leg_left = distance(L_hip, L_ank)
    leg_right = distance(R_hip, R_ank)
    
    torso = (torso_left + torso_right) / 2.0
    leg = (leg_left + leg_right) / 2.0
    
    if torso < 1e-3:
        return "unknown", torso, leg, 1.0
    
    ratio = leg / torso
    
    if ratio > threshold:
        return "long_legs", torso, leg, ratio
    elif (1.0/ratio) > threshold:
        return "long_torso", torso, leg, ratio
    else:
        return "balanced", torso, leg, ratio

def get_recommendation_for_squat(classification):
    if classification == "long_legs":
        return "You have proportionally long legs. Try elevating your heels on a small plate or wedge to improve squat depth."
    elif classification == "long_torso":
        return "You have a relatively long torso. Keep your chest more upright and consider a slightly wider stance."
    elif classification == "balanced":
        return "Body proportions are balanced. Standard squat cues apply."
    else:
        return "Proportions unclear. Ensure whole body is visible to camera."

def main(video_source=0, output_csv="/tmp/posture_summary.csv", max_frames=None):
    cap = cv2.VideoCapture(video_source)
    if not cap.isOpened():
        print("Error opening video source:", video_source)
        return
    
    fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS)>0 else 30
    print("FPS:", fps)
    
    pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
    frame_count = 0
    class_counts = {}
    ratio_samples = []
    
    start_time = time.time()
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = pose.process(image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        h, w = image.shape[:2]
        
        if results.pose_landmarks:
            cls, torso, leg, ratio = classify_body_proportions(results.pose_landmarks.landmark, image.shape)
            class_counts[cls] = class_counts.get(cls, 0) + 1
            ratio_samples.append(ratio)
            rec = get_recommendation_for_squat(cls)
            cv2.putText(image, f"Proportions: {cls.replace('_',' ')} ({ratio:.2f})", (10,30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
            cv2.putText(image, f"Torso: {int(torso)} px  Leg: {int(leg)} px", (10,60),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200,200,200), 1)
            y0 = 90
            for i, line in enumerate(textwrap.wrap(rec, width=60)):
                cv2.putText(image, line, (10, y0 + i*25), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0,255,255), 2)
        else:
            cv2.putText(image, "No full pose detected. Please show entire body.", (10,30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,255), 2)
        
        cv2.imshow("AI Posture Correction Assistant", image)
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        if max_frames and frame_count >= max_frames:
            break
    
    elapsed = time.time() - start_time
    cap.release()
    cv2.destroyAllWindows()
    pose.close()
    
    summary = {
        "frames_processed": frame_count,
        "time_sec": elapsed,
        "fps_est": frame_count/elapsed if elapsed>0 else 0,
        "most_common_class": max(class_counts, key=class_counts.get) if class_counts else "none",
        "class_counts": class_counts,
        "ratio_mean": float(np.mean(ratio_samples)) if ratio_samples else None,
        "ratio_std": float(np.std(ratio_samples)) if ratio_samples else None
    }
    try:
        df = pd.DataFrame([summary])
        df.to_csv(output_csv, index=False)
        print("Saved summary to", output_csv)
    except Exception as e:
        print("Could not save CSV:", e)
    print("Done. Processed frames:", frame_count)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--video", help="Path to video file. If omitted, webcam is used.", default=None)
    parser.add_argument("--out", help="CSV output path", default="/tmp/posture_summary.csv")
    parser.add_argument("--max-frames", help="Stop after N frames", type=int, default=None)
    args = parser.parse_args()
    source = args.video if args.video else 0
    main(video_source=source, output_csv=args.out, max_frames=args.max_frames)
