From 8c3eb5f6ce23710f850770351bfb71d40eaf5652 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Mon, 6 Jan 2025 22:04:07 -0500 Subject: [PATCH] Fix benchmarks --- benchmark/detection.py | 11 ++- benchmark/gcloud_label.py | 149 --------------------------------- benchmark/layout.py | 11 +-- benchmark/ordering.py | 10 +-- benchmark/profile.sh | 1 - benchmark/pymupdf_test.py | 39 --------- benchmark/table_recognition.py | 11 +-- benchmark/tesseract_test.py | 38 --------- benchmark/viz.sh | 1 - 9 files changed, 16 insertions(+), 255 deletions(-) delete mode 100644 benchmark/gcloud_label.py delete mode 100644 benchmark/profile.sh delete mode 100644 benchmark/pymupdf_test.py delete mode 100644 benchmark/tesseract_test.py delete mode 100644 benchmark/viz.sh diff --git a/benchmark/detection.py b/benchmark/detection.py index ab310a27..c121a2fb 100644 --- a/benchmark/detection.py +++ b/benchmark/detection.py @@ -6,12 +6,12 @@ from surya.benchmark.bbox import get_pdf_lines from surya.benchmark.metrics import precision_recall from surya.benchmark.tesseract import tesseract_parallel -from surya.model.detection.model import load_model, load_processor from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb -from surya.detection import batch_text_detection from surya.postprocessing.heatmap import draw_polys_on_image from surya.postprocessing.util import rescale_bbox from surya.settings import settings +from surya.detection import DetectionPredictor + import os import time from tabulate import tabulate @@ -27,8 +27,7 @@ def main(): parser.add_argument("--tesseract", action="store_true", help="Run tesseract as well.", default=False) args = parser.parse_args() - model = load_model() - processor = load_processor() + det_predictor = DetectionPredictor() if args.pdf_path is not None: pathname = args.pdf_path @@ -56,10 +55,10 @@ def main(): if settings.DETECTOR_STATIC_CACHE: # Run through one batch to compile the model - batch_text_detection(images[:1], model, processor) + det_predictor(images[:1]) start = time.time() - predictions = batch_text_detection(images, model, processor) + predictions = det_predictor(images) surya_time = time.time() - start if args.tesseract: diff --git a/benchmark/gcloud_label.py b/benchmark/gcloud_label.py deleted file mode 100644 index 5c9012df..00000000 --- a/benchmark/gcloud_label.py +++ /dev/null @@ -1,149 +0,0 @@ -import argparse -import json -from collections import defaultdict - -import datasets -from surya.settings import settings -from google.cloud import vision -import hashlib -import os -from tqdm import tqdm -import io - -DATA_DIR = os.path.join(settings.BASE_DIR, settings.DATA_DIR) -RESULT_DIR = os.path.join(settings.BASE_DIR, settings.RESULT_DIR) - -rtl_langs = ["ar", "fa", "he", "ur", "ps", "sd", "yi", "ug"] - -def polygon_to_bbox(polygon): - x = [vertex["x"] for vertex in polygon["vertices"]] - y = [vertex["y"] for vertex in polygon["vertices"]] - return (min(x), min(y), max(x), max(y)) - - -def text_with_break(text, property, is_rtl=False): - break_type = None - prefix = False - if property: - if "detectedBreak" in property: - if "type" in property["detectedBreak"]: - break_type = property["detectedBreak"]["type"] - if "isPrefix" in property["detectedBreak"]: - prefix = property["detectedBreak"]["isPrefix"] - break_char = "" - if break_type == 1: - break_char = " " - if break_type == 5: - break_char = "\n" - - if is_rtl: - prefix = not prefix - - if prefix: - text = break_char + text - else: - text = text + break_char - return text - - -def bbox_overlap_pct(box1, box2): - x1, y1, x2, y2 = box1 - x3, y3, x4, y4 = box2 - dx = min(x2, x4) - max(x1, x3) - dy = min(y2, y4) - max(y1, y3) - if (dx >= 0) and (dy >= 0): - return dx * dy / ((x2 - x1) * (y2 - y1)) - return 0 - - -def annotate_image(img, client, language, cache_dir): - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format=img.format) - img_byte_arr = img_byte_arr.getvalue() - - img_hash = hashlib.sha256(img_byte_arr).hexdigest() - cache_path = os.path.join(cache_dir, f"{img_hash}.json") - if os.path.exists(cache_path): - with open(cache_path, "r") as f: - response = json.load(f) - return response - - gc_image = vision.Image(content=img_byte_arr) - context = vision.ImageContext(language_hints=[language]) - response = client.document_text_detection(image=gc_image, image_context=context) - response_json = vision.AnnotateImageResponse.to_json(response) - loaded_response = json.loads(response_json) - with open(cache_path, "w+") as f: - json.dump(loaded_response, f) - return loaded_response - - -def get_line_text(response, lines, is_rtl=False): - document = response["fullTextAnnotation"] - - bounds = [] - for page in document["pages"]: - for block in page["blocks"]: - for paragraph in block["paragraphs"]: - for word in paragraph["words"]: - for symbol in word["symbols"]: - bounds.append((symbol["boundingBox"], symbol["text"], symbol.get("property"))) - - bboxes = [(polygon_to_bbox(b[0]), text_with_break(b[1], b[2], is_rtl)) for b in bounds] - line_boxes = defaultdict(list) - for i, bbox in enumerate(bboxes): - max_overlap_pct = 0 - max_overlap_idx = None - for j, line in enumerate(lines): - overlap = bbox_overlap_pct(bbox[0], line) - if overlap > max_overlap_pct: - max_overlap_pct = overlap - max_overlap_idx = j - if max_overlap_idx is not None: - line_boxes[max_overlap_idx].append(bbox) - - ocr_lines = [] - for j, line in enumerate(lines): - ocr_bboxes = sorted(line_boxes[j], key=lambda x: x[0][0]) - if is_rtl: - ocr_bboxes = list(reversed(ocr_bboxes)) - ocr_text = "".join([b[1] for b in ocr_bboxes]) - ocr_lines.append(ocr_text) - - assert len(ocr_lines) == len(lines) - return ocr_lines - - -def main(): - parser = argparse.ArgumentParser(description="Label text in dataset with google cloud vision.") - parser.add_argument("--project_id", type=str, help="Google cloud project id.", required=True) - parser.add_argument("--service_account", type=str, help="Path to service account json.", required=True) - parser.add_argument("--max", type=int, help="Maximum number of pages to label.", default=None) - args = parser.parse_args() - - cache_dir = os.path.join(DATA_DIR, "gcloud_cache") - os.makedirs(cache_dir, exist_ok=True) - - dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split="train") - client = vision.ImageAnnotatorClient.from_service_account_json(args.service_account) - - all_gc_lines = [] - for i in tqdm(range(len(dataset))): - img = dataset[i]["image"] - lines = dataset[i]["bboxes"] - language = dataset[i]["language"] - - response = annotate_image(img, client, language, cache_dir) - ocr_lines = get_line_text(response, lines, is_rtl=language in rtl_langs) - - all_gc_lines.append(ocr_lines) - - if args.max is not None and i >= args.max: - break - - with open(os.path.join(RESULT_DIR, "gcloud_ocr.json"), "w+") as f: - json.dump(all_gc_lines, f) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/benchmark/layout.py b/benchmark/layout.py index 41c7b771..a8294039 100644 --- a/benchmark/layout.py +++ b/benchmark/layout.py @@ -4,10 +4,8 @@ import json from surya.benchmark.metrics import precision_recall -from surya.model.layout.model import load_model -from surya.model.layout.processor import load_processor +from surya.layout import LayoutPredictor from surya.input.processing import convert_if_not_rgb -from surya.layout import batch_layout_detection from surya.postprocessing.heatmap import draw_bboxes_on_image from surya.settings import settings import os @@ -23,8 +21,7 @@ def main(): parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) args = parser.parse_args() - model = load_model() - processor = load_processor() + layout_predictor = LayoutPredictor() pathname = "layout_bench" # These have already been shuffled randomly, so sampling from the start is fine @@ -33,10 +30,10 @@ def main(): images = convert_if_not_rgb(images) if settings.LAYOUT_STATIC_CACHE: - batch_layout_detection(images[:1], model, processor) + layout_predictor(images[:1]) start = time.time() - layout_predictions = batch_layout_detection(images, model, processor) + layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] diff --git a/benchmark/ordering.py b/benchmark/ordering.py index b48041ec..a514217d 100644 --- a/benchmark/ordering.py +++ b/benchmark/ordering.py @@ -4,9 +4,7 @@ import json from surya.input.processing import convert_if_not_rgb -from surya.layout import batch_layout_detection -from surya.model.layout.model import load_model -from surya.model.layout.processor import load_processor +from surya.layout import LayoutPredictor from surya.schema import Bbox from surya.settings import settings from surya.benchmark.metrics import rank_accuracy @@ -21,9 +19,7 @@ def main(): parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None) args = parser.parse_args() - model = load_model() - processor = load_processor() - + layout_predictor = LayoutPredictor() pathname = "order_bench" # These have already been shuffled randomly, so sampling from the start is fine split = "train" @@ -34,7 +30,7 @@ def main(): images = convert_if_not_rgb(images) start = time.time() - layout_predictions = batch_layout_detection(images, model, processor) + layout_predictions = layout_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] diff --git a/benchmark/profile.sh b/benchmark/profile.sh deleted file mode 100644 index f6cecc41..00000000 --- a/benchmark/profile.sh +++ /dev/null @@ -1 +0,0 @@ -python -m cProfile -s time -o data/profile.pstats detect_text.py data/benchmark/nyt2.pdf --max 10 \ No newline at end of file diff --git a/benchmark/pymupdf_test.py b/benchmark/pymupdf_test.py deleted file mode 100644 index 7d82b617..00000000 --- a/benchmark/pymupdf_test.py +++ /dev/null @@ -1,39 +0,0 @@ -import argparse -import os - -from surya.benchmark.bbox import get_pdf_lines -from surya.postprocessing.heatmap import draw_bboxes_on_image - -from surya.input.processing import open_pdf, get_page_images -from surya.settings import settings - - -def main(): - parser = argparse.ArgumentParser(description="Draw pymupdf line bboxes on images.") - parser.add_argument("pdf_path", type=str, help="Path to PDF to detect bboxes in.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "pymupdf")) - args = parser.parse_args() - - doc = open_pdf(args.pdf_path) - page_count = len(doc) - page_indices = list(range(page_count)) - - images = get_page_images(doc, page_indices) - doc.close() - - image_sizes = [img.size for img in images] - pdf_lines = get_pdf_lines(args.pdf_path, image_sizes) - - folder_name = os.path.basename(args.pdf_path).split(".")[0] - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) - - for idx, (img, bboxes) in enumerate(zip(images, pdf_lines)): - bbox_image = draw_bboxes_on_image(bboxes, img) - bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) - - print(f"Wrote results to {result_path}") - -if __name__ == "__main__": - main() - diff --git a/benchmark/table_recognition.py b/benchmark/table_recognition.py index 376543de..cfa96d59 100644 --- a/benchmark/table_recognition.py +++ b/benchmark/table_recognition.py @@ -5,9 +5,7 @@ from tabulate import tabulate from surya.input.processing import convert_if_not_rgb -from surya.model.table_rec.model import load_model -from surya.model.table_rec.processor import load_processor -from surya.tables import batch_table_recognition +from surya.table_rec import TableRecPredictor from surya.settings import settings from surya.benchmark.metrics import penalized_iou_score from surya.benchmark.tatr import load_tatr, batch_inference_tatr @@ -23,8 +21,7 @@ def main(): parser.add_argument("--tatr", action="store_true", help="Run table transformer.", default=False) args = parser.parse_args() - model = load_model() - processor = load_processor() + table_rec_predictor = TableRecPredictor() pathname = "table_rec_bench" # These have already been shuffled randomly, so sampling from the start is fine @@ -37,10 +34,10 @@ def main(): if settings.TABLE_REC_STATIC_CACHE: # Run through one batch to compile the model - batch_table_recognition(images[:1], model, processor) + table_rec_predictor(images[:1]) start = time.time() - table_rec_predictions = batch_table_recognition(images, model, processor) + table_rec_predictions = table_rec_predictor(images) surya_time = time.time() - start folder_name = os.path.basename(pathname).split(".")[0] diff --git a/benchmark/tesseract_test.py b/benchmark/tesseract_test.py deleted file mode 100644 index 49ca86b4..00000000 --- a/benchmark/tesseract_test.py +++ /dev/null @@ -1,38 +0,0 @@ -import argparse -import os - -from surya.benchmark.tesseract import tesseract_bboxes -from surya.postprocessing.heatmap import draw_bboxes_on_image - -from surya.input.processing import open_pdf, get_page_images -from surya.settings import settings - - -def main(): - parser = argparse.ArgumentParser(description="Draw tesseract bboxes on imagese.") - parser.add_argument("pdf_path", type=str, help="Path to PDF to detect bboxes in.") - parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "tesseract")) - args = parser.parse_args() - - doc = open_pdf(args.pdf_path) - page_count = len(doc) - page_indices = list(range(page_count)) - - images = get_page_images(doc, page_indices) - doc.close() - - img_boxes = [tesseract_bboxes(img) for img in images] - - folder_name = os.path.basename(args.pdf_path).split(".")[0] - result_path = os.path.join(args.results_dir, folder_name) - os.makedirs(result_path, exist_ok=True) - - for idx, (img, bboxes) in enumerate(zip(images, img_boxes)): - bbox_image = draw_bboxes_on_image(bboxes, img) - bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) - - print(f"Wrote results to {result_path}") - -if __name__ == "__main__": - main() - diff --git a/benchmark/viz.sh b/benchmark/viz.sh deleted file mode 100644 index f642982d..00000000 --- a/benchmark/viz.sh +++ /dev/null @@ -1 +0,0 @@ -snakeviz data/profile.pstats \ No newline at end of file