Skip to content

Commit

Permalink
Make _encode and _decode C methods of Database more generic
Browse files Browse the repository at this point in the history
  • Loading branch information
althonos committed Jan 15, 2024
1 parent 3a06cde commit 0a3bda7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions pyopal/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ cdef class Database:
cdef vector[int] _lengths
cdef searchfn_t _search

cdef seq_t _encode(self, object sequence) except *
cdef str _decode(self, seq_t encoded, int length) except *
cdef digit_t* _encode(self, object sequence) except *
cdef str _decode(self, digit_t* encoded, int length) except *

cpdef void clear(self) except *
cpdef void extend(self, object sequences) except *
Expand Down
30 changes: 15 additions & 15 deletions pyopal/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -749,30 +749,30 @@ cdef class Database:

# --- Encoding -------------------------------------------------------------

cdef seq_t _encode(self, object sequence) except *:
cdef bytes encoded
cdef char* indices
cdef seq_t dst
cdef size_t length = len(sequence)
cdef digit_t* _encode(self, object sequence) except *:
cdef bytes encoded
cdef char* indices
cdef digit_t* dst
cdef size_t length = len(sequence)

dst = pyshared(<digit_t*> PyMem_Calloc(length, sizeof(digit_t)))
dst = <digit_t*> PyMem_Calloc(length, sizeof(digit_t))
if dst == nullptr:
raise MemoryError("Failed to allocate sequence data")

if SYS_IMPLEMENTATION_NAME == "cpython":
if isinstance(sequence, str):
sequence = sequence.encode('ascii')
view = PyMemoryView_FromMemory(<char*> dst.get(), length * sizeof(digit_t), PyBUF_WRITE)
view = PyMemoryView_FromMemory(<char*> dst, length * sizeof(digit_t), PyBUF_WRITE)
self.alphabet.encode_raw(sequence, view)
else:
encoded = self.alphabet.encode(sequence)
indices = <char*> encoded
memcpy(<void*> &dst.get()[0], <void*> indices, length * sizeof(digit_t))
memcpy(<void*> &dst[0], <void*> indices, length * sizeof(digit_t))

return dst

cdef str _decode(self, seq_t encoded, int length) except *:
view = PyMemoryView_FromMemory(<char*> encoded.get(), length, PyBUF_READ)
cdef str _decode(self, digit_t* encoded, int length) except *:
view = PyMemoryView_FromMemory(<char*> encoded, length, PyBUF_READ)
return self.alphabet.decode(view)

# --- Sequence interface ---------------------------------------------------
Expand All @@ -795,13 +795,13 @@ cdef class Database:
if index_ < 0 or (<size_t> index_) >= size:
raise IndexError(index)

return self._decode(self._sequences[index_], self._lengths[index_])
return self._decode(self._sequences[index_].get(), self._lengths[index_])

def __setitem__(self, ssize_t index, object sequence):
cdef size_t size
cdef ssize_t index_ = index
cdef int length = len(sequence)
cdef seq_t encoded = self._encode(sequence)
cdef seq_t encoded = pyshared(self._encode(sequence))

with self.lock.write:
size = self._sequences.size()
Expand Down Expand Up @@ -888,7 +888,7 @@ cdef class Database:
"""
cdef int length = len(sequence)
cdef seq_t encoded = self._encode(sequence)
cdef seq_t encoded = pyshared(self._encode(sequence))

with self.lock.write:
self._sequences.push_back(encoded)
Expand Down Expand Up @@ -931,7 +931,7 @@ cdef class Database:
cdef size_t size
cdef int length = len(sequence)
cdef ssize_t index_ = index
cdef seq_t encoded = self._encode(sequence)
cdef seq_t encoded = pyshared(self._encode(sequence))

with self.lock.write:
size = self._sequences.size()
Expand Down Expand Up @@ -1131,7 +1131,7 @@ cdef class Database:
raise ValueError("database and score matrix have different alphabets")

# encode query
query = self._encode(sequence)
query = pyshared(self._encode(sequence))

# search database
with self.lock.read:
Expand Down

0 comments on commit 0a3bda7

Please sign in to comment.