import re
import tempfile

import numpy as np
import numpy.testing as npt

from ase import Atoms
from ase.io.cube import ATOMS, DATA, read_cube, read_cube_data, write_cube
from ase.units import Bohr

# Have some real data to write to a file
file_content = """Generated by octopus maya
 git: 5b62360b6ec96e210d8ebbbd817bea74da18dfa8 build: \
Fri Oct  8 14:58:44 CEST 2021
     5   -3.779452   -1.889726   -3.779452
     5    1.889726    0.000000    0.000000
     3    0.000000    1.889726    0.000000
     5    0.000000    0.000000    0.000000
     6    0.000000    0.000000    0.000000    0.000000
     1    0.000000    0.000000    0.000000    2.057912
     1    0.000000    1.940220    0.000000   -0.685971
     1    0.000000   -0.970110   -1.680269   -0.685971
     1    0.000000   -0.970110    1.680269   -0.685971
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.284545E-01  0.706656E-01  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.391115E-01  0.000000E+00  0.000000E+00
   0.000000E+00  0.284545E-01  0.706656E-01  0.000000E+00  0.000000E+00
   0.000000E+00  0.399121E-01  0.926179E-01  0.000000E+00  0.000000E+00
   0.000000E+00  0.325065E-01  0.272831E+00  0.173787E+00  0.149557E-01
   0.000000E+00  0.399121E-01  0.926179E-01  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.506939E-01  0.138292E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00
   0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00  0.000000E+00"""


def test_cube_writing():
    d = 1.104  # N2 bondlength
    at = Atoms("N2", [(0, 0, 0), (0, 0, d * Bohr)])

    dummydata = np.arange(8).reshape((2, 2, 2))
    origin_in = (42 * Bohr, 0, 0)
    comment_regex = r"(Cube file from ASE, written on )([a-zA-Z ])*([0-9: ])*"

    # create output
    with tempfile.NamedTemporaryFile(mode="r+") as outfil:
        write_cube(outfil, at, data=dummydata, origin=origin_in)
        # reset read head
        outfil.seek(0)

        # Check default comment
        comment_line = outfil.readline()
        assert re.match(comment_regex, comment_line)

        # Check constant string
        assert outfil.readline() == ("OUTER LOOP: X, MIDDLE LOOP: Y, "
                                     "INNER LOOP: Z\n")

        # Check origin
        origin_from_file = outfil.readline().split()[1:]
        origin_from_file = tuple(
            map(lambda p: float(p) * Bohr, origin_from_file))
        assert origin_from_file == origin_in

        # skip three lines
        outfil.readline()
        outfil.readline()
        outfil.readline()

        # check Atoms and positions
        atom1 = outfil.readline().split()
        assert atom1 == ["7", "0.000000", "0.000000", "0.000000", "0.000000"]
        atom2 = outfil.readline().split()
        assert atom2 == ["7", "0.000000", "0.000000",
                         "0.000000", f"{d:.6f}"]

        # Check data
        data_lines = list(
            map(lambda l: float(l.rstrip("\n")), outfil.readlines()))
        for idx, line in enumerate(data_lines):
            assert float(idx) == line


def test_cube_reading():
    with tempfile.NamedTemporaryFile(mode="r+") as cubefil:
        # Write data to a file
        cubefil.write(file_content)
        cubefil.seek(0)

        # read data using cube reading
        result = read_cube(cubefil)
        npt.assert_equal(
            result[ATOMS].get_atomic_numbers(), np.array([6, 1, 1, 1, 1])
        )

        assert isinstance(result, dict)

        # check data
        assert result[DATA].shape == (5, 3, 5)

        # check spacing
        assert result["spacing"].shape == (3, 3)
        # check that values are on the diagonal (correctness of order in
        # reading)
        npt.assert_almost_equal(
            result["spacing"].diagonal() / Bohr,
            np.array([1.889726, 1.889726, 0.000000]),
        )
        # check that sum is only 1.889726 for every column (correctness of
        # value)
        npt.assert_almost_equal(
            result["spacing"].sum(axis=0) / Bohr,
            np.array([1.889726, 1.889726, 0.000000]),
        )

        # check origin
        assert result["origin"].shape == (3,)
        npt.assert_almost_equal(
            result["origin"], np.array([-3.779452, -1.889726, -3.779452]) * Bohr
        )

        # check PBC
        assert (result[ATOMS].get_pbc() == (True, True, False)).all()


file_content_multiple = """ Benzene_Opt_Freq_B3LYP_6_31G_d_p_ MO=HOMO,LUMO
 MO coefficients
  -12   -8.797610   -9.151024   -6.512752    2
    3    8.063676    0.000000    0.000000
    3    0.000000    8.063676    0.000000
    3    0.000000    0.000000    8.063676
    6    6.000000   -2.284858   -1.319192    0.000000
    6    6.000000   -2.284825    1.319137    0.000000
    6    6.000000    0.000000   -2.638272    0.000000
    6    6.000000    2.284845   -1.319103    0.000000
    6    6.000000   -0.000008    2.638273    0.000000
    6    6.000000    2.284870    1.319165    0.000000
    1    1.000000   -4.062255   -2.345299    0.000000
    1    1.000000   -0.000025   -4.690600    0.000000
    1    1.000000   -4.062297    2.345119    0.000000
    1    1.000000    0.000075    4.690601    0.000000
    1    1.000000    4.062174    2.345444    0.000000
    1    1.000000    4.062187   -2.345312    0.000000
    2   21   22
 -2.74760E-12 -8.90988E-12  5.59016E-10  1.81277E-09  8.76453E-16  2.84215E-15
 -1.43957E-07 -2.02610E-07  2.92889E-05  4.12223E-05  4.59206E-11  6.46303E-11
 -6.71453E-10  9.41323E-10  1.36612E-07 -1.91518E-07  2.14186E-13 -3.00272E-13
  5.12861E-08  6.37111E-08 -1.04345E-05 -1.29624E-05 -1.63597E-11 -2.03231E-11
 -3.53200E-05 -7.80160E-05  1.33966E-02  3.06849E-02  1.12667E-08  2.48862E-08
 -3.26637E-06  4.17535E-06  6.66354E-04 -8.51133E-04  1.04193E-09 -1.33189E-09
  8.90476E-11  1.24804E-10 -1.81173E-08 -2.53921E-08 -2.84052E-14 -3.98110E-14
  3.14350E-06  1.73228E-06 -6.39634E-04 -3.52493E-04 -1.00274E-09 -5.52577E-10
  6.74853E-09 -2.18944E-08 -1.37303E-06  4.45455E-06 -2.15271E-12  6.98406E-12
  """


def test_cube_reading_multiple():
    with tempfile.NamedTemporaryFile(mode="r+") as cubefil:
        # Write data to a file
        cubefil.write(file_content_multiple)
        cubefil.seek(0)

        # read data using cube reading
        result = read_cube(cubefil)
        npt.assert_equal(
            result[ATOMS].get_atomic_numbers(),
            [6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1]
        )

        assert isinstance(result, dict)

        # check data
        assert result[DATA].shape == (3, 3, 3)

        # and datas
        assert len(result["datas"]) == 2
        assert (
            result[DATA].shape
            == result["datas"][0].shape
            == result["datas"][1].shape
        )

        # check labels
        assert result["labels"] == [21, 22]

        # check spacing
        assert result["spacing"].shape == (3, 3)
        # check that values are on the diagonal
        # (correctness of order in reading)
        npt.assert_almost_equal(
            result["spacing"].diagonal() / Bohr,
            np.array([8.063676, 8.063676, 8.063676]),
        )
        # check that sum is only 8.063676 for every column
        # (correctness of value)
        npt.assert_almost_equal(
            result["spacing"].sum(axis=0) / Bohr,
            np.array([8.063676, 8.063676, 8.063676]),
        )

        # check origin
        assert result["origin"].shape == (3,)
        npt.assert_almost_equal(
            result["origin"], np.array([-8.797610, -9.151024, -6.512752]) * Bohr
        )

        # check PBC
        # I don't know what this does so please check...
        assert (result[ATOMS].get_pbc() == (True, True, True)).all()


def test_reading_using_io():
    with tempfile.NamedTemporaryFile(mode="r+") as cubefil:
        # Write data to a file
        cubefil.write(file_content)
        cubefil.seek(0)

        result = read_cube_data(cubefil)
        assert isinstance(result, tuple)
        assert len(result) == 2

        assert result[0].shape == (5, 3, 5)

        assert isinstance(result[1], Atoms)
        npt.assert_equal(result[1].get_atomic_numbers(),
                         np.array([6, 1, 1, 1, 1]))
