Generative AI for Documents Classification (i.e. to identify type of documents – e.g. driving license, passport, USA tax form like W2, W9, any business specific documents etc.) is a crucial initial step for a lot of complex automated tasks, such as auto-filling of forms, email classification, information extraction from documents, data labelling for AI model training etc.
With wide range of traditional methods available for document classification, each having its own pros and cons. the most popular is the use of CNNs (Convolutional Neural Networks), as it excels in capturing an image’s visual and spatial features more effectively than other machine learning algorithms or neural networks. While this is an effective technique, it does have its own set of disadvantages such as the requirement of large volume of training data, longer training time & higher configuration infrastructure for better performance.
Figure 1: CNN architecture
GenAI: Document Classification Beyond CNN
The GenAI based solution offers an interesting alternative route for document classification that aims to find a solution to the potential challenges faced by CNN or other traditional document classifiers. Advantage of GenAI based approach over CNN based method are
- Good performance even with small training samples: Good performance with very less data (4-5 samples per class) makes it a perfect choice for few-shot text classification
- No specific training required:Leveraging state-of-the-art OCR & Large language models eliminates the need for conventional model training as LLMs excel in discerning semantic similarities thanks to extensive pre-training on vast textual datasets.
- Less false positive due to text and layout focused approach:By analyzing both content and layout, this method significantly reduces false positives, proving especially effective when documents from different classes share a similar layout/ text content.
Also read: Hands-On Generative AI using real-world applications
Working of GenAI-based document classifier
The basic idea is to use OCR (optical character recognition) to extract text from the images, passing all this information to an LLM (large language model – e.g. GPT, LLaMA etc.) to generate embeddings which will then be used for document classification. Let’s look at the implementation of this approach using real world use case.
Problem statement# – Classify scanned copy of USA tax forms
Available classes: W2, W9 and 1040 tax forms.
Description: USA tax forms (like W2, W9 & 1040) vary in structure and are text heavy, making it difficult to accurately classify using common CV-based image classification algorithms. But using GenAI based approach, we can classify it even with very less training data.
Solution:
Figure 2: Solution Architecture
Technology used:
- Optical Character Recognition with Google tesseract (pip install pytesseract==0.3.10)
- Open-AI API key for embedding model (pip install openai==0.28.0)
- Python libraries like NumPy, OpenCV (pip install numpy==1.24.4 opencv-python==4.8.0.76)
Step #1 – Data Collection
First, we need to build training data which will have shape N (Number of classes) * K (Number of training samples for each class). Let’s collect a few samples (e.g. 2 samples) of W2, W9 & 1040 forms from the web.
Figure 5: Dataset collection
In this case, parameters are
- Classes: [w9, w2, 1040]
- N =3 and K = 2
Step #1A – Project Structure
This is the training data directory structure. Training data needs to be pushed to relevant directories –
Step #2 – Text Extraction (text_extraction.py):
We will be using Py-tesseract as an OCR engine to extract each word and its bounding box coordinate from document image
import pytesseract, numpy as np
pytesseract.pytesseract.tesseract_cmd = "<path_to_textract.exe>"
def apply_ocr(img):
ocr_df = pytesseract.image_to_data(img, output_type='data.frame')
return ocr_df
Note: Output format from this code snippet will be Pandas dataframe: [ text, left, top, width, height]
Step #3 – Convert Data frame to String Representation (text_extraction.py)
As we have extracted each word from document with its bounding boxes, let’s convert the data frame representation into string representation. Using this, embeddings can be generated for each document.
Instead of passing only concatenated text string to LLM, we will try to encode it in a format which will retain text context and how these texts are spread across the document. Let’s look at an example (Figure 3):
Figure 3: Represent Text documents as string
let’s write code to convert OCR results into formatted string format:
def text_to_string_encode(coordinates, min_x, min_y, max_x, max_y):
matrix = [[0] * 50 for _ in range(50)]
# fill matrix with values
for x, y, z in coordinates:
if z is not np.nan:
matrix_x = int((x - min_x) / (max_x - min_x) * (50 - 1)) if max_x != min_x else '-'
matrix_y = int((y - min_y) / (max_y - min_y) * (50 - 1)) if max_y != min_y else '-'
matrix[matrix_y][matrix_x] = z
# return matrix as string
return '/'.join([''.join(map(lambda x: str(x), line)) for line in matrix])
Note: Considering input length constraint for GPT models, we are always scaling our document to 50 * 50 grid layout, assuming words length for a page of document will not exceed 2500
Step #4 – Generate Embeddings (embeddings.py)
Let’s write code to generate embeddings in both training phase and inference phase for a given document’s text string
import openai
# setup open ai key
api_key = '<your api key>'
openai.api_key = api_key
def get_embedding(formatted_string):
return openai.Embedding.create(
input=[formatted_string],
engine="text-embedding-ada-002")['data'][0]['embedding']
Step #5 – AI Model Training (train.py)
Now with everything in place, we are ready to write a model training script and train our model using train.py
import os, sys
from embeddings import *
from text_extraction import *
import pandas as pd
def fit_model(train_data='support_set/'):
class_names = os.listdir(train_data)
support_set = dict.fromkeys(class_names)
for cls in support_set:
support_set[cls] = os.listdir(train_data + cls)
Let’s process individual files present in our training dataset using “apply_ocr”, “text_to_string_encode” and “get_embedding function”. Please note that this is also a part of the fit_model function.
label = []
embds = []
for cls in support_set:
for file in support_set[cls]:
print(f'Processing file {file} of class {cls}')
ocr_data = apply_ocr(train_data + cls + '/' + file)
min_x = ocr_data['left'].min()
min_y = ocr_data['top'].min()
max_x = (ocr_data['left'] + ocr_data['width']).max()
max_y = (ocr_data['top'] + ocr_data['height']).max()
formatted_string = text_to_string_encode(
zip(ocr_data['left'], ocr_data['top'], ocr_data['text']),
min_x, min_y, max_x, max_y)
embd = get_embedding(formatted_string)
label.append(cls)
embds.append(embd)
Note: This will generate class label and embedding mapping for all our training files, which will serve as a trained model for us. We can store this mapping as a csv file to refer in inference phase
# let’s export our csv file which will serve as a trained model
pd.DataFrame({'Label': label, 'Embedding': embds}).to_csv('support_set.csv')
Step #6 – Getting Inference (inference.py)
Now we are all ready to proceed with inference on test image. Flow for test image will be, image -> apply_ocr -> text_to_string_encode -> get embedding
import os, sys, numpy as np
from numpy.linalg import norm
from embeddings import *
from text_extraction import *
from train import *
import pandas as pd
def test(image):
ocr_data = apply_ocr(image)
min_x = ocr_data['left'].min()
min_y = ocr_data['top'].min()
max_x = (ocr_data['left'] + ocr_data['width']).max()
max_y = (ocr_data['top'] + ocr_data['height']).max()
formatted_string = text_to_string_encode(
zip(ocr_data['left'], ocr_data['top'], ocr_data['text']),
min_x, min_y, max_x, max_y)
test_embd = get_embedding(formatted_string)
return test_embd
Let’s calculate cosine similarity of test image with all training set images and predict the class based on mean similarity score of all classes. The higher the similarity score with a document, the more similar is the test document with respective class document.
# train the mode and save as support_set.csv
fit_model()
# load test image and predict
test_embd = test('w9-test-image.png')
model = pd.read_csv('support_set.csv')
for n in range(len(model)):
label = model.loc[n, 'Label']
i = eval(model.loc[n, 'Embedding'])
print(np.dot(i, test_embd)/(norm(i)*norm(test_embd)), label)
Step #7 – Analyze Inference output
Figure 4 – Inference on test image, which is a W9 form
Note: Above is the similarity score of test image with each image in training set
In Figure 4, we can clearly see the mean score of W9 class label is higher than W2 and f1040 class labels, meaning our model is working perfectly fine.
Conclusion
In this blog we have shared an approach for text rich document’s classification using a combination of OCR engine and text embedding. We hope that the post was helpful to a Generative AI enthusiast, like you.