Skip to content

Commit

Permalink
Create a A @ B.dag() operation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ericgig committed Jan 15, 2025
1 parent 216fb31 commit a5671aa
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
204 changes: 204 additions & 0 deletions qutip/core/data/matmul.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ cnp.import_array()
cdef extern from *:
void *PyMem_Calloc(size_t n, size_t elsize)


cdef extern from "<complex>" namespace "std" nogil:
double complex conj "conj"(double complex x)

# This function is templated over integral types on import to allow `idxint` to
# be any signed integer (though likely things will only work for >=32-bit). To
# change integral types, you only need to change the `idxint` definitions in
Expand All @@ -56,6 +60,7 @@ __all__ = [
'matmul', 'matmul_csr', 'matmul_dense', 'matmul_dia',
'matmul_csr_dense_dense', 'matmul_dia_dense_dense', 'matmul_dense_dia_dense',
'multiply', 'multiply_csr', 'multiply_dense', 'multiply_dia',
'matmul_dag', 'matmul_dag_data', 'matmul_dag_dense', 'matmul_dag_dense_csr_dense',
]


Expand Down Expand Up @@ -518,6 +523,162 @@ cpdef Dense matmul_dense_dia_dense(Dense left, Dia right, double complex scale=1
return out


cpdef Dense matmul_dag_dense_csr_dense(
Dense left, CSR right,
double complex scale=1, Dense out=None
):
"""
Perform the operation
``out := scale * (left @ right) + out``
where `left`, `right` and `out` are matrices. `scale` is a complex scalar,
defaulting to 1.
If `out` is not given, it will be allocated as if it were a zero matrix.
"""
if left.shape[1] != right.shape[1]:
raise ValueError(
"incompatible matrix shapes "
+ str(left.shape)
+ " and "
+ str(right.shape)
)
if (
out is not None and (
out.shape[0] != left.shape[0]
or out.shape[1] != right.shape[0]
)
):
raise ValueError(
"incompatible output shape, got "
+ str(out.shape)
+ " but needed "
+ str((left.shape[0], right.shape[0]))
)
cdef Dense tmp = None
if out is None:
out = dense.zeros(left.shape[0], right.shape[0], left.fortran)
if bool(left.fortran) != bool(out.fortran):
msg = (
"out matrix is {}-ordered".format('Fortran' if out.fortran else 'C')
+ " but input is {}-ordered".format('Fortran' if left.fortran else 'C')
)
warnings.warn(msg, OrderEfficiencyWarning)
# Rather than making loads of copies of the same code, we just moan at
# the user and then transpose one of the arrays. We prefer to have
# `right` in Fortran-order for cache efficiency.
if left.fortran:
tmp = out
out = out.reorder()
else:
left = left.reorder()
cdef idxint row, col, ptr, idx_l, idx_out, out_row
cdef idxint stride_in_col, stride_in_row, stride_out_row, stride_out_col
cdef idxint nrows=left.shape[0], ncols=right.shape[1]
cdef double complex val
stride_in_col = left.shape[0] if left.fortran else 1
stride_in_row = 1 if left.fortran else left.shape[1]
stride_out_col = out.shape[0] if out.fortran else 1
stride_out_row = 1 if out.fortran else out.shape[1]

# A @ B.dag = (B* @ A.T).T
# Todo: make a conj version of _matmul_csr_vector?
for row in range(right.shape[0]):
for ptr in range(right.row_index[row], right.row_index[row + 1]):
val = scale * conj(right.data[ptr])
col = right.col_index[ptr]
for out_row in range(out.shape[0]):
idx_out = row * stride_out_col + out_row * stride_out_row
idx_l = col * stride_in_col + out_row * stride_in_row
out.data[idx_out] += val * left.data[idx_l]
if tmp is None:
return out
memcpy(tmp.data, out.data, ncols * nrows * sizeof(double complex))
return tmp


cpdef Dense matmul_dag_dense(
Dense left, Dense right,
double complex scale=1., Dense out=None
):
# blas support matmul for normal, transpose, adjoint for fortran ordered
# matrices.
if left.shape[1] != right.shape[1]:
raise ValueError(
"incompatible matrix shapes "
+ str(left.shape)
+ " and "
+ str(right.shape)
)
if (
out is not None and (
out.shape[0] != left.shape[0]
or out.shape[1] != right.shape[0]
)
):
raise ValueError(
"incompatible output shape, got "
+ str(out.shape)
+ " but needed "
+ str((left.shape[0], right.shape[0]))
)
cdef Dense a, b, out_add=None
cdef double complex alpha = 1., out_scale = 0.
cdef int m, n, k = left.shape[1], lda, ldb, ldc
cdef char left_code, right_code

if not right.fortran:
# Need a conjugate, we compute the transpose of the desired results.
# A.conj @ B^op -> (B^T^op @ A.dag)^T
if out is not None and out.fortran:
# out is not the right order, create an empty out and add it back.
out_add = out
out = dense.empty(left.shape[0], right.shape[0], False)
elif out is None:
out = dense.empty(left.shape[0], right.shape[0], False)
else:
out_scale = 1.
m = right.shape[0]
n = left.shape[0]
a, b = right, left

lda = right.shape[1]
ldb = left.shape[0] if left.fortran else left.shape[1]
ldc = right.shape[0]

left_code = b'C'
right_code = b'T' if left.fortran else b'N'
else:
if out is not None and not out.fortran:
out_add = out
out = dense.empty(left.shape[0], right.shape[0], True)
elif out is None:
out = dense.empty(left.shape[0], right.shape[0], True)
else:
out_scale = 1.

m = left.shape[0]
n = right.shape[0]
a, b = left, right

lda = left.shape[0] if left.fortran else left.shape[1]
ldb = right.shape[0]
ldc = left.shape[0]

left_code = b'N' if left.fortran else b'T'
right_code = b'C'

blas.zgemm(
&left_code, &right_code, &m, &n, &k,
&scale, a.data, &lda, b.data, &ldb,
&out_scale, out.data, &ldc
)

if out_add is not None:
out = iadd_dense(out, out_add)

return out


cpdef CSR multiply_csr(CSR left, CSR right):
"""Element-wise multiplication of CSR matrices."""
if left.shape[0] != right.shape[0] or left.shape[1] != right.shape[1]:
Expand Down Expand Up @@ -732,6 +893,49 @@ multiply.add_specialisations([
], _defer=True)


cpdef Data matmul_dag_data(
Data left, Data right,
double complex scale=1, Dense out=None
):
return matmul(left, right.adjoint(), scale, out)


matmul_dag = _Dispatcher(
_inspect.Signature([
_inspect.Parameter('left', _inspect.Parameter.POSITIONAL_ONLY),
_inspect.Parameter('right', _inspect.Parameter.POSITIONAL_ONLY),
_inspect.Parameter('scale', _inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=1),
]),
name='matmul_dag',
module=__name__,
inputs=('left', 'right'),
out=True,
)
matmul_dag.__doc__ =\
"""
Compute the matrix multiplication of two matrices, with the operation
scale * (left @ right.dag)
where `scale` is (optionally) a scalar, and `left` and `right` are
matrices.
Parameters
----------
left : Data
The left operand as either a bra or a ket matrix.
right : Data
The right operand as a ket matrix.
scale : complex, optional
The scalar to multiply the output by.
"""
matmul_dag.add_specialisations([
(Dense, CSR, Dense, matmul_dag_dense_csr_dense),
(Dense, Dense, Dense, matmul_dag_dense),
(Data, Data, Data, matmul_dag_data),
], _defer=True)

del _inspect, _Dispatcher


Expand Down
22 changes: 22 additions & 0 deletions qutip/tests/core/data/test_mathematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,28 @@ def op_numpy(self, left, right):
]


class TestMatmulDag(BinaryOpMixin):
def op_numpy(self, left, right):
return np.matmul(left, right)

shapes = shapes_binary_matmul()
bad_shapes = shapes_binary_bad_matmul()
specialisations = [
pytest.param(
lambda l, r: data.matmul_dag_data(l, r.adjoint()),
CSR, CSR, CSR
),
pytest.param(
lambda l, r: data.matmul_dag_dense_csr_dense(l, r.adjoint()),
Dense, CSR, Dense
),
pytest.param(
lambda l, r: data.matmul_dag_dense(l, r.adjoint()),
Dense, Dense, Dense
),
]


class TestMultiply(BinaryOpMixin):
def op_numpy(self, left, right):
return left * right
Expand Down

0 comments on commit a5671aa

Please sign in to comment.