Skip to content

Commit

Permalink
Fix benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 7, 2025
1 parent 8530131 commit 8c3eb5f
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 255 deletions.
11 changes: 5 additions & 6 deletions benchmark/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from surya.benchmark.bbox import get_pdf_lines
from surya.benchmark.metrics import precision_recall
from surya.benchmark.tesseract import tesseract_parallel
from surya.model.detection.model import load_model, load_processor
from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb
from surya.detection import batch_text_detection
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.util import rescale_bbox
from surya.settings import settings
from surya.detection import DetectionPredictor

import os
import time
from tabulate import tabulate
Expand All @@ -27,8 +27,7 @@ def main():
parser.add_argument("--tesseract", action="store_true", help="Run tesseract as well.", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()
det_predictor = DetectionPredictor()

if args.pdf_path is not None:
pathname = args.pdf_path
Expand Down Expand Up @@ -56,10 +55,10 @@ def main():

if settings.DETECTOR_STATIC_CACHE:
# Run through one batch to compile the model
batch_text_detection(images[:1], model, processor)
det_predictor(images[:1])

start = time.time()
predictions = batch_text_detection(images, model, processor)
predictions = det_predictor(images)
surya_time = time.time() - start

if args.tesseract:
Expand Down
149 changes: 0 additions & 149 deletions benchmark/gcloud_label.py

This file was deleted.

11 changes: 4 additions & 7 deletions benchmark/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import json

from surya.benchmark.metrics import precision_recall
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.layout import LayoutPredictor
from surya.input.processing import convert_if_not_rgb
from surya.layout import batch_layout_detection
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.settings import settings
import os
Expand All @@ -23,8 +21,7 @@ def main():
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()
layout_predictor = LayoutPredictor()

pathname = "layout_bench"
# These have already been shuffled randomly, so sampling from the start is fine
Expand All @@ -33,10 +30,10 @@ def main():
images = convert_if_not_rgb(images)

if settings.LAYOUT_STATIC_CACHE:
batch_layout_detection(images[:1], model, processor)
layout_predictor(images[:1])

start = time.time()
layout_predictions = batch_layout_detection(images, model, processor)
layout_predictions = layout_predictor(images)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
Expand Down
10 changes: 3 additions & 7 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import json

from surya.input.processing import convert_if_not_rgb
from surya.layout import batch_layout_detection
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor
from surya.layout import LayoutPredictor
from surya.schema import Bbox
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy
Expand All @@ -21,9 +19,7 @@ def main():
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
args = parser.parse_args()

model = load_model()
processor = load_processor()

layout_predictor = LayoutPredictor()
pathname = "order_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
Expand All @@ -34,7 +30,7 @@ def main():
images = convert_if_not_rgb(images)

start = time.time()
layout_predictions = batch_layout_detection(images, model, processor)
layout_predictions = layout_predictor(images)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
Expand Down
1 change: 0 additions & 1 deletion benchmark/profile.sh

This file was deleted.

39 changes: 0 additions & 39 deletions benchmark/pymupdf_test.py

This file was deleted.

11 changes: 4 additions & 7 deletions benchmark/table_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from tabulate import tabulate

from surya.input.processing import convert_if_not_rgb
from surya.model.table_rec.model import load_model
from surya.model.table_rec.processor import load_processor
from surya.tables import batch_table_recognition
from surya.table_rec import TableRecPredictor
from surya.settings import settings
from surya.benchmark.metrics import penalized_iou_score
from surya.benchmark.tatr import load_tatr, batch_inference_tatr
Expand All @@ -23,8 +21,7 @@ def main():
parser.add_argument("--tatr", action="store_true", help="Run table transformer.", default=False)
args = parser.parse_args()

model = load_model()
processor = load_processor()
table_rec_predictor = TableRecPredictor()

pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
Expand All @@ -37,10 +34,10 @@ def main():

if settings.TABLE_REC_STATIC_CACHE:
# Run through one batch to compile the model
batch_table_recognition(images[:1], model, processor)
table_rec_predictor(images[:1])

start = time.time()
table_rec_predictions = batch_table_recognition(images, model, processor)
table_rec_predictions = table_rec_predictor(images)
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
Expand Down
38 changes: 0 additions & 38 deletions benchmark/tesseract_test.py

This file was deleted.

1 change: 0 additions & 1 deletion benchmark/viz.sh

This file was deleted.

0 comments on commit 8c3eb5f

Please sign in to comment.