Skip to content

Commit

Permalink
Update pyopal.align to support reusing a thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
althonos committed Jan 18, 2024
1 parent c0e4278 commit dc94320
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion pyopal/_align.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import functools
import multiprocessing.pool
import os
Expand All @@ -22,6 +23,19 @@
ALIGN_OVERFLOW = Literal["simple", "buckets"]
ALIGN_ALGORITHM = Literal["nw", "hw", "ov", "sw"]

T = typing.TypeVar("T")

@contextlib.contextmanager
def nullcontext(enter_result: T) -> typing.Iterator[T]:
"""Return a context manager that returns its input and does nothing.
Adapted from `contextlib.nullcontext` for backwards compatibility
with Python 3.6.
"""
yield obj


@typing.overload
def align(
query: typing.Union[str, bytes, bytearray],
Expand Down Expand Up @@ -63,6 +77,7 @@ def align(
overflow: Literal["simple", "buckets"] = "buckets",
algorithm: Literal["nw", "hw", "ov", "sw"] = "sw",
threads: int = 0,
pool: typing.Optional[multiprocessing.pool.ThreadPool] = None
) -> typing.Iterator[ScoreResult]:
"""Align the query sequence to every database sequence in parallel.
Expand Down Expand Up @@ -102,6 +117,11 @@ def align(
of threads reported by `os.cpu_count`. If one given, use
the main threads for aligning, otherwise spawns a
`multiprocessing.pool.ThreadPool`.
pool (`multiprocessing.pool.ThreadPool`): A running pool
instance to use for parallelization. Useful for reusing
the same pool across several calls of `~pyopal.align`.
If `None` give, spawn a new pool based on the ``threads``
argument.
Yields:
`~pyopal.ScoreResult`: Results for the alignment of the query
Expand Down Expand Up @@ -143,9 +163,14 @@ def align(
algorithm=algorithm
)
else:
#
if pool is None:
_pool_context = multiprocessing.pool.ThreadPool(threads)
else:
_pool_context = nullcontext(pool)
# cut the database in chunks of similar length
chunk_length = len(database) // threads
with multiprocessing.pool.ThreadPool(threads) as pool:
with _pool_context as pool:
align = functools.partial(
aligner._align_slice, # type: ignore
query,
Expand Down

0 comments on commit dc94320

Please sign in to comment.