| import os |
| import logging |
| import torch |
| from torch.utils.data import Dataset |
| from datasets import load_dataset, load_from_disk |
| import pandas as pd |
| import nltk |
|
|
| from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp |
|
|
| |
| |
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
| |
| |
| |
| def process_data(): |
| if not os.path.exists(PREPROCESSED_DIR): |
| logging.info("Preprocessing data... This may take a while.") |
| |
| snli = load_dataset("snli") |
| snli = snli.filter(lambda x: x["label"] != -1) |
|
|
| def build_dependency_graph(sentence): |
| doc = nlp(sentence) |
| tokens = [tok.text for tok in doc] |
| edges = [] |
| for tok in doc: |
| if tok.head.i != tok.i: |
| edges.extend([(tok.i, tok.head.i), (tok.head.i, tok.i)]) |
| return tokens, edges |
|
|
| def preprocess(examples): |
| premises = examples["premise"] |
| hypotheses = examples["hypothesis"] |
| labels = examples["label"] |
| tokenized = tokenizer(premises, hypotheses, |
| truncation=True, padding="max_length", |
| max_length=MAX_LENGTH) |
| tokenized["labels"] = labels |
|
|
| p_tokens_list, p_edges_list, p_idx_list = [], [], [] |
| h_tokens_list, h_edges_list, h_idx_list = [], [], [] |
|
|
| for p, h, input_ids in zip(premises, hypotheses, tokenized["input_ids"]): |
| p_toks, p_edges = build_dependency_graph(p) |
| h_toks, h_edges = build_dependency_graph(h) |
| wp_tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
| def align_tokens(spacy_tokens, wp_tokens): |
| node_indices, wp_idx = [], 1 |
| for _ in spacy_tokens: |
| if wp_idx >= len(wp_tokens) - 1: break |
| node_indices.append(wp_idx) |
| wp_idx += 1 |
| while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"): |
| wp_idx += 1 |
| return node_indices |
|
|
| p_idx = align_tokens(p_toks, wp_tokens) |
| h_idx = align_tokens(h_toks, wp_tokens) |
|
|
| p_tokens_list.append(p_toks) |
| p_edges_list.append(p_edges) |
| p_idx_list.append(p_idx) |
|
|
| h_tokens_list.append(h_toks) |
| h_edges_list.append(h_edges) |
| h_idx_list.append(h_idx) |
|
|
| tokenized.update({ |
| "premise_graph_tokens": p_tokens_list, |
| "premise_graph_edges": p_edges_list, |
| "premise_node_indices": p_idx_list, |
| "hypothesis_graph_tokens": h_tokens_list, |
| "hypothesis_graph_edges": h_edges_list, |
| "hypothesis_node_indices": h_idx_list, |
| }) |
| return tokenized |
|
|
| snli = snli.map(preprocess, batched=True) |
| snli.save_to_disk(PREPROCESSED_DIR) |
| logging.info(f"Preprocessing complete. Saved to {PREPROCESSED_DIR}") |
| else: |
| logging.info("Using existing preprocessed data at %s", PREPROCESSED_DIR) |
|
|
|
|
| def chunk_transcript(transcript_text, start_idx, end_idx, tokenizer): |
| encoded = tokenizer(transcript_text, |
| return_offsets_mapping=True, |
| add_special_tokens=True, |
| return_tensors=None, |
| max_length=1024, |
| padding=False, |
| truncation=False) |
| all_input_ids = encoded["input_ids"] |
| all_offsets = encoded["offset_mapping"] |
|
|
| chunks = [] |
| i = 0 |
| while i < len(all_input_ids): |
| chunk_ids = all_input_ids[i : i + MAX_LENGTH] |
| chunk_offsets = all_offsets[i : i + MAX_LENGTH] |
| attention_mask = [1] * len(chunk_ids) |
|
|
| no_span = 1 |
| start_token, end_token = -1, -1 |
| if start_idx >= 0 and end_idx >= 0: |
| for j, (off_s, off_e) in enumerate(chunk_offsets): |
| if off_s <= start_idx < off_e: |
| start_token = j |
| if off_s < end_idx <= off_e: |
| end_token = j |
| break |
| if 0 <= start_token <= end_token: |
| no_span = 0 |
| else: |
| start_token, end_token = -1, -1 |
|
|
| chunks.append({ |
| "input_ids": torch.tensor(chunk_ids, dtype=torch.long), |
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| "start_label": start_token, |
| "end_label": end_token, |
| "no_span_label": no_span, |
| }) |
| i += (MAX_LENGTH - OVERLAP) |
| return chunks |
|
|
|
|
| class SpanExtractionChunkedDataset(Dataset): |
| def __init__(self, data): |
| self.samples = [] |
| for item in data: |
| chunks = chunk_transcript( |
| item.get("transcript", ""), |
| item.get("start_idx", -1), |
| item.get("end_idx", -1), |
| tokenizer) |
| self.samples.extend(chunks) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| return self.samples[idx] |
|
|
|
|
| def span_collate_fn(batch): |
| max_len = max(len(x["input_ids"]) for x in batch) |
| inputs, masks, starts, ends, nos = [], [], [], [], [] |
| for x in batch: |
| pad = max_len - len(x["input_ids"]) |
| inputs.append(torch.cat([x["input_ids"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0)) |
| masks.append(torch.cat([x["attention_mask"], torch.zeros(pad, dtype=torch.long)]).unsqueeze(0)) |
| starts.append(x["start_label"]) |
| ends.append(x["end_label"]) |
| nos.append(x["no_span_label"]) |
| return { |
| "input_ids": torch.cat(inputs, dim=0), |
| "attention_mask": torch.cat(masks, dim=0), |
| "start_positions": torch.tensor(starts, dtype=torch.long), |
| "end_positions": torch.tensor(ends, dtype=torch.long), |
| "no_span_label": torch.tensor(nos, dtype=torch.long), |
| } |
|
|
|
|
| nltk.download('punkt') |
| nltk.download('punkt_tab') |
|
|
| class SentenceDataset(Dataset): |
| def __init__(self, |
| excel_path: str, |
| tokenizer, |
| max_length: int = 128): |
| df = pd.read_excel(excel_path) |
| self.samples = [] |
|
|
| for _, row in df.iterrows(): |
| transcript = str(row['Claude_Call']) |
| gold_sentences = row['Sel_K'] |
| |
| if isinstance(gold_sentences, str): |
| gold_sentences = eval(gold_sentences) |
|
|
| |
| sentences = nltk.sent_tokenize(transcript) |
| for sent in sentences: |
| label = 1 if sent in gold_sentences else 0 |
|
|
| enc = tokenizer.encode_plus( |
| sent, |
| max_length=max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| self.samples.append({ |
| 'input_ids': enc['input_ids'].squeeze(0), |
| 'attention_mask': enc['attention_mask'].squeeze(0), |
| 'label': torch.tensor(label, dtype=torch.float) |
| }) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| return self.samples[idx] |