diff --git a/surya/table_rec/__init__.py b/surya/table_rec/__init__.py index 6673983..01f26b7 100644 --- a/surya/table_rec/__init__.py +++ b/surya/table_rec/__init__.py @@ -22,7 +22,7 @@ class TableRecPredictor(BasePredictor): default_batch_sizes = { "cpu": 8, "mps": 8, - "cuda": 128 + "cuda": 64 } def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]: @@ -295,13 +295,17 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si cell_polygonbox = PolygonBox(polygon=cell_polygon) intersection_pct = cell_polygonbox.intersection_pct(spanning_cell) # Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns) - if intersection_pct > .9 and spanning_cell.width > cell_polygonbox.width: - cell_added = True - if zz not in used_spanning_cells: - used_spanning_cells.add(zz) - spanning_cell.col_id = l - cells.append(spanning_cell) - skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell + correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]]) + if intersection_pct > .9: + if spanning_cell.width > (correct_col_width * .85): + cell_added = True + if zz not in used_spanning_cells: + used_spanning_cells.add(zz) + spanning_cell.col_id = l + cells.append(spanning_cell) + skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell + else: + used_spanning_cells.add(zz) # Skip this spanning cell if not cell_added: cells.append(