Skip to content

Commit

Permalink
Refactor benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 11, 2025
1 parent 03e722f commit 72d0064
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 83 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ jobs:
poetry install
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max 2
poetry run python benchmark/detection.py --max_rows 2
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
- name: Run recognition benchmark test
run: |
poetry run python benchmark/recognition.py --max 2
poetry run python benchmark/recognition.py --max_rows 2
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
- name: Run layout benchmark test
run: |
poetry run python benchmark/layout.py --max 5
poetry run python benchmark/layout.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark
run: |
poetry run python benchmark/ordering.py --max 5
poetry run python benchmark/ordering.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
- name: Run table recognition benchmark
run: |
poetry run python benchmark/table_recognition.py --max 5
poetry run python benchmark/table_recognition.py --max_rows 5
poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,10 @@ You can benchmark the performance of surya on your machine.
This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).

```shell
python benchmark/detection.py --max 256
python benchmark/detection.py --max_rows 256
```

- `--max` controls how many images to process for the benchmark
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images and detected bboxes
- `--pdf_path` will let you specify a pdf to benchmark instead of the default data
- `--results_dir` will let you specify a directory to save results to instead of the default one
Expand All @@ -441,7 +441,7 @@ This will evaluate surya and optionally tesseract on multilingual pdfs from comm
python benchmark/recognition.py --tesseract
```

- `--max` controls how many images to process for the benchmark
- `--max_rows` controls how many images to process for the benchmark
- `--debug 2` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
Expand All @@ -457,7 +457,7 @@ This will evaluate surya on the publaynet dataset.
python benchmark/layout.py
```

- `--max` controls how many images to process for the benchmark
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

Expand All @@ -467,17 +467,17 @@ python benchmark/layout.py
python benchmark/ordering.py
```

- `--max` controls how many images to process for the benchmark
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one

**Table Recognition**

```shell
python benchmark/table_recognition.py --max 1024 --tatr
python benchmark/table_recognition.py --max_rows 1024 --tatr
```

- `--max` controls how many images to process for the benchmark
- `--max_rows` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tatr` specifies whether to also run table transformer
Expand Down
40 changes: 20 additions & 20 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import copy
import json

import click

from benchmark.utils.bbox import get_pdf_lines
from benchmark.utils.metrics import precision_recall
from benchmark.utils.tesseract import tesseract_parallel
Expand All @@ -18,33 +20,31 @@
import datasets


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
parser.add_argument("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=100)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--tesseract", action="store_true", help="Run tesseract as well.", default=False)
args = parser.parse_args()

@click.command(help="Benchmark detection model.")
@click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None)
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False)
def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool):
det_predictor = DetectionPredictor()

if args.pdf_path is not None:
pathname = args.pdf_path
doc = open_pdf(args.pdf_path)
if pdf_path is not None:
pathname = pdf_path
doc = open_pdf(pdf_path)
page_count = len(doc)
page_indices = list(range(page_count))
page_indices = page_indices[:args.max]
page_indices = page_indices[:max_rows]

images = get_page_images(doc, page_indices)
doc.close()

image_sizes = [img.size for img in images]
correct_boxes = get_pdf_lines(args.pdf_path, image_sizes)
correct_boxes = get_pdf_lines(pdf_path, image_sizes)
else:
pathname = "det_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
images = list(dataset["image"])
images = convert_if_not_rgb(images)
correct_boxes = []
Expand All @@ -61,7 +61,7 @@ def main():
predictions = det_predictor(images)
surya_time = time.time() - start

if args.tesseract:
if tesseract:
start = time.time()
tess_predictions = tesseract_parallel(images)
tess_time = time.time() - start
Expand All @@ -70,7 +70,7 @@ def main():
tess_time = None

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
Expand All @@ -89,14 +89,14 @@ def main():
"tesseract": tess_metrics
}

if args.debug:
if debug:
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))

mean_metrics = {}
metric_types = sorted(page_metrics[0]["surya"].keys())
models = ["surya"]
if args.tesseract:
if tesseract:
models.append("tesseract")

for k in models:
Expand Down Expand Up @@ -124,7 +124,7 @@ def main():
table_data = [
["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types],
]
if args.tesseract:
if tesseract:
table_data.append(
["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types]
)
Expand Down
21 changes: 10 additions & 11 deletions benchmark/layout.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
import collections
import copy
import json

import click

from benchmark.utils.metrics import precision_recall
from surya.layout import LayoutPredictor
from surya.input.processing import convert_if_not_rgb
Expand All @@ -14,18 +15,16 @@
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya layout model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=100)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

@click.command(help="Benchmark surya layout model.")
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=100)
@click.option("--debug", is_flag=True, help="Run in debug mode.", default=False)
def main(results_dir: str, max_rows: int, debug: bool):
layout_predictor = LayoutPredictor()

pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]")
images = list(dataset["image"])
images = convert_if_not_rgb(images)

Expand All @@ -37,7 +36,7 @@ def main():
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

label_alignment = { # First is publaynet, second is surya
Expand Down Expand Up @@ -66,7 +65,7 @@ def main():

page_metrics[idx] = page_results

if args.debug:
if debug:
bbox_image = draw_bboxes_on_image(all_correct_bboxes, copy.deepcopy(images[idx]))
bbox_image.save(os.path.join(result_path, f"{idx}_layout.png"))

Expand Down
19 changes: 9 additions & 10 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import collections
import json

import click

from surya.input.processing import convert_if_not_rgb
from surya.layout import LayoutPredictor
from surya.common.polygon import PolygonBox
Expand All @@ -12,18 +13,16 @@
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya layout for reading order.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()

@click.command(help="Benchmark surya layout for reading order.")
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=None)
def main(results_dir: str, max_rows: int):
layout_predictor = LayoutPredictor()
pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
if max_rows is not None:
split = f"train[:{max_rows}]"
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
Expand All @@ -33,7 +32,7 @@ def main():
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
Expand Down
45 changes: 22 additions & 23 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
from collections import defaultdict

import click

from benchmark.utils.scoring import overlap_score
from surya.input.processing import convert_if_not_rgb
from surya.debug.text import draw_text_on_image
Expand All @@ -16,28 +18,25 @@

KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"]


def main():
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
parser.add_argument("--debug", type=int, help="Debug level - 1 dumps bad detection info, 2 writes out images.", default=0)
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
parser.add_argument("--specify_language", action="store_true", help="Pass language codes into the model.", default=False)
args = parser.parse_args()

@click.command(help="Benchmark recognition model.")
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
@click.option("--tesseract", is_flag=True, help="Run tesseract instead of surya.", default=False)
@click.option("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
@click.option("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
@click.option("--specify_language", is_flag=True, help="Pass language codes into the model.", default=False)
def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: str, tess_cpus: int, specify_language: bool):
rec_predictor = RecognitionPredictor()

split = "train"
if args.max:
split = f"train[:{args.max}]"
if max_rows:
split = f"train[:{max_rows}]"

dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)

if args.langs:
langs = args.langs.split(",")
if langs:
langs = langs.split(",")
dataset = dataset.filter(lambda x: x["language"] in langs, num_proc=4)

images = list(dataset["image"])
Expand All @@ -61,7 +60,7 @@ def main():
rec_predictor(images[:1], lang_list[:1], bboxes=bboxes[:1])

start = time.time()
predictions_by_image = rec_predictor(images, lang_list if args.specify_language else n_list, bboxes=bboxes)
predictions_by_image = rec_predictor(images, lang_list if specify_language else n_list, bboxes=bboxes)
surya_time = time.time() - start

surya_scores = defaultdict(list)
Expand All @@ -82,13 +81,13 @@ def main():
}
}

result_path = os.path.join(args.results_dir, "rec_bench")
result_path = os.path.join(results_dir, "rec_bench")
os.makedirs(result_path, exist_ok=True)

with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
json.dump(surya_scores, f)

if args.tesseract:
if tesseract:
tess_valid = []
tess_langs = []
for idx, lang in enumerate(lang_list):
Expand All @@ -104,7 +103,7 @@ def main():
tess_bboxes = [bboxes[i] for i in tess_valid]
tess_reference = [line_text[i] for i in tess_valid]
start = time.time()
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs, cpus=args.tess_cpus)
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs, cpus=tess_cpus)
tesseract_time = time.time() - start

tess_scores = defaultdict(list)
Expand All @@ -130,15 +129,15 @@ def main():
table_data = [
["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages],
]
if args.tesseract:
if tesseract:
table_data.append(
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
)

print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Only a few major languages are displayed. See the result path for additional languages.")

if args.debug >= 1:
if debug >= 1:
bad_detections = []
for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
if score < .8:
Expand All @@ -147,7 +146,7 @@ def main():
with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
json.dump(bad_detections, f)

if args.debug == 2:
if debug == 2:
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
Expand Down
Loading

0 comments on commit 72d0064

Please sign in to comment.