Skip to content
Snippets Groups Projects
Commit ab7fb70c authored by Ivan Kondov's avatar Ivan Kondov
Browse files

introduce a step (stride) to reduce trajectory output and optimizations

parent 26943420
No related branches found
No related tags found
No related merge requests found
......@@ -56,24 +56,25 @@ class Propagator:
""" perform a time step - stub method """
return (coordinate, momentum)
def get_trajectory(self):
def get_trajectory(self, step=1):
""" return a trajectory - a time series of coordinate and momentum """
time = np.array(self.time, dtype='float')
coordinate = np.array(self.sq, dtype='float')
momentum = np.array(self.sp, dtype='float')
time = np.array(self.time, dtype='float')[::step]
coordinate = np.array(self.sq, dtype='float')[::step]
momentum = np.array(self.sp, dtype='float')[::step]
return time, coordinate, momentum
def get_trajectory_3d(self):
def get_trajectory_3d(self, step=1):
""" return a trajectory of coordinates grouped in 3-tuples """
ti = np.array(self.time, dtype='float')
qt = np.array([list(zip(t[0::3], t[1::3], t[2::3])) for t in self.sq], dtype=float)
pt = np.array([list(zip(t[0::3], t[1::3], t[2::3])) for t in self.sp], dtype=float)
ti, qt, pt = self.get_trajectory(step)
qt.shape = (len(qt), -1, 3)
pt.shape = (len(pt), -1, 3)
return ti, qt, pt
def analyse(self):
def analyse(self, step=1):
""" analyse the trajectory """
energy = [self.syst.energy(q, p) for q, p in zip(self.sq, self.sp)]
self.en = np.array(energy, dtype=float)
from itertools import islice
qptup = islice(zip(self.sq, self.sp), 0, None, step)
self.en = np.array([self.syst.energy(*t) for t in qptup], dtype=float)
self.er = (self.en-self.en0)/self.en0
def plot(self):
......
......@@ -2,14 +2,14 @@ from ase import Atoms
from ase.io import write
from units import LJUnitsElement
def write_ase_trajectory(propagator, atoms, filename='trajectory.traj'):
def write_ase_trajectory(propagator, atoms, filename='trajectory.traj', step=1):
""" currently restricted to a single element, no export of time axis """
symbols = atoms.get_chemical_symbols()
assert len(set(symbols)) == 1
ljunits = LJUnitsElement(symbols[0])
pbc = atoms.get_pbc()
cell = atoms.get_cell()
traj_3d = propagator.get_trajectory_3d()
traj_3d = propagator.get_trajectory_3d(step)
qs = traj_3d[1]*ljunits.sigmaA
ps = traj_3d[2]*ljunits.momentumA
traj = [Atoms(symbols, cell=cell, pbc=pbc, positions=q, momenta=p)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment