| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
| import pandas as pd |
| import sys |
| import os |
| import shutil |
| from pathlib import Path |
| import chromadb |
| from chromadb.config import Settings |
| import uuid |
| import tempfile |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
| from scripts.core.ingestion.ingest import GitCrawler |
| from scripts.core.ingestion.chunk import RepoChunker |
|
|
| |
| BASELINE_MODEL = "microsoft/codebert-base" |
| FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DB_DIR = Path(os.path.abspath("data/chroma_db_comparison")) |
| DB_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"Loading models on {DEVICE}...") |
| print("1. Loading baseline model...") |
| baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL) |
| baseline_model = AutoModel.from_pretrained(BASELINE_MODEL) |
| baseline_model.to(DEVICE) |
| baseline_model.eval() |
|
|
| print("2. Loading fine-tuned model...") |
| finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL) |
| finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL) |
| finetuned_model.to(DEVICE) |
| finetuned_model.eval() |
| print("Both models loaded!") |
|
|
| |
| chroma_client = chromadb.PersistentClient(path=str(DB_DIR)) |
| baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) |
| finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) |
|
|
| |
| def compute_baseline_embeddings(text_list): |
| if not text_list: return None |
| inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
| with torch.no_grad(): |
| out = baseline_model(**inputs) |
| emb = out.last_hidden_state.mean(dim=1) |
| return F.normalize(emb, p=2, dim=1) |
|
|
| def compute_finetuned_embeddings(text_list): |
| if not text_list: return None |
| inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
| with torch.no_grad(): |
| out = finetuned_model(**inputs) |
| emb = out.last_hidden_state.mean(dim=1) |
| return F.normalize(emb, p=2, dim=1) |
|
|
| |
| def reset_baseline(): |
| chroma_client.delete_collection("baseline_rag") |
| global baseline_collection |
| baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) |
| return "Baseline database reset." |
|
|
| def reset_finetuned(): |
| chroma_client.delete_collection("finetuned_rag") |
| global finetuned_collection |
| finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) |
| return "Fine-tuned database reset." |
|
|
| |
| def list_baseline_files(): |
| count = baseline_collection.count() |
| if count == 0: |
| return [["No data indexed yet", "-", "-"]] |
| |
| try: |
| data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"]) |
| file_stats = {} |
| for meta in data['metadatas']: |
| fname = meta.get("file_name", "unknown") |
| url = meta.get("url", "unknown") |
| if fname not in file_stats: |
| file_stats[fname] = {"count": 0, "url": url} |
| file_stats[fname]["count"] += 1 |
| |
| results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] |
| return sorted(results, key=lambda x: x[1], reverse=True) |
| except Exception as e: |
| return [[f"Error: {str(e)}", "-", "-"]] |
|
|
| def list_finetuned_files(): |
| count = finetuned_collection.count() |
| if count == 0: |
| return [["No data indexed yet", "-", "-"]] |
| |
| try: |
| data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"]) |
| file_stats = {} |
| for meta in data['metadatas']: |
| fname = meta.get("file_name", "unknown") |
| url = meta.get("url", "unknown") |
| if fname not in file_stats: |
| file_stats[fname] = {"count": 0, "url": url} |
| file_stats[fname]["count"] += 1 |
| |
| results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] |
| return sorted(results, key=lambda x: x[1], reverse=True) |
| except Exception as e: |
| return [[f"Error: {str(e)}", "-", "-"]] |
|
|
| |
| def get_files_list_baseline(): |
| """Get list of unique files in baseline collection""" |
| try: |
| data = baseline_collection.get(include=["metadatas"]) |
| if not data['metadatas']: |
| return [] |
| files = list(set([m.get("file_name", "unknown") for m in data['metadatas']])) |
| return sorted(files) |
| except Exception as e: |
| return [] |
|
|
| def get_files_list_finetuned(): |
| """Get list of unique files in fine-tuned collection""" |
| try: |
| data = finetuned_collection.get(include=["metadatas"]) |
| if not data['metadatas']: |
| return [] |
| files = list(set([m.get("file_name", "unknown") for m in data['metadatas']])) |
| return sorted(files) |
| except Exception as e: |
| return [] |
|
|
| def get_chunks_for_file_baseline(file_name): |
| """Get all chunks for a specific file from baseline collection""" |
| if not file_name: |
| return {"error": "No file selected"} |
| |
| try: |
| |
| data = baseline_collection.get( |
| include=["documents", "metadatas", "embeddings"] |
| ) |
| |
| if not data['documents']: |
| return {"error": "No chunks found"} |
| |
| |
| chunks = [] |
| for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])): |
| if meta.get("file_name") == file_name: |
| chunks.append({ |
| "chunk_id": len(chunks) + 1, |
| "content": doc[:500] + "..." if len(doc) > 500 else doc, |
| "full_length": len(doc), |
| "metadata": meta, |
| "embedding_dim": len(emb) if emb is not None else 0 |
| }) |
| |
| if not chunks: |
| return {"error": "No chunks found for this file"} |
| |
| return { |
| "file_name": file_name, |
| "total_chunks": len(chunks), |
| "chunks": chunks |
| } |
| except Exception as e: |
| import traceback |
| error_details = traceback.format_exc() |
| print(f"ERROR in get_chunks_for_file_baseline: {error_details}") |
| return {"error": str(e)} |
|
|
| def get_chunks_for_file_finetuned(file_name): |
| """Get all chunks for a specific file from fine-tuned collection""" |
| if not file_name: |
| return {"error": "No file selected"} |
| |
| try: |
| |
| data = finetuned_collection.get( |
| include=["documents", "metadatas", "embeddings"] |
| ) |
| |
| if not data['documents']: |
| return {"error": "No chunks found"} |
| |
| |
| chunks = [] |
| for i, (doc, meta, emb) in enumerate(zip(data['documents'], data['metadatas'], data['embeddings'])): |
| if meta.get("file_name") == file_name: |
| chunks.append({ |
| "chunk_id": len(chunks) + 1, |
| "content": doc[:500] + "..." if len(doc) > 500 else doc, |
| "full_length": len(doc), |
| "metadata": meta, |
| "embedding_dim": len(emb) if emb is not None else 0 |
| }) |
| |
| if not chunks: |
| return {"error": "No chunks found for this file"} |
| |
| return { |
| "file_name": file_name, |
| "total_chunks": len(chunks), |
| "chunks": chunks |
| } |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| def download_chunks_baseline(file_name): |
| """Export chunks to JSON file for baseline""" |
| if not file_name: |
| return None |
| |
| import json |
| import tempfile |
| |
| chunks_data = get_chunks_for_file_baseline(file_name) |
| |
| temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') |
| json.dump(chunks_data, temp_file, indent=2) |
| temp_file.close() |
| |
| return temp_file.name |
|
|
| def download_chunks_finetuned(file_name): |
| """Export chunks to JSON file for fine-tuned""" |
| if not file_name: |
| return None |
| |
| import json |
| import tempfile |
| |
| chunks_data = get_chunks_for_file_finetuned(file_name) |
| |
| temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') |
| json.dump(chunks_data, temp_file, indent=2) |
| temp_file.close() |
| |
| return temp_file.name |
|
|
| |
| def search_baseline(query, top_k=5): |
| if baseline_collection.count() == 0: return [] |
| query_emb = compute_baseline_embeddings([query]) |
| if query_emb is None: return [] |
| query_vec = query_emb.cpu().numpy().tolist()[0] |
| results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"]) |
| output = [] |
| if results['ids']: |
| for i in range(len(results['ids'][0])): |
| meta = results['metadatas'][0][i] |
| code = results['documents'][0][i] |
| dist = results['distances'][0][i] |
| score = 1 - dist |
| output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) |
| return output |
|
|
| def search_finetuned(query, top_k=5): |
| if finetuned_collection.count() == 0: return [] |
| query_emb = compute_finetuned_embeddings([query]) |
| if query_emb is None: return [] |
| query_vec = query_emb.cpu().numpy().tolist()[0] |
| results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"]) |
| output = [] |
| if results['ids']: |
| for i in range(len(results['ids'][0])): |
| meta = results['metadatas'][0][i] |
| code = results['documents'][0][i] |
| dist = results['distances'][0][i] |
| score = 1 - dist |
| output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) |
| return output |
|
|
| def search_comparison(query, top_k=5): |
| baseline_results = search_baseline(query, top_k) |
| finetuned_results = search_finetuned(query, top_k) |
| return baseline_results, finetuned_results |
|
|
| |
| def ingest_from_url(repo_url): |
| if not repo_url.startswith("http"): |
| yield "Invalid URL" |
| return |
| |
| DATA_DIR = Path(os.path.abspath("data/raw_ingest")) |
| import stat |
| def remove_readonly(func, path, _): |
| os.chmod(path, stat.S_IWRITE) |
| func(path) |
| |
| try: |
| if DATA_DIR.exists(): |
| shutil.rmtree(DATA_DIR, onerror=remove_readonly) |
| |
| yield f"Cloning {repo_url}..." |
| crawler = GitCrawler(cache_dir=DATA_DIR) |
| repo_path = crawler.clone_repository(repo_url) |
| if not repo_path: |
| yield "Failed to clone repository." |
| return |
| |
| yield "Listing files..." |
| files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'}) |
| if isinstance(files, tuple): files = [f.path for f in files[0]] |
| |
| total_files = len(files) |
| yield f"Found {total_files} files. Chunking..." |
| |
| chunker = RepoChunker() |
| all_chunks = [] |
| |
| for i, file_path in enumerate(files): |
| yield f"Chunking: {i+1}/{total_files} ({file_path.name})" |
| try: |
| meta = {"file_name": file_path.name, "url": repo_url} |
| file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) |
| all_chunks.extend(file_chunks) |
| except Exception as e: |
| print(f"Skipping {file_path}: {e}") |
| |
| if not all_chunks: |
| yield "No valid chunks found." |
| return |
| |
| total_chunks = len(all_chunks) |
| yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." |
| |
| batch_size = 64 |
| |
| for i in range(0, total_chunks, batch_size): |
| batch = all_chunks[i:i+batch_size] |
| texts = [c.code for c in batch] |
| ids = [str(uuid.uuid4()) for _ in batch] |
| metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] |
| |
| embeddings = compute_baseline_embeddings(texts) |
| if embeddings is not None: |
| baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| |
| yield f"Embedding (FINE-TUNED)..." |
| |
| for i in range(0, total_chunks, batch_size): |
| batch = all_chunks[i:i+batch_size] |
| texts = [c.code for c in batch] |
| ids = [str(uuid.uuid4()) for _ in batch] |
| metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] |
| |
| embeddings = compute_finetuned_embeddings(texts) |
| if embeddings is not None: |
| finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| |
| yield f"SUCCESS! Indexed {total_chunks} chunks in both databases." |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield f"Error: {str(e)}" |
|
|
| def ingest_from_files(files): |
| if not files or len(files) == 0: |
| yield "No files uploaded." |
| return |
| |
| try: |
| yield f"Processing {len(files)} file(s)..." |
| |
| chunker = RepoChunker() |
| all_chunks = [] |
| |
| for i, file in enumerate(files): |
| yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}" |
| try: |
| |
| file_path = Path(file.name) |
| meta = {"file_name": file_path.name, "url": "uploaded"} |
| file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) |
| all_chunks.extend(file_chunks) |
| except Exception as e: |
| yield f"Error chunking {Path(file.name).name}: {str(e)}" |
| import traceback |
| traceback.print_exc() |
|
|
| |
| if not all_chunks: |
| yield "No valid chunks found." |
| return |
| |
| total_chunks = len(all_chunks) |
| yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." |
| |
| batch_size = 64 |
| for i in range(0, total_chunks, batch_size): |
| batch = all_chunks[i:i+batch_size] |
| texts = [c.code for c in batch] |
| ids = [str(uuid.uuid4()) for _ in batch] |
| metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] |
| |
| embeddings = compute_baseline_embeddings(texts) |
| if embeddings is not None: |
| baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| |
| yield f"Embedding (FINE-TUNED)..." |
| for i in range(0, total_chunks, batch_size): |
| batch = all_chunks[i:i+batch_size] |
| texts = [c.code for c in batch] |
| ids = [str(uuid.uuid4()) for _ in batch] |
| metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] |
| |
| embeddings = compute_finetuned_embeddings(texts) |
| if embeddings is not None: |
| finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) |
| yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" |
| |
| yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files." |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield f"Error: {str(e)}" |
|
|
| |
| def analyze_embeddings_baseline(): |
| count = baseline_collection.count() |
| if count < 5: |
| return "Not enough data (Need > 5 chunks).", None |
| |
| try: |
| limit = min(count, 2000) |
| data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"]) |
| |
| X = torch.tensor(data['embeddings']) |
| X_mean = torch.mean(X, 0) |
| X_centered = X - X_mean |
| U, S, V = torch.pca_lowrank(X_centered, q=2) |
| projected = torch.matmul(X_centered, V[:, :2]).numpy() |
| |
| indices = torch.randint(0, len(X), (min(100, len(X)),)) |
| sample = X[indices] |
| sim_matrix = torch.mm(sample, sample.t()) |
| mask = ~torch.eye(len(sample), dtype=bool) |
| avg_sim = sim_matrix[mask].mean().item() |
| diversity_score = 1.0 - avg_sim |
| |
| metrics = ( |
| f"BASELINE MODEL\n" |
| f"Total Chunks: {count}\n" |
| f"Analyzed: {len(X)}\n" |
| f"Diversity Score: {diversity_score:.4f}\n" |
| f"Avg Similarity: {avg_sim:.4f}" |
| ) |
| |
| plot_df = pd.DataFrame({ |
| "x": projected[:, 0], |
| "y": projected[:, 1], |
| "topic": [m.get("file_name", "unknown") for m in data['metadatas']] |
| }) |
| |
| import matplotlib.pyplot as plt |
| import io |
| from PIL import Image |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| fig.subplots_adjust(top=0.92) |
| |
| |
| unique_topics = plot_df["topic"].unique() |
| for topic in unique_topics: |
| mask = plot_df["topic"] == topic |
| ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) |
| |
| ax.set_xlabel("PC1") |
| ax.set_ylabel("PC2") |
| ax.set_title("Baseline Semantic Space (PCA)", fontsize=14, pad=20) |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| buf.seek(0) |
| img = Image.open(buf) |
| plt.close() |
| |
| return metrics, img |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"Error: {e}", None |
|
|
| def analyze_embeddings_finetuned(): |
| count = finetuned_collection.count() |
| if count < 5: |
| return "Not enough data (Need > 5 chunks).", None |
| |
| try: |
| limit = min(count, 2000) |
| data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"]) |
| |
| X = torch.tensor(data['embeddings']) |
| X_mean = torch.mean(X, 0) |
| X_centered = X - X_mean |
| U, S, V = torch.pca_lowrank(X_centered, q=2) |
| projected = torch.matmul(X_centered, V[:, :2]).numpy() |
| |
| indices = torch.randint(0, len(X), (min(100, len(X)),)) |
| sample = X[indices] |
| sim_matrix = torch.mm(sample, sample.t()) |
| mask = ~torch.eye(len(sample), dtype=bool) |
| avg_sim = sim_matrix[mask].mean().item() |
| diversity_score = 1.0 - avg_sim |
| |
| metrics = ( |
| f"FINE-TUNED MODEL\n" |
| f"Total Chunks: {count}\n" |
| f"Analyzed: {len(X)}\n" |
| f"Diversity Score: {diversity_score:.4f}\n" |
| f"Avg Similarity: {avg_sim:.4f}" |
| ) |
| |
| plot_df = pd.DataFrame({ |
| "x": projected[:, 0], |
| "y": projected[:, 1], |
| "topic": [m.get("file_name", "unknown") for m in data['metadatas']] |
| }) |
| |
| import matplotlib.pyplot as plt |
| import io |
| from PIL import Image |
| |
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| fig.subplots_adjust(top=0.92) |
| |
| |
| unique_topics = plot_df["topic"].unique() |
| for topic in unique_topics: |
| mask = plot_df["topic"] == topic |
| ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) |
| |
| ax.set_xlabel("PC1") |
| ax.set_ylabel("PC2") |
| ax.set_title("Fine-tuned Semantic Space (PCA)", fontsize=14, pad=20) |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) |
| ax.grid(True, alpha=0.3) |
| plt.tight_layout() |
| |
| |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
| buf.seek(0) |
| img = Image.open(buf) |
| plt.close() |
| |
| return metrics, img |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"Error: {e}", None |
|
|
| def evaluate_retrieval_baseline(sample_limit): |
| count = baseline_collection.count() |
| if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." |
| |
| try: |
| fetch_limit = min(count, 2000) |
| data = baseline_collection.get(limit=fetch_limit, include=["documents"]) |
| |
| import random |
| actual_sample_size = min(sample_limit, len(data['ids'])) |
| sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) |
| |
| hits_at_1 = 0 |
| hits_at_5 = 0 |
| mrr_sum = 0 |
| |
| yield f"BASELINE: Evaluating {actual_sample_size} chunks..." |
| |
| for i, idx in enumerate(sample_indices): |
| target_id = data['ids'][idx] |
| code = data['documents'][idx] |
| query = "\n".join(code.split("\n")[:3]) |
| query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0] |
| results = baseline_collection.query(query_embeddings=[query_emb], n_results=10) |
| found_ids = results['ids'][0] |
| if target_id in found_ids: |
| rank = found_ids.index(target_id) + 1 |
| mrr_sum += 1.0 / rank |
| if rank == 1: hits_at_1 += 1 |
| if rank <= 5: hits_at_5 += 1 |
| if i % 10 == 0: |
| yield f"Baseline: {i}/{actual_sample_size}..." |
| |
| recall_1 = hits_at_1 / actual_sample_size |
| recall_5 = hits_at_5 / actual_sample_size |
| mrr = mrr_sum / actual_sample_size |
| |
| report = ( |
| f"BASELINE EVALUATION ({actual_sample_size} chunks)\n" |
| f"{'='*40}\n" |
| f"Recall@1: {recall_1:.4f}\n" |
| f"Recall@5: {recall_5:.4f}\n" |
| f"MRR: {mrr:.4f}" |
| ) |
| yield report |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield f"Error: {e}" |
|
|
| def evaluate_retrieval_finetuned(sample_limit): |
| count = finetuned_collection.count() |
| if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." |
| |
| try: |
| fetch_limit = min(count, 2000) |
| data = finetuned_collection.get(limit=fetch_limit, include=["documents"]) |
| |
| import random |
| actual_sample_size = min(sample_limit, len(data['ids'])) |
| sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) |
| |
| hits_at_1 = 0 |
| hits_at_5 = 0 |
| mrr_sum = 0 |
| |
| yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..." |
| |
| for i, idx in enumerate(sample_indices): |
| target_id = data['ids'][idx] |
| code = data['documents'][idx] |
| query = "\n".join(code.split("\n")[:3]) |
| query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0] |
| results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10) |
| found_ids = results['ids'][0] |
| if target_id in found_ids: |
| rank = found_ids.index(target_id) + 1 |
| mrr_sum += 1.0 / rank |
| if rank == 1: hits_at_1 += 1 |
| if rank <= 5: hits_at_5 += 1 |
| if i % 10 == 0: |
| yield f"Fine-tuned: {i}/{actual_sample_size}..." |
| |
| recall_1 = hits_at_1 / actual_sample_size |
| recall_5 = hits_at_5 / actual_sample_size |
| mrr = mrr_sum / actual_sample_size |
| |
| report = ( |
| f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n" |
| f"{'='*40}\n" |
| f"Recall@1: {recall_1:.4f}\n" |
| f"Recall@5: {recall_5:.4f}\n" |
| f"MRR: {mrr:.4f}" |
| ) |
| yield report |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| yield f"Error: {e}" |
|
|
| |
| theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600") |
|
|
| css = """ |
| h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; } |
| .gradio-container { max-width: 1400px !important; margin: auto; } |
| .comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; } |
| """ |
|
|
| with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo: |
| gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison") |
| gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model") |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("1. Ingest Code"): |
| with gr.Tabs(): |
| with gr.Tab("GitHub Repository"): |
| repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask") |
| ingest_url_btn = gr.Button("Ingest from URL", variant="primary") |
| url_status = gr.Textbox(label="Status") |
| ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status) |
| |
| with gr.Tab("Upload Python Files"): |
| file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple") |
| ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary") |
| upload_status = gr.Textbox(label="Status") |
| ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status) |
| |
| with gr.Row(): |
| reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop") |
| reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop") |
| reset_status = gr.Textbox(label="Reset Status") |
| |
| reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status) |
| reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status) |
| |
| gr.Markdown("---") |
| gr.Markdown("### Database Inspector") |
| gr.Markdown("View indexed files in each collection") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Baseline Collection") |
| inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary") |
| baseline_files_df = gr.Dataframe( |
| headers=["File Name", "Chunks", "Source URL"], |
| datatype=["str", "number", "str"], |
| interactive=False, |
| value=[["No data yet", "-", "-"]] |
| ) |
| inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df) |
| |
| with gr.Column(): |
| gr.Markdown("#### Fine-tuned Collection") |
| inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary") |
| finetuned_files_df = gr.Dataframe( |
| headers=["File Name", "Chunks", "Source URL"], |
| datatype=["str", "number", "str"], |
| interactive=False, |
| value=[["No data yet", "-", "-"]] |
| ) |
| inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df) |
| |
| gr.Markdown("---") |
| gr.Markdown("### Chunk Inspector") |
| gr.Markdown("View detailed chunk information for indexed files (content, metadata, schema)") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Baseline Collection") |
| baseline_file_dropdown = gr.Dropdown( |
| label="Select File to Inspect", |
| choices=[], |
| interactive=True |
| ) |
| baseline_refresh_files = gr.Button("Refresh File List", variant="secondary") |
| baseline_chunks_display = gr.JSON(label="Chunk Details") |
| baseline_download_btn = gr.Button("Download Chunks as JSON", variant="primary") |
| baseline_download_output = gr.File(label="Download") |
| |
| with gr.Column(): |
| gr.Markdown("#### Fine-tuned Collection") |
| finetuned_file_dropdown = gr.Dropdown( |
| label="Select File to Inspect", |
| choices=[], |
| interactive=True |
| ) |
| finetuned_refresh_files = gr.Button("Refresh File List", variant="secondary") |
| finetuned_chunks_display = gr.JSON(label="Chunk Details") |
| finetuned_download_btn = gr.Button("Download Chunks as JSON", variant="primary") |
| finetuned_download_output = gr.File(label="Download") |
| |
| |
| baseline_refresh_files.click( |
| lambda: gr.Dropdown(choices=get_files_list_baseline()), |
| outputs=baseline_file_dropdown |
| ) |
| baseline_file_dropdown.change( |
| get_chunks_for_file_baseline, |
| inputs=baseline_file_dropdown, |
| outputs=baseline_chunks_display |
| ) |
| baseline_download_btn.click( |
| download_chunks_baseline, |
| inputs=baseline_file_dropdown, |
| outputs=baseline_download_output |
| ) |
| |
| finetuned_refresh_files.click( |
| lambda: gr.Dropdown(choices=get_files_list_finetuned()), |
| outputs=finetuned_file_dropdown |
| ) |
| finetuned_file_dropdown.change( |
| get_chunks_for_file_finetuned, |
| inputs=finetuned_file_dropdown, |
| outputs=finetuned_chunks_display |
| ) |
| finetuned_download_btn.click( |
| download_chunks_finetuned, |
| inputs=finetuned_file_dropdown, |
| outputs=finetuned_download_output |
| ) |
| |
| |
| with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"): |
| gr.Markdown("### Side-by-Side Retrieval Comparison") |
| search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'") |
| compare_btn = gr.Button("Compare Models", variant="primary") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header") |
| baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) |
| |
| with gr.Column(): |
| gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header") |
| finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) |
| |
| compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results]) |
| |
| |
| |
| with gr.Tab("3. Code Similarity Search"): |
| gr.Markdown("### Find Similar Code Snippets") |
| gr.Markdown("Paste a code snippet to find similar code in the database") |
| |
| with gr.Row(): |
| with gr.Column(): |
| code_input = gr.Code(label="Paste Code Snippet", language="python", lines=10) |
| similarity_btn = gr.Button("Find Similar Code", variant="primary") |
| |
| with gr.Column(): |
| gr.Markdown("#### Search Settings") |
| top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of Results") |
| model_choice = gr.Radio(["Baseline", "Fine-tuned", "Both"], value="Both", label="Model to Use") |
| |
| gr.Markdown("### Results") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Baseline Results") |
| baseline_code_results = gr.Dataframe( |
| headers=["File", "Similarity", "Code Snippet"], |
| datatype=["str", "str", "str"], |
| interactive=False, |
| wrap=True, |
| value=[["No search yet", "-", "-"]] |
| ) |
| |
| with gr.Column(): |
| gr.Markdown("#### Fine-tuned Results") |
| finetuned_code_results = gr.Dataframe( |
| headers=["File", "Similarity", "Code Snippet"], |
| datatype=["str", "str", "str"], |
| interactive=False, |
| wrap=True, |
| value=[["No search yet", "-", "-"]] |
| ) |
| |
| def search_similar_code(code_snippet, top_k, model_choice): |
| if not code_snippet or len(code_snippet.strip()) == 0: |
| empty = [["Enter code to search", "-", "-"]] |
| return empty, empty |
| |
| baseline_res = [] |
| finetuned_res = [] |
| |
| if model_choice in ["Baseline", "Both"]: |
| baseline_res = search_baseline(code_snippet, top_k) |
| if not baseline_res: |
| baseline_res = [["No results found", "-", "-"]] |
| |
| if model_choice in ["Fine-tuned", "Both"]: |
| finetuned_res = search_finetuned(code_snippet, top_k) |
| if not finetuned_res: |
| finetuned_res = [["No results found", "-", "-"]] |
| |
| if model_choice == "Baseline": |
| finetuned_res = [["Not searched", "-", "-"]] |
| elif model_choice == "Fine-tuned": |
| baseline_res = [["Not searched", "-", "-"]] |
| |
| return baseline_res, finetuned_res |
| |
| similarity_btn.click( |
| search_similar_code, |
| inputs=[code_input, top_k_slider, model_choice], |
| outputs=[baseline_code_results, finetuned_code_results] |
| ) |
| |
| |
| with gr.Tab("4. Deployment Monitoring"): |
| gr.Markdown("### Embedding Quality Analysis") |
| gr.Markdown("Analyze the semantic space distribution and diversity of embeddings") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Baseline Model") |
| analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary") |
| baseline_metrics = gr.Textbox(label="Baseline Metrics") |
| baseline_plot = gr.Image() |
| analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot]) |
| |
| with gr.Column(): |
| gr.Markdown("#### Fine-tuned Model") |
| analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary") |
| finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics") |
| finetuned_plot = gr.Image() |
| analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot]) |
| |
| gr.Markdown("---") |
| gr.Markdown("### Retrieval Performance Evaluation") |
| gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)") |
| |
| eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)") |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Baseline Evaluation") |
| eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary") |
| baseline_eval_output = gr.Textbox(label="Baseline Results") |
| eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output) |
| |
| with gr.Column(): |
| gr.Markdown("#### Fine-tuned Evaluation") |
| eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary") |
| finetuned_eval_output = gr.Textbox(label="Fine-tuned Results") |
| eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|