#include "unit_test.hpp"

#include <array>
#include <iostream>
#include <regex>

#include "phylogenetics.hpp"
#include "test_utils.hpp"

void print_header(std::string const& header);

void test_distanceMatrixToString();
void test_clusterToString();
void test_clusterIdPairToString();
void test_treeToString();

void test_calculateDistance();
void test_initTree();
void test_initDistanceMatrix();
void test_eraseColumn();
void test_eraseRow();

void test_minimumDistance();
void test_mergeClusters();
void test_buildPhylogeneticTree();
void test_phylogeneticTreeToString();

// Utility function to print out the header (name)
// Used by test cases to print out the test name
void print_header(std::string const& header) {
    std::cerr << "---------------" << std::endl;
    std::cerr << header << std::endl;
    std::cerr << "---------------" << std::endl;
}

//
// Unit tests
//

// Run all unit tests
void run_unit_tests(int part) {
    switch (part) {
        case 1:
            test_distanceMatrixToString();
            test_clusterToString();
            test_clusterIdPairToString();
            test_treeToString();

            break;
        case 2:
            test_calculateDistance();
            test_initTree();
            test_initDistanceMatrix();
            test_eraseColumn();
            test_eraseRow();

            break;
        case 3:
            test_minimumDistance();
            test_mergeClusters();
            test_buildPhylogeneticTree();
            test_phylogeneticTreeToString();

            break;
        default:
            std::cerr << "Part should be either 1, 2, or 3. Provided part = "
                      << part << std::endl;
    }
}

void test_distanceMatrixToString() {
    print_header("test_distanceMatrixToString");

    DistanceMatrix matrix{{0, 5, 3, 1.8},
                          {5, 0, 12.4, 0.1},
                          {3, 12.4, 0, 100.9},
                          {1.8, 0.1, 100.9, 0}};

    // not verbose
    {
        std::string expected(
            "0 5 3 1.8\n"
            "5 0 12.4 0.1\n"
            "3 12.4 0 100.9\n"
            "1.8 0.1 100.9 0\n");

        std::string actual(toString(matrix, false));

        // remove whitespace at the end of each line
        actual = std::regex_replace(actual, std::regex(R"(\s+\n)"), "\n");

        check_equal(expected, actual);
    }

    // verbose
    {
        std::string expected(
            "0 - 0 = 0\n"
            "0 - 1 = 5\n"
            "0 - 2 = 3\n"
            "0 - 3 = 1.8\n"
            "1 - 0 = 5\n"
            "1 - 1 = 0\n"
            "1 - 2 = 12.4\n"
            "1 - 3 = 0.1\n"
            "2 - 0 = 3\n"
            "2 - 1 = 12.4\n"
            "2 - 2 = 0\n"
            "2 - 3 = 100.9\n"
            "3 - 0 = 1.8\n"
            "3 - 1 = 0.1\n"
            "3 - 2 = 100.9\n"
            "3 - 3 = 0\n");

        std::string actual(toString(matrix, true));

        // remove whitespace at the end of each line
        actual = std::regex_replace(actual, std::regex(R"(\s+\n)"), "\n");

        check_equal(expected, actual);
    }
}

void test_clusterToString() {
    print_header("test_clusterToString");

    // distance between t1 and t2 = 5
    // distance between t1 and t3 = 11
    // distance between t2 and t3 = 13
    Cluster t1{"ATTACCCGGATTAAC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"CCATCCCGGATTAAT", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"TTCCCTCCTCACGCC", 2, 0.0, nullptr, nullptr, 1};

    Cluster t1t2{"", -1, 2.5, &t1, &t2, 2};

    // distance between t1t2 and t3 = (11 + 13) / 2 = 12
    Cluster t1t2t3{"", -1, 6.0, &t1t2, &t3, 3};

    // not verbose
    {
        std::string expected("((Taxon_0,Taxon_1),Taxon_2)");

        std::string actual(clusterToString(&t1t2t3, false));

        check_equal(expected, actual);
    }

    // verbose
    {
        std::string expected(
            "((ATTACCCGGATTAAC[i:0,h:0,s:1],CCATCCCGGATTAAT[i:1,h:0,s:1])[i:-1,"
            "h:2.5,s:2],TTCCCTCCTCACGCC[i:2,h:0,s:1])[i:-1,h:6,s:3]");

        std::string actual(clusterToString(&t1t2t3, true));

        check_equal(expected, actual);
    }
}

void test_clusterIdPairToString() {
    print_header("test_clusterIdPairToString");

    check_equal("0-5\n", toString(ClusterIdPair{0, 5}));
    check_equal("5-0\n", toString(ClusterIdPair{5, 0}));
    check_equal("2-201\n", toString(ClusterIdPair{2, 201}));
    check_equal("201-2\n", toString(ClusterIdPair{201, 2}));
}

void test_treeToString() {
    print_header("test_treeToString");

    Cluster t1{"ATTACCCGGATTAAC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"CCATCCCGGATTAAT", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"TTCCCTCCTCACGCC", 2, 0.0, nullptr, nullptr, 1};

    Cluster t1t2{"", -1, 2.5, &t1, &t2, 2};

    Tree tree{&t1t2, &t3};

    // not verbose
    {
        std::string expected(
            "(Taxon_0,Taxon_1);\n"
            "Taxon_2;\n");

        std::string actual(toString(tree, false));

        check_equal(expected, actual);
    }

    // verbose
    {
        std::string expected(
            "(ATTACCCGGATTAAC[i:0,h:0,s:1],CCATCCCGGATTAAT[i:1,h:0,s:1])[i:-1,"
            "h:2.5,s:2];\n"
            "TTCCCTCCTCACGCC[i:2,h:0,s:1];\n");

        std::string actual(toString(tree, true));

        check_equal(expected, actual);
    }
}

void test_calculateDistance() {
    print_header("test_calculateDistance");

    Taxon taxon1("ATTACCCGGATTAAC");
    Taxon taxon2("CCATCCCGGATTAATAAC");

    int expected(8);

    check_equal(expected, calculateDistance(taxon1, taxon2));
    check_equal(expected, calculateDistance(taxon2, taxon1));
}

void test_initTree() {
    print_header("test_initTree");

    std::vector<Taxon> taxa{"ATATAAACTGAATCATCGAC", "AGGTTGGGCTGCTATAACCC",
                            "GGCACCAGAATTCCAATCAGCTCTT", "AGTGCTTGTAAACGCTCG",
                            "AATCTTGAACCTCTAACCGT"};

    Cluster t1{"ATATAAACTGAATCATCGAC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"AGGTTGGGCTGCTATAACCC", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"GGCACCAGAATTCCAATCAGCTCTT", 2, 0.0, nullptr, nullptr, 1};
    Cluster t4{"AGTGCTTGTAAACGCTCG", 3, 0.0, nullptr, nullptr, 1};
    Cluster t5{"AATCTTGAACCTCTAACCGT", 4, 0.0, nullptr, nullptr, 1};

    Tree expected{&t1, &t2, &t3, &t4, &t5};

    Tree actual(initTree(taxa));

    check_equal(expected, actual);

    // clean up
    //  !!! UNCOMMENT THE FOLLOWING LINE WHEN deleteTree IS IMPLEMENTED : !!!
    // deleteTree(actual);
}

void test_initDistanceMatrix() {
    print_header("test_initDistanceMatrix");

    Cluster t1{"ATATAAACTGAATCATCGAC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"AGGTTGGGCTGCTATAACCC", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"GGCACCAGAATTCCAATCAGCTCTT", 2, 0.0, nullptr, nullptr, 1};
    Cluster t4{"AGTGCTTGTAAACGCTCG", 3, 0.0, nullptr, nullptr, 1};
    Cluster t5{"AATCTTGAACCTCTAACCGT", 4, 0.0, nullptr, nullptr, 1};

    Tree tree{&t1, &t2, &t3, &t4, &t5};

    DistanceMatrix expected{{0, 16, 21, 13, 17},
                            {16, 0, 21, 17, 15},
                            {21, 21, 0, 20, 19},
                            {13, 17, 20, 0, 15},
                            {17, 15, 19, 15, 0}};

    DistanceMatrix actual(initDistanceMatrix(tree));

    check_equal(expected, actual);
}

void test_eraseColumn() {
    print_header("test_eraseColumn");

    DistanceMatrix matrix{{0, 16, 21, 13, 17},
                          {16, 0, 21, 17, 15},
                          {21, 21, 0, 20, 19},
                          {13, 17, 20, 0, 15},
                          {17, 15, 19, 15, 0}};

    DistanceMatrix expected{{0, 16, 21, 17},
                            {16, 0, 21, 15},
                            {21, 21, 0, 19},
                            {13, 17, 20, 15},
                            {17, 15, 19, 0}};

    DistanceMatrix actual(matrix);
    eraseColumn(actual, 3);

    check_equal(expected, actual);
}

void test_eraseRow() {
    print_header("test_eraseRow");

    DistanceMatrix matrix{{0, 16, 21, 13, 17},
                          {16, 0, 21, 17, 15},
                          {21, 21, 0, 20, 19},
                          {13, 17, 20, 0, 15},
                          {17, 15, 19, 15, 0}};

    DistanceMatrix expected{{0, 16, 21, 13, 17},
                            {16, 0, 21, 17, 15},
                            {21, 21, 0, 20, 19},
                            {17, 15, 19, 15, 0}};

    DistanceMatrix actual(matrix);
    eraseRow(actual, 3);

    check_equal(expected, actual);
}

void test_minimumDistance() {
    print_header("test_minimumDistance");

    // matrix with no duplicate distances
    {
        DistanceMatrix matrix{{0, 16, 21, 13, 17},
                              {16, 0, 21, 17, 15},
                              {21, 21, 0, 20, 19},
                              {13, 17, 20, 0, 15},
                              {17, 15, 19, 15, 0}};

        ClusterIdPair expected{0, 3};  // or {3, 0}

        ClusterIdPair actual(minimumDistance(matrix));

        check_equal(expected, actual);
    }

    // matrix with duplicate distances
    {
        DistanceMatrix matrix{{0, 13.5, 21, 13.5, 17},
                              {13.5, 0, 21, 17, 15},
                              {21, 21, 0, 20, 19},
                              {13.5, 17, 20, 0, 15},
                              {17, 15, 19, 15, 0}};

        ClusterIdPair expected{0, 1};  // or {1, 0}

        ClusterIdPair actual(minimumDistance(matrix));

        check_equal(expected, actual);
    }
}

void test_mergeClusters() {
    print_header("test_mergeClusters");

    Cluster t1{"CATAGACCTGACGCCAGCTC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"CATAGACCCGCCATGAGCTC", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"CGTAGACTGGGCGCCAGCTC", 2, 0.0, nullptr, nullptr, 1};
    Cluster t4{"CCTAGACGTCGCGGCAGTCC", 3, 0.0, nullptr, nullptr, 1};

    Tree initial_tree{&t1, &t2, &t3, &t4};

    DistanceMatrix initial_matrix{
        {0, 5, 4, 7},
        {5, 0, 7, 10},
        {4, 7, 0, 7},
        {7, 10, 7, 0},
    };

    // WPGMA
    {
        Tree actual_tree(initial_tree);
        DistanceMatrix actual_matrix(initial_matrix);

        // first merge
        Cluster t1t3{"", -1, 2, &t1, &t3, 2};
        Tree expected_tree1{&t1t3, &t2, &t4};

        DistanceMatrix expected_matrix1{
            {0, 6, 7},
            {6, 0, 10},
            {7, 10, 0},
        };

        mergeClusters({0, 2}, actual_tree, actual_matrix, WPGMA);

        check_equal(expected_tree1, actual_tree);
        check_equal(expected_matrix1, actual_matrix);

        // second merge
        Cluster t1t3t2{"", -1, 3, &t1t3, &t2, 3};
        Tree expected_tree2{&t1t3t2, &t4};

        DistanceMatrix expected_matrix2{
            {0, 8.5},
            {8.5, 0},
        };

        mergeClusters({0, 1}, actual_tree, actual_matrix, WPGMA);

        check_equal(expected_tree2, actual_tree);
        check_equal(expected_matrix2, actual_matrix);

        // third merge
        Cluster t1t3t2t4{"", -1, 4.25, &t1t3t2, &t4, 4};
        Tree expected_tree3{&t1t3t2t4};

        DistanceMatrix expected_matrix3{
            {0},
        };

        mergeClusters({0, 1}, actual_tree, actual_matrix, WPGMA);

        check_equal(expected_tree3, actual_tree);
        check_equal(expected_matrix3, actual_matrix);
    }

    // UPGMA
    {
        Tree actual_tree(initial_tree);
        DistanceMatrix actual_matrix(initial_matrix);

        // first merge
        Cluster t1t3{"", -1, 2, &t1, &t3, 2};
        Tree expected_tree1{&t1t3, &t2, &t4};

        DistanceMatrix expected_matrix1{
            {0, 6, 7},
            {6, 0, 10},
            {7, 10, 0},
        };

        mergeClusters({0, 2}, actual_tree, actual_matrix, UPGMA);

        check_equal(expected_tree1, actual_tree);
        check_equal(expected_matrix1, actual_matrix);

        // second merge (is different for UPGMA)
        Cluster t1t3t2{"", -1, 3, &t1t3, &t2, 3};
        Tree expected_tree2{&t1t3t2, &t4};

        DistanceMatrix expected_matrix2{
            {0, 8},
            {8, 0},
        };

        mergeClusters({0, 1}, actual_tree, actual_matrix, UPGMA);

        check_equal(expected_tree2, actual_tree);
        check_equal(expected_matrix2, actual_matrix);

        // third merge (is different for UPGMA)
        Cluster t1t3t2t4{"", -1, 4, &t1t3t2, &t4, 4};
        Tree expected_tree3{&t1t3t2t4};

        DistanceMatrix expected_matrix3{
            {0},
        };

        mergeClusters({0, 1}, actual_tree, actual_matrix, UPGMA);

        check_equal(expected_tree3, actual_tree);
        check_equal(expected_matrix3, actual_matrix);
    }
}

void test_buildPhylogeneticTree() {
    print_header("test_buildPhylogeneticTree");

    Cluster t1{"CATAGACCTGACGCCAGCTC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"CATAGACCCGCCATGAGCTC", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"CGTAGACTGGGCGCCAGCTC", 2, 0.0, nullptr, nullptr, 1};
    Cluster t4{"CCTAGACGTCGCGGCAGTCC", 3, 0.0, nullptr, nullptr, 1};

    Tree initial_tree{&t1, &t2, &t3, &t4};

    DistanceMatrix initial_matrix{
        {0, 5, 4, 7},
        {5, 0, 7, 10},
        {4, 7, 0, 7},
        {7, 10, 7, 0},
    };

    // WPGMA
    {
        Tree actual_tree(initial_tree);
        DistanceMatrix actual_matrix(initial_matrix);

        // build expected tree
        Cluster t1t3{"", -1, 2, &t1, &t3, 2};
        Cluster t1t3t2{"", -1, 3, &t1t3, &t2, 3};
        Cluster t1t3t2t4{"", -1, 4.25, &t1t3t2, &t4, 4};

        Tree expected_tree{&t1t3t2t4};
        DistanceMatrix expected_matrix{
            {0},
        };

        buildPhylogeneticTree(actual_tree, actual_matrix, WPGMA);

        check_equal(expected_tree, actual_tree);
        check_equal(expected_matrix, actual_matrix);
    }

    // UPGMA
    {
        Tree actual_tree(initial_tree);
        DistanceMatrix actual_matrix(initial_matrix);

        // build expected tree
        Cluster t1t3{"", -1, 2, &t1, &t3, 2};
        Cluster t1t3t2{"", -1, 3, &t1t3, &t2, 3};
        Cluster t1t3t2t4{"", -1, 4, &t1t3t2, &t4, 4};

        Tree expected_tree{&t1t3t2t4};
        DistanceMatrix expected_matrix{
            {0},
        };

        buildPhylogeneticTree(actual_tree, actual_matrix, UPGMA);

        check_equal(expected_tree, actual_tree);
        check_equal(expected_matrix, actual_matrix);
    }
}

void test_phylogeneticTreeToString() {
    print_header("test_phylogeneticTreeToString");

    // build tree
    Cluster t1{"CATAGACCTGACGCCAGCTC", 0, 0.0, nullptr, nullptr, 1};
    Cluster t2{"CATAGACCCGCCATGAGCTC", 1, 0.0, nullptr, nullptr, 1};
    Cluster t3{"CGTAGACTGGGCGCCAGCTC", 2, 0.0, nullptr, nullptr, 1};
    Cluster t4{"CCTAGACGTCGCGGCAGTCC", 3, 0.0, nullptr, nullptr, 1};
    Cluster t1t3{"", -1, 2, &t1, &t3, 2};
    Cluster t1t3t2{"", -1, 3, &t1t3, &t2, 3};
    Cluster t1t3t2t4{"", -1, 4, &t1t3t2, &t4, 4};

    Cluster* treeRoot = &t1t3t2t4;

    // not verbose
    {
        std::string expected(
            " (4)\n"
            "|   +--- (3)\n"
            "|   |   +--- (2)\n"
            "|   |   |   +---Taxon_0\n"
            "|   |   |   +---Taxon_2\n"
            "|   |   +---Taxon_1\n"
            "|   +---Taxon_3\n");

        std::string actual(phylogeneticTreeToString(treeRoot, false));

        check_equal(expected, actual);
    }

    // verbose
    {
        std::string expected(
            " (4)\n"
            "|   +--- (3)\n"
            "|   |   +--- (2)\n"
            "|   |   |   +---CATAGACCTGACGCCAGCTC\n"
            "|   |   |   +---CGTAGACTGGGCGCCAGCTC\n"
            "|   |   +---CATAGACCCGCCATGAGCTC\n"
            "|   +---CCTAGACGTCGCGGCAGTCC\n");

        std::string actual(phylogeneticTreeToString(treeRoot, true));

        check_equal(expected, actual);
    }
}
