#include "test_utils.hpp"

#include <cmath>
#include <iostream>
#include <stdexcept>
#include <utility>

bool clusters_both_null_or_equal(Cluster const* cluster_ptr1,
                                 Cluster const* cluster_ptr2);

/*
    For units tests
*/

bool doubles_equal(double double1, double double2, double precision) {
    return std::fabs(double1 - double2) < precision;
}

/*
    This `operator==` is a is a technique called "operator overloading".
    it implements the '==' operator for the type Cluster so that we can do
    comparisons to see if two Clusters are equal. If you want to know more
    take a look here: https://en.cppreference.com/w/cpp/language/operators
*/
bool operator==(Cluster const& cluster1, Cluster const& cluster2) {
    bool same_taxon(cluster1.taxon == cluster2.taxon);
    bool same_id(cluster1.id == cluster2.id);
    bool same_height(doubles_equal(cluster1.height, cluster2.height));

    bool left1_equals_left2(
        clusters_both_null_or_equal(cluster1.left, cluster2.left));
    bool right1_equals_right2(
        clusters_both_null_or_equal(cluster1.right, cluster2.right));
    bool left1_equals_right2(
        clusters_both_null_or_equal(cluster1.left, cluster2.right));
    bool right1_equals_left2(
        clusters_both_null_or_equal(cluster1.right, cluster2.left));

    // child order does not matter
    bool same_children((left1_equals_left2 && right1_equals_right2) ||
                       (left1_equals_right2 && right1_equals_left2));

    return same_taxon && same_id && same_height && same_children;
}

bool clusters_both_null_or_equal(Cluster const* cluster_ptr1,
                                 Cluster const* cluster_ptr2) {
    if ((cluster_ptr1 == nullptr) ^ (cluster_ptr2 == nullptr)) return false;
    return (cluster_ptr1 == nullptr && cluster_ptr2 == nullptr) ||
           (*cluster_ptr1 == *cluster_ptr2);
}

bool operator==(ClusterIdPair const& pair1, ClusterIdPair const& pair2) {
    return (pair1[0] == pair2[0] && pair1[1] == pair2[1]) ||
           (pair1[0] == pair2[1] && pair1[1] == pair2[0]);
}

bool operator==(Tree const& tree1, Tree const& tree2) {
    if (tree1.size() != tree2.size()) return false;

    for (size_t i(0); i < tree1.size(); ++i)
        if (!clusters_both_null_or_equal(tree1[i], tree2[i])) return false;

    return true;
}

/*
    Every test should use this function to keep track of summary
*/
void check(bool condition, std::string const& fail_message) {
    if (condition)
        std::cerr << "[Passed]" << std::endl;
    else
        std::cerr << "[Failed]" << std::endl << fail_message << std::endl;

    std::cerr << std::endl;
}

/*
    It should be called by all following check_equal functions.

    Note that it does not compare message_expected and message_actual
    (they are just used for the fail message)
*/
void check_equal(bool is_equal, std::string const& message_expected,
                 std::string const& message_actual) {
    check(is_equal,
          "Expected:\n" + message_expected + "\nActual:\n" + message_actual);
}

void check_equal(int expected, int actual) {
    check_equal(expected == actual, std::to_string(expected),
                std::to_string(actual));
}

void check_equal(double expected, double actual, double precision) {
    check_equal(doubles_equal(expected, actual, precision) < precision,
                std::to_string(expected), std::to_string(actual));
}

void check_equal(std::string const& expected, std::string const& actual) {
    check_equal(expected == actual, expected, actual);
}

void check_equal(Cluster const& expected, Cluster const& actual) {
    check_equal(expected == actual, clusterToString(&expected, true),
                clusterToString(&actual, true));
}

void check_equal(DistanceMatrix const& expected, DistanceMatrix const& actual) {
    check_equal(expected == actual, toString(expected), toString(actual));
}

void check_equal(ClusterIdPair const& expected, ClusterIdPair const& actual) {
    check_equal(expected == actual, toString(expected), toString(actual));
}

void check_equal(Tree const& expected, Tree const& actual) {
    check_equal(expected == actual, toString(expected, true),
                toString(actual, true));
}
