ucirvine/sms_spam
Viewer • Updated • 5.57k • 5.57k • 55
Simple BiLSTM model PyTorch trained for SPAM detection on SMS Spam Collection
(Almeida, Tiago and Jos Hidalgo. 2011. SMS Spam Collection.
UCI Machine Learning Repository. https://doi.org/10.24432/C5CC84).
torch.sigmoid.bert-base-uncased tokenizer only for tokenization (the encoder is NOT BERT).BiLSTMClassifier.safetensors: trained weightsBiLSTMClassifier.py: model definitionconfig.json: hyperparametersimport json
import torch
from transformers import BertTokenizer
from safetensors.torch import load_file
from BiLSTMClassifier import BiLSTMClassifier
with open("config.json") as f:
cfg = json.load(f)
model = BiLSTMClassifier(**cfg)
state_dict = load_file("BiLSTMClassifier.safetensors")
model.load_state_dict(state_dict)
model.eval()
sample_text = "URGENT HIRING! Earn $500/day working from home. No experience needed."
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(sample_text, return_tensors="pt")
logits = model(tokens["input_ids"])
prob = torch.sigmoid(logits)