Skip to content
Snippets Groups Projects
Commit 838c2580 authored by Ivan Kondov's avatar Ivan Kondov
Browse files

corrected the shape of input arrays for positions and momenta

parent e3abbde3
No related branches found
No related tags found
No related merge requests found
......@@ -14,12 +14,14 @@ class MPLPTest(unittest.TestCase):
ffparams = [{'class': 'FFLennardJones', 'pair': ((0, 1, 2), (3, 4, 5))}]
symbols = ['0x:z', '1x:z']
masses = np.array([1, 1, 1, 1, 1, 1])
masses = np.array([1, 1])
def test_qvec_pvec_taylor(self):
""" test with a pre-calculated reference with only subs numerics """
qval = np.array([1, 2, 3, 1, 1, 1])
pval = np.array([0, 0, 0, 1, 1, 1])
qval = np.array([[1, 2, 3], [1, 1, 1]])
pval = np.array([[0, 0, 0], [1, 1, 1]])
qval_flat = qval.flatten()
pval_flat = pval.flatten()
numeric = {'tool': 'subs'}
qref = [[1.0, 2.0, 3.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[0.0, -0.0377856, -0.0755712, 0.0, 0.0377856, 0.0755712]]
......@@ -36,7 +38,7 @@ class MPLPTest(unittest.TestCase):
kwargs = {} if not var else {'variant': var}
mplpobj = mplpcls(3, system, numeric=numeric, recursion='taylor',
**kwargs)
qvec, pvec = mplpobj.get_val_float(qval, pval)
qvec, pvec = mplpobj.get_val_float(qval_flat, pval_flat)
self.assertTrue(np.allclose(qvec, qref))
self.assertTrue(np.allclose(pvec, pref))
......@@ -44,8 +46,10 @@ class MPLPTest(unittest.TestCase):
""" cross-check all classes with all numerics and variants """
from newtonian import leja_points_ga
qval = np.random.rand(6)
pval = np.random.rand(6)
qval = np.random.rand(2, 3)
pval = np.random.rand(2, 3)
qval_flat = qval.flatten()
pval_flat = pval.flatten()
cases = (
(MPLPQP, None, {'tool': 'subs'}),
......@@ -86,14 +90,14 @@ class MPLPTest(unittest.TestCase):
mplpobj = mplpcls(nterms, system, numeric=numeric, **pdict,
**kwargs)
if qref and pref:
qvec, pvec = mplpobj.get_val_float(qval, pval)
qvec, pvec = mplpobj.get_val_float(qval_flat, pval_flat)
relerr = (np.array(pref)-np.array(pvec))/np.array(pref)
msg = (repr(mplpcls.__name__)+repr(variant)+repr(numeric)+
repr(qval)+repr(pval)+repr(relerr))
repr(qval_flat)+repr(pval_flat)+repr(relerr))
self.assertTrue(np.allclose(qvec, qref, rtol=0.001), msg)
self.assertTrue(np.allclose(pvec, pref, rtol=0.001), msg)
else:
qref, pref = mplpobj.get_val_float(qval, pval)
qref, pref = mplpobj.get_val_float(qval_flat, pval_flat)
class ChebyshevSymplecticTest(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment