Skip to content

Commit

Permalink
bmcage#77 add JacTimesSetupFn and JacTimesVecFn to IDA
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed May 4, 2019
1 parent 9ba1fb4 commit 3b515fd
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 3 deletions.
4 changes: 2 additions & 2 deletions scikits/odes/sundials/c_ida.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,12 @@ cdef extern from "ida/ida_spils.h":
realtype c_j, realtype delta, void *user_data)
ctypedef int (*IDASpilsJacTimesSetupFn)(realtype tt, N_Vector yy,
N_Vector yp, N_Vector rr,
realtype c_j, void *user_data)
realtype c_j, void *user_data) except? -1
ctypedef int (*IDASpilsJacTimesVecFn)(realtype tt,
N_Vector yy, N_Vector yp, N_Vector rr,
N_Vector v, N_Vector Jv,
realtype c_j, void *user_data,
N_Vector tmp1, N_Vector tmp2)
N_Vector tmp1, N_Vector tmp2) except? -1

int IDASpilsSetLinearSolver(void *ida_mem, SUNLinearSolver LS)
int IDASpilsSetPreconditioner(void *ida_mem,
Expand Down
235 changes: 234 additions & 1 deletion scikits/odes/sundials/ida.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,178 @@ cdef int _prec_solvefn(realtype tt, N_Vector yy, N_Vector yp, N_Vector r,

return user_flag

# JacTimesVec function
cdef class IDA_JacTimesVecFunction:
"""
Prototype for jacobian times vector function.
Note that evaluate must return a integer, 0 for success, non-zero for error
(as per IDA documentation).
"""
cpdef int evaluate(self,
DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
np.ndarray[DTYPE_t, ndim=1] v,
np.ndarray[DTYPE_t, ndim=1] Jv,
DTYPE_t cj,
object userdata = None) except? -1:

"""
This function calculates the product of the Jacobian with a given vector v.
Use the userdata object to expose Jacobian related data to the solve function.
This is a generic class, you should subclass it for the problem specific
purposes.
"""
return 0

cdef class IDA_WrapJacTimesVecFunction(IDA_JacTimesVecFunction):
cpdef set_jac_times_vecfn(self, object jac_times_vecfn):
"""
Set some IDA_JacTimesVecFn executable class.
"""
"""
set a jacobian-times-vector method as a IDA_JacTimesVecFunction
executable class
"""
self.with_userdata = 0
nrarg = _get_num_args(jac_times_vecfn)
if nrarg > 8:
#hopefully a class method, self gives 9 arg!
self.with_userdata = 1
elif nrarg == 8 and inspect.isfunction(jac_times_vecfn):
self.with_userdata = 1
self._jac_times_vecfn = jac_times_vecfn

cpdef int evaluate(self,
DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
np.ndarray[DTYPE_t, ndim=1] v,
np.ndarray[DTYPE_t, ndim=1] Jv,
DTYPE_t cj,
object userdata = None) except? -1:
if self.with_userdata == 1:
user_flag = self._jac_times_vecfn(rr, yy, yp, rr, v, Jv, cj, userdata)
else:
user_flag = self._jac_times_vecfn(rr, yy, yp, rr, v, Jv, cj)
if user_flag is None:
user_flag = 0
return user_flag

cdef int _jac_times_vecfn(realtype t, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector v,
N_Vector Jv, realtype cj, void *user_data, N_Vector tmp1, N_Vector tmp2) except? -1:
""" function with the signature of IDA_JacTimesVecFunction, that calls python function """
cdef np.ndarray[DTYPE_t, ndim=1] yy_tmp, yp_tmp, rr_tmp, v_tmp, Jv_tmp

aux_data = <IDA_data> user_data
cdef bint parallel_implementation = aux_data.parallel_implementation

if parallel_implementation:
raise NotImplemented

yy_tmp = aux_data.yy_tmp
yp_tmp = aux_data.yp_tmp
rr_tmp = aux_data.residual_tmp
v_tmp = aux_data.v_tmp
Jv_tmp = aux_data.z_tmp

nv_s2ndarray(yy, yy_tmp)
nv_s2ndarray(yp, yp_tmp)
nv_s2ndarray(rr, rr_tmp)
nv_s2ndarray(v, v_tmp)

user_flag = aux_data.jac_times_vecfn.evaluate(t, yy_tmp, yp_tmp, rr_tmp, v_tmp,
Jv_tmp, cj, aux_data.user_data)

ndarray2nv_s(Jv, Jv_tmp)

return user_flag

# JacTimesVec function
cdef class IDA_JacTimesSetupFunction:
"""
Prototype for jacobian times setup function.
Note that evaluate must return a integer, 0 for success, non-zero for error
(as per CVODE documentation), with >0 a recoverable error (step is retried).
"""
cpdef int evaluate(self,
DTYPE_t tt,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
DTYPE_t cj,
object userdata = None) except? -1:
"""
This function calculates the product of the Jacobian with a given vector v.
Use the userdata object to expose Jacobian related data to the solve function.
This is a generic class, you should subclass it for the problem specific
purposes.
"""
return 0

cdef class IDA_WrapJacTimesSetupFunction(IDA_JacTimesSetupFunction):
cpdef set_jac_times_setupfn(self, object jac_times_setupfn):
"""
Set some IDA_JacTimesSetupFn executable class.
"""
"""
set a jacobian-times-vector method setup as a IDA_JacTimesSetupFunction
executable class
"""
self.with_userdata = 0
nrarg = _get_num_args(jac_times_setupfn)
if nrarg > 6:
#hopefully a class method, self gives 7 arg!
self.with_userdata = 1
elif nrarg == 6 and inspect.isfunction(jac_times_setupfn):
self.with_userdata = 1
self._jac_times_setupfn = jac_times_setupfn

cpdef int evaluate(self,
DTYPE_t tt,
np.ndarray[DTYPE_t, ndim=1] yy,
np.ndarray[DTYPE_t, ndim=1] yp,
np.ndarray[DTYPE_t, ndim=1] rr,
DTYPE_t cj,
object userdata = None) except? -1:
if self.with_userdata == 1:
user_flag = self._jac_times_setupfn(tt, yy, yp, rr, cj, userdata)
else:
user_flag = self._jac_times_setupfn(tt, yy, yp, rr, cj)
if user_flag is None:
user_flag = 0
return user_flag

cdef int _jac_times_setupfn(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr,
realtype cj, void *user_data) except? -1:
""" function with the signature of IDA_JacTimesSetupFunction, that calls python function """
cdef np.ndarray[DTYPE_t, ndim=1] yy_tmp, yp_tmp, rr_tmp

aux_data = <IDA_data> user_data
cdef bint parallel_implementation = aux_data.parallel_implementation

if parallel_implementation:
raise NotImplemented

yy_tmp = aux_data.yy_tmp
yp_tmp = aux_data.yp_tmp
rr_tmp = aux_data.residual_tmp

nv_s2ndarray(yy, yy_tmp)
nv_s2ndarray(yp, yp_tmp)
nv_s2ndarray(rr, rr_tmp)

user_flag = aux_data.jac_times_setupfn.evaluate(tt, yy_tmp, yp_tmp, rr_tmp, cj, aux_data.user_data)

return user_flag


cdef class IDA_ContinuationFunction:
"""
Simple wrapper for functions called when ROOT or TSTOP are returned.
Expand Down Expand Up @@ -867,6 +1039,37 @@ cdef class IDA:
parameters gamma and delta, input flag lr that determines
the flavour of the preconditioner (left = 1, right = 2) and
optional userdata.
'jac_times_vecfn':
Values: function of class IDA_JacTimesVecFunction
Description:
Defines a function that solves the product of vector v
with an (approximate) Jacobian of the system J.
This function takes as input arguments:
tt is the current value of the independent variable.
yy is the current value of the dependent variable vector, y(t).
yp is the current value of ˙y(t).
rr is the current value of the residual vector F(t, y, y˙).
v is the vector by which the Jacobian must be multiplied to
the right.
Jv is the computed output vector.
cj is the scalar in the system Jacobian, proportional to the
inverse of the step size.
user data is a pointer to user data (optional)
'jac_times_setupfn':
Values: function of class IDA_JacTimesSetupFunction
Description:
Optional. Default is to internal finite difference with no
extra setup.
Defines a function that preprocesses and/or evaluates
Jacobian-related data needed by the Jacobiantimes-vector routine
This function takes as input arguments:
tt is the current value of the independent variable.
yy is the current value of the dependent variable vector, y(t).
yp is the current value of ˙y(t).
rr is the current value of the residual vector F(t, y, y˙).
cj is the scalar in the system Jacobian, proportional to the
inverse of the step size.
user data is a pointer to user data (optional)
'err_handler':
Values: function of class IDA_ErrHandler, default = None
Description:
Expand Down Expand Up @@ -1311,6 +1514,22 @@ cdef class IDA:
opts['prec_solvefn'] = tmpfun
self.aux_data.prec_solvefn = prec_solvefn

jac_times_vecfn = opts['jac_times_vecfn']
if jac_times_vecfn is not None and not isinstance(jac_times_vecfn, IDA_JacTimesVecFunction):
tmpfun = IDA_WrapJacTimesVecFunction()
tmpfun.set_jac_times_vecfn(jac_times_vecfn)
jac_times_vecfn = tmpfun
opts['jac_times_vecfn'] = tmpfun
self.aux_data.jac_times_vecfn = jac_times_vecfn

jac_times_setupfn = opts['jac_times_setupfn']
if jac_times_setupfn is not None and not isinstance(jac_times_setupfn, IDA_JacTimesSetupFunction):
tmpfun = IDA_WrapJacTimesSetupFunction()
tmpfun.set_jac_times_setupfn(jac_times_setupfn)
jac_times_setupfn = tmpfun
opts['jac_times_setupfn'] = tmpfun
self.aux_data.jac_times_setupfn = jac_times_setupfn

self._set_runtime_changeable_options(opts, supress_supported_check=True)

if flag == IDA_ILL_INPUT:
Expand Down Expand Up @@ -1409,7 +1628,7 @@ cdef class IDA:
elif flag == IDASPILS_ILL_INPUT:
raise MemoryError('linear solver memory was NULL')
elif flag != IDASPILS_SUCCESS:
raise ValueError('CVSpilsSetLinearSolver failed with code {}'
raise ValueError('IDASpilsSetLinearSolver failed with code {}'
.format(flag))
# TODO: make option for the Gram-Schmidt orthogonalization
#flag = SUNSPGMRSetGSType(LS, gstype);
Expand All @@ -1431,6 +1650,20 @@ cdef class IDA:
elif flag != IDASPILS_SUCCESS:
raise ValueError('IDASpilsSetPreconditioner failed with code {}'
.format(flag))

if self.aux_data.jac_times_vecfn:
if self.aux_data.jac_times_setupfn:
flag = IDASpilsSetJacTimes(ida_mem, _jac_times_setupfn, _jac_times_vecfn)
else:
flag = IDASpilsSetJacTimes(ida_mem, NULL, _jac_times_vecfn)
if flag == IDASPILS_MEM_NULL:
raise ValueError('LinSolver: The ida mem pointer is NULL.')
elif flag == IDASPILS_LMEM_NULL:
raise ValueError('LinSolver: The idaspils linear solver has '
'not been initialized.')
elif flag != IDASPILS_SUCCESS:
raise ValueError('IDASpilsSetJacTimes failed with code {}'
.format(flag))
else:
IF SUNDIALS_BLAS_LAPACK:
if linsolver == 'lapackdense':
Expand Down

0 comments on commit 3b515fd

Please sign in to comment.