#!/usr/bin/env python3
"""
Simple script to run the Hunyuan model locally.
"""

import os
import re
from transformers import AutoModelForCausalLM, AutoTokenizer

def main():
    # Use the current directory as the model path since all model files are here
    model_path = "."
    
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        device_map="auto",
        torch_dtype="auto"
    )
    
    print("Model loaded successfully!")
    
    # Example conversation with shorter generation
    messages = [
        {"role": "user", "content": "What is 2+2?"}
    ]
    
    print("\nGenerating response...")
    tokenized_chat = tokenizer.apply_chat_template(
        messages, 
        tokenize=True, 
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=True  # Enable thinking mode
    )
    
    # Generate response with shorter parameters
    outputs = model.generate(
        tokenized_chat.to(model.device), 
        max_new_tokens=100,  # Reduced from 2048
        do_sample=True,
        top_k=20,
        top_p=0.8,
        repetition_penalty=1.05,
        temperature=0.7
    )
    
    output_text = tokenizer.decode(outputs[0])
    print("\n" + "="*50)
    print("FULL OUTPUT:")
    print(output_text)
    print("="*50)
    
    # Parse thinking and answer content
    think_pattern = r'<think>(.*?)</think>'
    think_matches = re.findall(think_pattern, output_text, re.DOTALL)
    
    answer_pattern = r'<answer>(.*?)</answer>'
    answer_matches = re.findall(answer_pattern, output_text, re.DOTALL)
    
    if think_matches:
        think_content = think_matches[0].strip()
        print(f"\nTHINKING PROCESS:\n{think_content}")
    
    if answer_matches:
        answer_content = answer_matches[0].strip()
        print(f"\nFINAL ANSWER:\n{answer_content}")
    
    print("\nModel is working correctly! You can now use it interactively.")

if __name__ == "__main__":
    main()
