#!/usr/bin/env python3
import sys
import numpy as np
import numpy.fft as fft
import gzip
import argparse
#import ase.io

DESCRIPTION = """calculates velocity-velocity autocorrelation"""

# Padding with zeros for FFT (N means N times more zeros than data)
NEXTENSION = 10

# XYZ format descriptors
datatype = np.dtype([ ( 'posx', 'f8' ),
                     ( 'posy', 'f8' ),
                     ( 'posz', 'f8' ),
                     ( 'charge', 'f8' ),
                     ( 'vx', 'f8' ),
                     ( 'vy', 'f8' ),
                     ( 'vz', 'f8' )
                 ])

fields = [ ( 'positions', 'f8', ( 'posx', 'posy', 'posz' )),
           ( 'charges', 'f8', ( 'charge', )),
           ( 'velocities', 'f8', ( 'vx', 'vy', 'vz' )),
         ]


def parse_arguments():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument("geofile", help="geometry file to process")
    parser.add_argument("prefix", help="prefix in output file names")
    parser.add_argument("firstframe", type=int, help="First frame to consider")
    parser.add_argument("lastframe", type=int, help="Last frame to consider")
    parser.add_argument("corframes", type=int,
                        help="Nr. of frames to consider for the correlation")
    parser.add_argument("timestep", type=float, help="time step in picoseconds")
    return parser.parse_args()



def get_nr_atoms_and_frames(fp):
    curpos = fp.tell()
    natom = int(fp.readline())
    fp.seek(curpos)
    nframe = 0
    line = fp.readline()
    while line:
        fp.readline()
        for ii in range(natom):
            fp.readline()
        line = fp.readline()
        nframe += 1
    fp.seek(curpos)
    return natom, nframe
    

def readframe(fp, datatype, fields, natom):
    fp.readline()
    fp.readline()
    symbols = []
    datastr = []
    for ii in range(natom):
        words = fp.readline().split()
        symbols.append(words[0])
        datastr.append(tuple(words[1:]))
    data = np.array(datastr, dtype=datatype)
    arrays = {}
    for fieldname, fieldtype, fieldids in fields:
        naxis = len(fieldids)
        shape = ( natom, naxis )
        mymatrix = np.empty(shape, dtype=fieldtype)
        for ii, fieldid in enumerate(fieldids):
            mymatrix[:,ii] = data[fieldid]
        if mymatrix.shape[1] == 1:
            mymatrix.shape = ( mymatrix.shape[0], )
        arrays[fieldname] = mymatrix
    return arrays


def get_velocity_autocorrelation_full(velocities):
    """Calculates the autocorrelation with numpy.

    See: dx.doi.org/10.1021/ct300152t
    See: http://www.timteatro.net/2010/09/29/velocity-autocorrelation-and-vibrational-spectrum-calculation/
    """
    nframe = velocities.shape[1]
    velo_autocorr = np.zeros(( nframe, ), dtype=float)
    # Loop over all velocity components
    for vt in velocities:
        full = np.correlate(vt, vt, mode='full')
        # mode='full' delivers a vector of the size 2*nframe - 1
        # We need only the second half (starting from the middle) as
        # this contains the positive time shifts
        pos_shifted = full[nframe-1:]
        velo_autocorr += pos_shifted
    # Divide each R(tau) by number of possible different origins
    velo_autocorr /= np.arange(nframe, 0, -1)
    velo_autocorr /= velo_autocorr[0]
    return velo_autocorr


def get_velocity_autocorrelation_valid(velocities, maxtime):
    nframe_all = velocities.shape[1]
    nframe_maxtime = maxtime
    nframe_valid = nframe_all - nframe_maxtime
    nframe_autocorr = nframe_maxtime + 1
    velo_autocorr = np.zeros(( nframe_autocorr, ), dtype=float)
    for vt in velocities:
        vaf = np.correlate(vt, vt[:nframe_valid], mode='valid')
        velo_autocorr += vaf
    velo_autocorr /= velo_autocorr[0]
    return velo_autocorr


def von_hann_filter(signal):
    npoint = len(signal)
    nn = np.arange(npoint)
    # von-Hann window (hanning)
    mask = 0.5 * (1.0 + np.cos(((2.0 * np.pi) / (2.0 * npoint)) * nn))
    return signal * mask

def do_fft(signal, filter=None, extension=0):
    if filter:
        signal_masked = filter(signal)
    else:
        signal_masked = signal
    npoints = (1 + NEXTENSION) * len(signal_masked)
    spectrum = fft.rfft(signal_masked, n=npoints)
    return spectrum


def write_data(fname, dx, yvals):
    print("Writing data file '{}'".format(fname))
    with open(fname, "w") as fp:
        for ii, yy in enumerate(yvals):
            fp.write("{:20.12E} {:20.12E}\n".format(ii * dx, yy))
    


def main():
    args = parse_arguments()
    print("Reading xyz")
    source = args.geofile
    if source.endswith(".gz"):
        fp = gzip.open(source, "r")
    else:
        fp = open(source, "r")
    natom, nframe_all = get_nr_atoms_and_frames(fp)
    print("Found {:d} atoms in {:d} frames".format(natom, nframe_all))

    firstframe = args.firstframe
    lastframe = args.lastframe
    if lastframe < 0:
        lastframe = nframe_all + lastframe + 1
    frames = [ readframe(fp, datatype, fields, natom) 
               for iframe in range(nframe_all) ]
    frames = frames[firstframe:lastframe]


    print("Extracting velocities")
    tmp = [ frame['velocities'].flat for frame in frames ]
    # velocities -> ( 3 * Natom, Ntimesteps )
    velocities = np.vstack(tmp).transpose()

    print("Calculating autocorrelation (valid)")
    velo_autocorr = get_velocity_autocorrelation_valid(velocities, 
                                                       args.corframes)
    ntimestep = len(velo_autocorr)


    print("FFT")
    spectrum = np.real(do_fft(velo_autocorr, extension=NEXTENSION,
                              filter=von_hann_filter))

    print("Writing results")
    write_data(args.prefix + ".vv.dat", args.timestep, velo_autocorr)

    dcm1 = (1.0 / args.timestep) * 3.33565e1   #  1 / ps -> cm^-1
    dcm1 /= (1 + NEXTENSION) * len(velo_autocorr)
    write_data(args.prefix + ".spectrum.dat", dcm1, spectrum)


if __name__ == "__main__":
    main()
