|
|
| import gradio as gr |
| import datetime |
| import pandas as pd |
| from groq import Groq |
| from sentence_transformers import SentenceTransformer |
| import chromadb |
| from chromadb.config import Settings |
| import hashlib |
| from typing import TypedDict, Optional, List |
| from langgraph.graph import StateGraph, END |
| import json |
| import tempfile |
| import subprocess |
| import os |
|
|
| |
| |
| api_key_coder= os.environ.get('api_key_coder') |
| |
| |
| |
| class CodeAssistantState(TypedDict): |
| user_input: str |
| similar_examples: Optional[List[str]] |
| generated_code: Optional[str] |
| error: Optional[str] |
| task_type: Optional[str] |
| evaluation_result: Optional[str] |
|
|
| |
| |
| |
| |
| df = pd.read_parquet("hf://datasets/openai/openai_humaneval/openai_humaneval/test-00000-of-00001.parquet") |
| extracted_data = df[['task_id', 'prompt', 'canonical_solution']] |
|
|
| |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
| groq_client = Groq(api_key=api_key_coder) |
|
|
| client = chromadb.Client(Settings( |
| anonymized_telemetry=False, |
| persist_directory="rag_db" |
| )) |
| collection = client.get_or_create_collection( |
| name="code_examples", |
| metadata={"hnsw:space": "cosine"} |
| ) |
|
|
| |
| |
| |
|
|
| def initialize_db(state: CodeAssistantState): |
| try: |
| for _, row in extracted_data.iterrows(): |
| embedding = embedding_model.encode([row['prompt'].strip()])[0] |
| doc_id = hashlib.md5(row['prompt'].encode()).hexdigest() |
| collection.add( |
| documents=[row['canonical_solution'].strip()], |
| metadatas=[{"prompt": row['prompt'], "type": "code_example"}], |
| ids=[doc_id], |
| embeddings=[embedding] |
| ) |
| return state |
| except Exception as e: |
| state["error"] = f"DB initialization failed: {str(e)}" |
| return state |
|
|
| def retrieve_examples(state: CodeAssistantState): |
| try: |
| embedding = embedding_model.encode([state["user_input"]])[0] |
| results = collection.query( |
| query_embeddings=[embedding], |
| n_results=2 |
| ) |
| state["similar_examples"] = results['documents'][0] if results['documents'] else None |
| return state |
| except Exception as e: |
| state["error"] = f"Retrieval failed: {str(e)}" |
| return state |
|
|
| def classify_task_llm(state: CodeAssistantState) -> CodeAssistantState: |
| if not isinstance(state, dict): |
| raise ValueError("State must be a dictionary") |
|
|
| if "user_input" not in state or not state["user_input"].strip(): |
| state["error"] = "No user input provided for classification" |
| state["task_type"] = "generate" |
| return state |
|
|
| try: |
| prompt = f"""You are a helpful code assistant. Classify the user request as one of the following tasks: |
| - "generate": if the user wants to write or generate code |
| - "explain": if the user wants to understand what a code snippet does |
| - "test": if the user wants to test existing code |
| Return ONLY a JSON object in the format: {{"task": "...", "user_input": "..."}} — no explanation. |
| User request: {state["user_input"]} |
| """ |
| completion = groq_client.chat.completions.create( |
| model="llama3-70b-8192", |
| messages=[ |
| {"role": "system", "content": "Classify code-related user input. Respond with ONLY JSON."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.3, |
| max_tokens=200, |
| response_format={"type": "json_object"} |
| ) |
|
|
| content = completion.choices[0].message.content.strip() |
|
|
| try: |
| result = json.loads(content) |
| if not isinstance(result, dict): |
| raise ValueError("Response is not a JSON object") |
| except (json.JSONDecodeError, ValueError) as e: |
| state["error"] = f"Invalid response format from LLM: {str(e)}. Content: {content}" |
| state["task_type"] = "generate" |
| return state |
|
|
| task_type = result.get("task", "").lower() |
| if task_type not in ["generate", "explain", "test"]: |
| state["error"] = f"Invalid task type received: {task_type}" |
| task_type = "generate" |
|
|
| state["task_type"] = task_type |
| state["user_input"] = result.get("user_input", state["user_input"]) |
| return state |
|
|
| except Exception as e: |
| state["error"] = f"LLM-based classification failed: {str(e)}" |
| state["task_type"] = "generate" |
| return state |
|
|
| def test_code(state: CodeAssistantState) -> CodeAssistantState: |
| if not isinstance(state, dict): |
| raise ValueError("State must be a dictionary") |
|
|
| if "user_input" not in state or not state["user_input"].strip(): |
| state["error"] = "Please provide the code you want to test" |
| return state |
|
|
| try: |
| messages = [ |
| {"role": "system", "content": """You are a Python testing expert. Generate unit tests for the provided code. |
| Return the test code in the following format: |
| ```python |
| # Test code here |
| ```"""}, |
| {"role": "user", "content": f"Generate comprehensive unit tests for this Python code:\n\n{state['user_input']}"} |
| ] |
|
|
| completion = groq_client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=messages, |
| temperature=0.5, |
| max_tokens=2048, |
| ) |
|
|
| test_code = completion.choices[0].message.content |
| if test_code.startswith('```python'): |
| test_code = test_code[9:-3] if test_code.endswith('```') else test_code[9:] |
| elif test_code.startswith('```'): |
| test_code = test_code[3:-3] if test_code.endswith('```') else test_code[3:] |
|
|
| state["generated_tests"] = test_code.strip() |
| state["metadata"] = { |
| "model": "llama-3.3-70b-versatile", |
| "timestamp": datetime.datetime.now().isoformat() |
| } |
|
|
| |
| try: |
| |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as code_file: |
| code_file.write(state['user_input']) |
| code_file_path = code_file.name |
|
|
| |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as test_file: |
| test_file.write(test_code) |
| test_file_path = test_file.name |
|
|
| |
| result = subprocess.run( |
| ['python', test_file_path], |
| capture_output=True, |
| text=True, |
| timeout=10 |
| ) |
|
|
| state["test_results"] = { |
| "returncode": result.returncode, |
| "stdout": result.stdout, |
| "stderr": result.stderr |
| } |
|
|
| |
| os.unlink(code_file_path) |
| os.unlink(test_file_path) |
|
|
| except Exception as e: |
| state["test_error"] = f"Error executing tests: {str(e)}" |
|
|
| print(f"\nGenerated Tests:\n{test_code.strip()}\n") |
| if "test_results" in state: |
| print(f"Test Execution Results:\n{state['test_results']['stdout']}") |
| if state["test_results"]["stderr"]: |
| print(f"Errors:\n{state['test_results']['stderr']}") |
|
|
| return state |
|
|
| except Exception as e: |
| state["error"] = f"Error generating tests: {str(e)}" |
| return state |
|
|
| def generate_code(state: CodeAssistantState) -> CodeAssistantState: |
| if not isinstance(state, dict): |
| raise ValueError("State must be a dictionary") |
|
|
| if "user_input" not in state or not state["user_input"].strip(): |
| state["error"] = "Please enter your code request" |
| return state |
|
|
| try: |
| messages = [ |
| {"role": "system", "content": "You are a Python coding assistant. Return only clean, production-ready code."}, |
| {"role": "user", "content": state["user_input"].strip()} |
| ] |
|
|
| completion = groq_client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=messages, |
| temperature=0.7, |
| max_tokens=2048, |
| ) |
|
|
| code = completion.choices[0].message.content |
| if code.startswith('```python'): |
| code = code[9:-3] if code.endswith('```') else code[9:] |
| elif code.startswith('```'): |
| code = code[3:-3] if code.endswith('```') else code[3:] |
|
|
| state["generated_code"] = code.strip() |
| state["metadata"] = { |
| "model": "llama-3.3-70b-versatile", |
| "timestamp": datetime.datetime.now().isoformat() |
| } |
|
|
| |
| print(f"\nGenerated Code:\n{code.strip()}\n") |
|
|
| return state |
|
|
| except Exception as e: |
| state["error"] = f"Error generating code: {str(e)}" |
| return state |
|
|
| def explain_code(state: CodeAssistantState) -> CodeAssistantState: |
| try: |
| messages = [ |
| {"role": "system", "content": "You are a Python expert. Explain what the following code does in plain language."}, |
| {"role": "user", "content": state["user_input"].strip()} |
| ] |
|
|
| completion = groq_client.chat.completions.create( |
| model="llama-3.3-70b-versatile", |
| messages=messages, |
| temperature=0.5, |
| max_tokens=1024 |
| ) |
|
|
| explanation = completion.choices[0].message.content.strip() |
| state["generated_code"] = explanation |
| state["metadata"] = { |
| "model": "llama-3.3-70b-versatile", |
| "timestamp": datetime.datetime.now().isoformat() |
| } |
|
|
| |
| print(f"Explanation:\n{explanation}") |
|
|
| return state |
|
|
| except Exception as e: |
| state["error"] = f"Error explaining code: {str(e)}" |
| return state |
|
|
| |
| |
| |
| workflow = StateGraph(CodeAssistantState) |
|
|
| |
| workflow.add_node("initialize_db", initialize_db) |
| workflow.add_node("retrieve_examples", retrieve_examples) |
| workflow.add_node("classify_task", classify_task_llm) |
| workflow.add_node("generate_code", generate_code) |
| workflow.add_node("explain_code", explain_code) |
| workflow.add_node("test_code", test_code) |
|
|
| |
| workflow.set_entry_point("initialize_db") |
| workflow.add_edge("initialize_db", "retrieve_examples") |
| workflow.add_edge("retrieve_examples", "classify_task") |
|
|
| |
| workflow.add_conditional_edges( |
| "classify_task", |
| lambda state: state["task_type"], |
| { |
| "generate": "generate_code", |
| "explain": "explain_code", |
| "test": "test_code" |
| } |
| ) |
|
|
| |
| workflow.add_edge("generate_code", END) |
| workflow.add_edge("explain_code", END) |
| workflow.add_edge("test_code", END) |
|
|
| |
| app_workflow = workflow.compile() |
|
|
| |
| |
| |
| def process_input(user_input: str): |
| """Function that will be called by Gradio to process user input""" |
| initial_state = { |
| "user_input": user_input, |
| "similar_examples": None, |
| "generated_code": None, |
| "error": None, |
| "task_type": None |
| } |
|
|
| result = app_workflow.invoke(initial_state) |
|
|
| if result.get("error"): |
| return f"Error: {result['error']}" |
|
|
| if result["task_type"] == "generate": |
| return f"Generated Code:\n\n{result['generated_code']}" |
| else: |
| return f"Code Explanation:\n\n{result['generated_code']}" |
|
|
| |
| |
| with gr.Blocks(title="Smart Code Assistant") as demo: |
| gr.Markdown(""" |
| # Smart Code Assistant |
| Enter your request either to generate new code or to explain existing code |
| """) |
|
|
| with gr.Row(): |
| input_text = gr.Textbox(label="Enter your request", placeholder="Example: Write a function to add two numbers... or Explain this code...") |
| output_text = gr.Textbox(label="Result", interactive=False) |
|
|
| submit_btn = gr.Button("Execute") |
| submit_btn.click(fn=process_input, inputs=input_text, outputs=output_text) |
|
|
| |
| gr.Examples( |
| examples=[ |
| ["Write a Python function to add two numbers"], |
| ["Explain this code: for i in range(5): print(i)"], |
| ["Create a function to convert temperature from Fahrenheit to Celsius"], |
| ["test for i in range(3): print('Hello from test', i)"] |
| ], |
|
|
| inputs=input_text |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |