Skip to content

Commit

Permalink
Merge pull request #148 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
New architecture for detection/layout
  • Loading branch information
VikParuchuri committed Jul 12, 2024
2 parents f7c6c04 + 821160a commit 03b859e
Show file tree
Hide file tree
Showing 16 changed files with 895 additions and 591 deletions.
43 changes: 21 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ Install with:
pip install surya-ocr
```

Model weights will automatically download the first time you run surya. Note that this does not work with the latest version of transformers `4.37+` [yet](https://github.com/huggingface/transformers/issues/28846#issuecomment-1926109135), so you will need to keep `4.36.2`, which is installed with surya.
Model weights will automatically download the first time you run surya.

# Usage

- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. For text detection, the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.

## Interactive App

Expand All @@ -79,8 +79,6 @@ pip install streamlit
surya_gui
```

Pass the `--math` command line argument to use the math text detection model instead of the default model. This will detect math better, but will be worse at everything else.

## OCR (text recognition)

This command will write out a json file with the detected text and bboxes:
Expand Down Expand Up @@ -151,7 +149,6 @@ surya_detect DATA_PATH --images
- `--images` will save images of the pages and detected text lines (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--math` uses a specialized math detection model instead of the default model. This will be better at math, but worse at everything else.

The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:

Expand All @@ -166,14 +163,14 @@ The `results.json` file will contain a json dictionary where the keys are the in

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `2`.
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `440MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.

### From python

```python
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor

image = Image.open(IMAGE_PATH)
model, processor = load_model(), load_processor()
Expand Down Expand Up @@ -207,15 +204,15 @@ The `results.json` file will contain a json dictionary where the keys are the in

**Performance tips**

Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `2`.
Setting the `DETECTOR_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `400MB` of VRAM, so very high batch sizes are possible. The default is a batch size `36`, which will use about 16GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `6`.

### From python

```python
from PIL import Image
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.settings import settings

image = Image.open(IMAGE_PATH)
Expand Down Expand Up @@ -334,16 +331,16 @@ For Google Cloud, I aligned the output from Google Cloud with the ground truth.

![Benchmark chart](static/images/benchmark_chart_small.png)

| Model | Time (s) | Time per page (s) | precision | recall |
| Model | Time (s) | Time per page (s) | precision | recall |
|-----------|------------|---------------------|-------------|----------|
| surya | 52.6892 | 0.205817 | 0.844426 | 0.937818 |
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |
| surya | 50.2099 | 0.196133 | 0.821061 | 0.956556 |
| tesseract | 74.4546 | 0.290838 | 0.631498 | 0.997694 |


Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A6000 GPU, and a 32 core CPU. This was the resource usage:
Tesseract is CPU-based, and surya is CPU or GPU. I ran the benchmarks on a system with an A10 GPU, and a 32 core CPU. This was the resource usage:

- tesseract - 32 CPU cores, or 8 workers using 4 cores each
- surya - 32 batch size, for 9GB VRAM usage
- surya - 36 batch size, for 16GB VRAM usage

**Methodology**

Expand All @@ -362,14 +359,14 @@ Then we calculate precision and recall for the whole dataset.

![Benchmark chart](static/images/benchmark_layout_chart.png)

| Layout Type | precision | recall |
|---------------|-------------|----------|
| Image | 0.95 | 0.99 |
| Table | 0.95 | 0.96 |
| Text | 0.89 | 0.95 |
| Title | 0.92 | 0.89 |
| Layout Type | precision | recall |
| ----------- | --------- | ------ |
| Image | 0.97 | 0.96 |
| Table | 0.99 | 0.99 |
| Text | 0.9 | 0.97 |
| Title | 0.94 | 0.88 |

Time per image - .79 seconds on GPU (A6000).
Time per image - .4 seconds on GPU (A10).

**Methodology**

Expand Down Expand Up @@ -446,7 +443,7 @@ python benchmark/ordering.py

# Training

Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.

Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).

Expand All @@ -455,6 +452,8 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m
This work would not have been possible without amazing open source AI work:

- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
- [EfficientViT](https://github.com/mit-han-lab/efficientvit) from MIT
- [timm](https://github.com/huggingface/pytorch-image-models) from Ross Wightman
- [Donut](https://github.com/clovaai/donut) from Naver
- [transformers](https://github.com/huggingface/transformers) from huggingface
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
Expand Down
2 changes: 1 addition & 1 deletion benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from surya.benchmark.bbox import get_pdf_lines
from surya.benchmark.metrics import precision_recall
from surya.benchmark.tesseract import tesseract_parallel
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.detection import batch_text_detection
from surya.postprocessing.heatmap import draw_polys_on_image
Expand Down
2 changes: 1 addition & 1 deletion benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from surya.benchmark.metrics import precision_recall
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.layout import batch_layout_detection
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
Expand Down
2 changes: 1 addition & 1 deletion detect_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
Expand Down
10 changes: 7 additions & 3 deletions detect_text.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
import copy
import json
import time
from collections import defaultdict

from surya.input.load import load_from_folder, load_from_file
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.detection import batch_text_detection
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
Expand All @@ -20,10 +21,9 @@ def main():
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
args = parser.parse_args()

checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
model = load_model(checkpoint=checkpoint)
processor = load_processor(checkpoint=checkpoint)

Expand All @@ -34,9 +34,13 @@ def main():
images, names = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]

start = time.time()
predictions = batch_text_detection(images, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
end = time.time()
if args.debug:
print(f"Detection took {end - start} seconds")

if args.images:
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
Expand Down
15 changes: 2 additions & 13 deletions ocr_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import argparse
import io
from typing import List

import pypdfium2
import streamlit as st
from surya.detection import batch_text_detection
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model, load_processor
from surya.model.detection.model import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
Expand All @@ -22,18 +20,9 @@
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from surya.settings import settings

parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.")
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)

try:
args = parser.parse_args()
except SystemExit as e:
print(f"Error parsing arguments: {e}")
os._exit(e.code)

@st.cache_resource()
def load_det_cached():
checkpoint = settings.DETECTOR_MATH_MODEL_CHECKPOINT if args.math else settings.DETECTOR_MODEL_CHECKPOINT
checkpoint = settings.DETECTOR_MODEL_CHECKPOINT
return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)


Expand Down
2 changes: 1 addition & 1 deletion ocr_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from surya.input.langs import replace_lang_with_code, get_unique_langs
from surya.input.load import load_from_folder, load_from_file, load_lang_file
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.14"
version = "0.4.15"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion reading_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.layout import batch_layout_detection
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.ordering import batch_ordering
Expand Down
7 changes: 0 additions & 7 deletions run_ocr_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@


def run_app():
parser = argparse.ArgumentParser(description="Run the streamlit OCR app")
parser.add_argument("--math", action="store_true", help="Use math model for detection", default=False)
args = parser.parse_args()

cur_dir = os.path.dirname(os.path.abspath(__file__))
ocr_app_path = os.path.join(cur_dir, "ocr_app.py")
cmd = ["streamlit", "run", ocr_app_path]
if args.math:
cmd.append("--")
cmd.append("--math")
subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})

if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image

from surya.model.detection.segformer import SegformerForRegressionMask
from surya.model.detection.model import EfficientViTForSemanticSegmentation
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb
Expand All @@ -19,12 +19,14 @@ def get_batch_size():
batch_size = settings.DETECTOR_BATCH_SIZE
if batch_size is None:
batch_size = 6
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 24
batch_size = 36
return batch_size


def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
Expand Down
51 changes: 51 additions & 0 deletions surya/model/detection/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from transformers import PretrainedConfig


class EfficientViTConfig(PretrainedConfig):
r"""
```"""

model_type = "efficientvit"

def __init__(
self,
num_classes=2,
num_channels=3,
widths=(32, 64, 128, 256, 512),
head_dim=32,
num_stages=4,
depths=(1, 1, 1, 6, 6),
strides=(2, 2, 2, 2, 2),
hidden_sizes=(32, 64, 160, 256),
patch_size=(7, 7),
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.0,
layer_norm_eps=1e-6,
decoder_layer_hidden_size=128,
decoder_hidden_size=512,
semantic_loss_ignore_index=255,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)

self.num_classes = num_classes
self.widths = widths
self.head_dim = head_dim

self.num_channels = num_channels
self.num_stages = num_stages
self.depths = depths
self.strides = strides
self.hidden_sizes = hidden_sizes
self.patch_size = patch_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.decoder_hidden_size = decoder_hidden_size
self.decoder_layer_hidden_size = decoder_layer_hidden_size
self.semantic_loss_ignore_index = semantic_loss_ignore_index

self.initializer_range = initializer_range
Loading

0 comments on commit 03b859e

Please sign in to comment.