import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from config import get_config

config = get_config()

class RecipeGenerator:
    
    def __init__(self):
        self.device = config.DEVICE
        self.model_name = config.MODEL_NAME
        
        try:
            self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
            self.model = GPT2LMHeadModel.from_pretrained(self.model_name)
            self.model.to(self.device)
            self.model.eval()
            print(f" Recipe model loaded on {self.device}")
        except Exception as e:
            print(f" Error loading model: {e}")
            self.model = None
            self.tokenizer = None
    
    def generate_recipe(self, ingredients, meal_type="", num_words=150):
        
        if self.model is None or self.tokenizer is None:
            return "Model not available. Please check configuration."
        
        try:
            ingredients_str = ", ".join(ingredients)
            prompt = f"Recipe with {ingredients_str}"
            
            if meal_type:
                prompt += f" for {meal_type}"
            
            prompt += ":\n"
            input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
            
        
            with torch.no_grad():
                output_ids = self.model.generate(
                    input_ids,
                    max_length=input_ids.shape[1] + num_words,
                    num_return_sequences=1,
                    temperature=config.TEMPERATURE,
                    top_p=config.TOP_P,
                    do_sample=True,
                    no_repeat_ngram_size=2,
                    pad_token_id=self.tokenizer.eos_token_id,
                    attention_mask=torch.ones_like(input_ids)
                )
            
            # Decode output
            recipe = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
            
            return recipe
        
        except Exception as e:
            print(f"Error generating recipe: {e}")
            return f"Error: Could not generate recipe. {str(e)}"
    
    def generate_batch(self, ingredients_list, meal_type="", num_words=150):
        recipes = []
        for ingredients in ingredients_list:
            recipe = self.generate_recipe(ingredients, meal_type, num_words)
            recipes.append(recipe)
        
        return recipes
    
    def is_available(self):
        return self.model is not None and self.tokenizer is not None


# Global instance
recipe_generator = None

def get_recipe_generator():
    global recipe_generator
    if recipe_generator is None:
        recipe_generator = RecipeGenerator()
    return recipe_generator