Skip to content

Commit

Permalink
bmcage#77 fix bugs with jac_times_vecfc and jac_times_setupfn
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed May 5, 2019
1 parent 4e8303c commit eb7f3fd
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 25 deletions.
43 changes: 42 additions & 1 deletion scikits/odes/sundials/ida.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,44 @@ cdef class IDA_WrapPrecSolveFunction(IDA_PrecSolveFunction):
cdef int with_userdata
cpdef set_prec_solvefn(self, object prec_solvefn)

cdef class IDA_JacTimesVecFunction:
cpdef int evaluate(self,
DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
np.ndarray[DTYPE_t, ndim=1] v,
np.ndarray[DTYPE_t, ndim=1] Jv,
DTYPE_t cj,
object userdata = *) except? -1

cdef class IDA_WrapJacTimesVecFunction(IDA_JacTimesVecFunction):
cpdef object _jac_times_vecfn
cdef int with_userdata
cpdef set_jac_times_vecfn(self, object jac_times_vecfn)

cdef class IDA_JacTimesSetupFunction:
cpdef int evaluate(self,
DTYPE_t tt,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
DTYPE_t cj,
object userdata = *) except? -1

cdef class IDA_WrapJacTimesSetupFunction(IDA_JacTimesSetupFunction):
cpdef object _jac_times_setupfn
cdef int with_userdata
cpdef set_jac_times_setupfn(self, object jac_times_setupfn)

cpdef int evaluate(self,
DTYPE_t tt,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
DTYPE_t cj,
object userdata = *) except? -1

cdef class IDA_ContinuationFunction:
cpdef object _fn
cpdef int evaluate(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y,
Expand All @@ -91,12 +129,15 @@ cdef class IDA_WrapErrHandler(IDA_ErrHandler):


cdef class IDA_data:
cdef np.ndarray yy_tmp, yp_tmp, residual_tmp, jac_tmp, g_tmp, z_tmp, rvec_tmp
cdef np.ndarray yy_tmp, yp_tmp, residual_tmp, jac_tmp
cdef np.ndarray g_tmp, z_tmp, rvec_tmp, v_tmp
cdef IDA_RhsFunction res
cdef IDA_JacRhsFunction jac
cdef IDA_RootFunction rootfn
cdef IDA_PrecSetupFunction prec_setupfn
cdef IDA_PrecSolveFunction prec_solvefn
cdef IDA_JacTimesVecFunction jac_times_vecfn
cdef IDA_JacTimesSetupFunction jac_times_setupfn
cdef bint parallel_implementation
cdef object user_data
cdef IDA_ErrHandler err_handler
Expand Down
4 changes: 4 additions & 0 deletions scikits/odes/sundials/ida.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ cdef class IDA_data:
self.g_tmp = None
self.z_tmp = None
self.rvec_tmp = None
self.v_tmp = np.empty(N, DTYPE)
self.z_tmp = np.empty(N, DTYPE)

cdef class IDA:

Expand Down Expand Up @@ -812,6 +814,8 @@ cdef class IDA:
'precond_type': 'NONE',
'prec_setupfn': None,
'prec_solvefn': None,
'jac_times_vecfn': None,
'jac_times_setupfn': None,
'err_handler': None,
'err_user_data': None,
'old_api': None,
Expand Down
68 changes: 44 additions & 24 deletions scikits/odes/tests/test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,51 @@ def _do_problem(self, problem, integrator, old_api=True, **integrator_params):
jac = problem.jac
res = problem.res

jac_tmp = None

def jac_times_vec(tt, yy, yp, rr, v, Jv, cj):
J = empty(len(yy), len(yy))
jac(tt, yy, yp, cj, J)
Js = sparse.csr_matrix(J)
Jv[:] = Js*v

def jac_times_vec2(tt, yy, yp, rr, v, Jv, cj, userdata):
Jv[:] = userdata * v

def jac_times_setupfn(tt, yy, yp, rr, cj, userdata):
J = empty(len(yy), len(yy))
jac(tt, yy, yp, cj, J)
userdata = sparse.csr_matrix(J)

igs = (
dae(integrator, res, jacfn=jac, old_api=old_api),
dae(integrator, res, jac_times_vec=jac_times_vec, old_api=old_api),
dae(integrator, res, jac_times_vec=jac_times_vec,
jac_times_setupfn=jac_times_setupfn, old_api=old_api,
user_data=jac_tmp),
)
class UserData:
def __init__(self):
self.J = None

my_userdata = UserData()

jac_times_vec = None
jac_times_vec2 = None
jac_times_setupfn = None

if jac is not None and integrator == 'ida':
def jac_times_vec(tt, yy, yp, rr, v, Jv, cj):
J = zeros((len(yy), len(yy)), DTYPE)
jac(tt, yy, yp, rr, cj, J)
Js = sparse.csr_matrix(J)
Jv[:] = Js * v
return 0

def jac_times_vec2(tt, yy, yp, rr, v, Jv, cj, userdata):
Jv[:] = userdata.J * v
return 0

def jac_times_setupfn(tt, yy, yp, rr, cj, userdata):
J = zeros((len(yy), len(yy)), DTYPE)
jac(tt, yy, yp, rr, cj, J)
userdata.J = sparse.csr_matrix(J)
return 0

igs = [dae(integrator, res, jacfn=jac, old_api=old_api)]

if integrator == 'ida':
igs.append(
dae(integrator, res, linsolver='spgmr',
jac_times_vecfn=jac_times_vec,
old_api=old_api)
)
igs.append(
dae(integrator, res, linsolver='spgmr',
jac_times_vecfn=jac_times_vec2,
jac_times_setupfn=jac_times_setupfn,
old_api=old_api,
user_data=my_userdata)
)

for ig in igs:
ig = dae(integrator, res, jacfn=jac, old_api=old_api)
ig.set_options(old_api=old_api, **integrator_params)
z = empty((1+len(problem.stop_t),alen(problem.z0)), DTYPE)
zprime = empty((1+len(problem.stop_t),alen(problem.z0)), DTYPE)
Expand All @@ -72,6 +91,7 @@ def jac_times_setupfn(tt, yy, yp, rr, cj, userdata):
assert problem.verify(array(z), array(zprime), [0.]+problem.stop_t), \
(problem.info(),)


def test_ddaspk(self):
"""Check the ddaspk solver"""
for problem_cls in PROBLEMS_DDASPK:
Expand Down

0 comments on commit eb7f3fd

Please sign in to comment.