rth/sroie-2019-v2
Viewer • Updated • 973 • 232 • 2
How to use devashish-pisal/layoutlmv3-sroie-token-classification with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("token-classification", model="devashish-pisal/layoutlmv3-sroie-token-classification") # Load model directly
from transformers import AutoProcessor, AutoModelForTokenClassification
processor = AutoProcessor.from_pretrained("devashish-pisal/layoutlmv3-sroie-token-classification")
model = AutoModelForTokenClassification.from_pretrained("devashish-pisal/layoutlmv3-sroie-token-classification")This model is a fine-tuned version of LayoutLMv3 for invoice token classification using the SROIE dataset.
Token classification for document understanding:
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image
import torch
import pytesseract # other OCR library can also be used
# load model & image processor
processor = LayoutLMv3Processor.from_pretrained("devashish-pisal/layoutlmv3-sroie-token-classification")
model = LayoutLMv3ForTokenClassification.from_pretrained("devashish-pisal/layoutlmv3-sroie-token-classification")
# load image to perform inference
IMAGE_PATH = "path/to/the/image.jpg"
img = Image.open(IMAGE_PATH).covert("RGB")
width, height = img.size
# perform OCR
# note: OCR step can be skipped, if "apply_ocr=True" is specified while loading processor
ocr_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
words, boxes = find_words_and_bboxes(ocr_data) # this function finds bounding boxes from input dictionary and maps it to words
# prepare input for the model
encoding = processor(
img,
words,
boxes=boxes,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512,
)
# perform inference
with torch.no_grad():
outputs = model(**encoding)
predictions = torch.argmax(outputs.logits, dim=-1)[0].cpu().numpy()
# decode predictions
tokens = processor.tokenizer.convert_ids_to_tokens(
encoding["input_ids"][0].cpu().numpy()
)
# print result
id2label = model.config.id2label
print("\nToken predictions:\n")
for token, pred in zip(tokens, predictions):
print(f"{token:15} -> {id2label[pred]}")
# additional processing is required to convert tokens into words and sentences
| Tag | Meaning | Description |
|---|---|---|
| B-COMPANY | Beginning of Company | First token of a company name |
| I-COMPANY | Inside Company | Subsequent token of a company name |
| B-DATE | Beginning of Date | First token of a date expression |
| I-DATE | Inside Date | Subsequent token of a date |
| B-ADDRESS | Beginning of Address | First token of an address |
| I-ADDRESS | Inside Address | Subsequent token of an address |
| B-TOTAL | Beginning of Total | First token of a total amount |
| I-TOTAL | Inside Total | Subsequent token of a total amount |
| O | Outside | Token is not part of any entity |
Base model
microsoft/layoutlmv3-base