Skip to content

Commit

Permalink
Lower batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 20, 2025
1 parent 9f4bbee commit 317f71d
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TableRecPredictor(BasePredictor):
default_batch_sizes = {
"cpu": 8,
"mps": 8,
"cuda": 128
"cuda": 64
}

def __call__(self, images: List[Image.Image], batch_size: int | None = None) -> List[TableResult]:
Expand Down Expand Up @@ -295,13 +295,17 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
cell_polygonbox = PolygonBox(polygon=cell_polygon)
intersection_pct = cell_polygonbox.intersection_pct(spanning_cell)
# Make sure cells intersect, and that the spanning cell is wider than the current cell (takes up multiple columns)
if intersection_pct > .9 and spanning_cell.width > cell_polygonbox.width:
cell_added = True
if zz not in used_spanning_cells:
used_spanning_cells.add(zz)
spanning_cell.col_id = l
cells.append(spanning_cell)
skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
correct_col_width = sum([col.width for col in columns[l:l + spanning_cell.colspan]])
if intersection_pct > .9:
if spanning_cell.width > (correct_col_width * .85):
cell_added = True
if zz not in used_spanning_cells:
used_spanning_cells.add(zz)
spanning_cell.col_id = l
cells.append(spanning_cell)
skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
else:
used_spanning_cells.add(zz) # Skip this spanning cell

if not cell_added:
cells.append(
Expand Down

0 comments on commit 317f71d

Please sign in to comment.