#include <iostream>
#include <vector>
#include <limits>
#include <algorithm>
#include <random>
#include <tuple>
#include <chrono>

using Matrix = std::vector<std::vector<int> >;

/////////// BACKTRACKING VANILLA ///////////////////////
void
backtracking_vanilla_aux( 
        Matrix const& m, int depth, 
        std::vector<int>  & current_assignment,
        int               & cost_current_assignemnt, 
        std::vector<int>  & best_solution_so_far, 
        int               & cost_best_solution_so_far )
{
    // base case - last level reached
    if ( depth == m.size() ) {
        if ( cost_current_assignemnt < cost_best_solution_so_far ) {
            cost_best_solution_so_far = cost_current_assignemnt;
            best_solution_so_far = current_assignment;
        }
        return;
    }

    // recursive case - try all possible jobs for agent index `depth`
    for ( int j=0; j < m.size(); ++j ) {

        //skip if job j is already assigned
        if ( std::find( current_assignment.begin(), current_assignment.end(), j) != current_assignment.end() ) continue;

        //otherwise assign and call recursively
        current_assignment.push_back( j );
        cost_current_assignemnt += m[depth][j];

        backtracking_vanilla_aux( m, depth+1, 
                current_assignment, cost_current_assignemnt, 
                best_solution_so_far, cost_best_solution_so_far );

        // undo assignment and cost update
        current_assignment.pop_back( );
        cost_current_assignemnt -= m[depth][j];
    }
}

std::vector<int> backtracking_vanilla( Matrix const& m ) {
    std::vector<int> best_solution_so_far;
    std::vector<int> current_assignment;
    int cost_current_assignemnt = 0;
    int cost_best_solution_so_far = std::numeric_limits<int>::max();

    backtracking_vanilla_aux( m, 0, current_assignment, cost_current_assignemnt, best_solution_so_far, cost_best_solution_so_far );

    return best_solution_so_far;
}

/////////// BACKTRACKING + BRANCH & BOUND ///////////////////////

//lower bound evaluation function
int
lower_bound( 
        Matrix const& m, 
        int           depth, 
        std::vector<int> &current_assignment,
        int    const& cost_current_assignemnt )
{
    int lower_bound = cost_current_assignemnt;
    // for each future agent (row) find the minimum cost of assigning it to a job (column) 
    // that is not already assigned to a previous agent
    for ( int i=depth; i < m.size(); ++i ) {
        int min_in_row = std::numeric_limits<int>::max();
        for ( int j=0; j < m.size(); ++j ) {
            //check if job is not assigned (column is taken)
            if ( std::find( current_assignment.begin(), current_assignment.end(), j) == current_assignment.end() ) {
                if ( min_in_row > m[i][j] ) min_in_row = m[i][j];
            }
        }
        lower_bound += min_in_row;
    }
    return lower_bound;
}

void
backtracking_branch_bound_aux( 
        Matrix const& m, int depth, 
        std::vector<int>  & current_assignment,
        int               & cost_current_assignemnt, 
        std::vector<int>  & best_solution_so_far, 
        int               & cost_best_solution_so_far )
{
    if ( depth == m.size() ) {
        if ( cost_current_assignemnt < cost_best_solution_so_far ) {
            cost_best_solution_so_far = cost_current_assignemnt;
            best_solution_so_far = current_assignment;
        }
    }

    for ( int j=0; j < m.size(); ++j ) {
        if ( std::find( current_assignment.begin(), current_assignment.end(), j) != current_assignment.end() ) continue;
        current_assignment.push_back( j );
        cost_current_assignemnt += m[depth][j];

        int lb = lower_bound( m, depth+1, current_assignment, cost_current_assignemnt );

        //branch cancellation check
        if ( lb < cost_best_solution_so_far ) {
            backtracking_branch_bound_aux( m, depth+1, 
                    current_assignment, cost_current_assignemnt, 
                    best_solution_so_far, cost_best_solution_so_far );
        }

        current_assignment.pop_back( );
        cost_current_assignemnt -= m[depth][j];
    }
}

std::vector<int>
backtracking_branch_bound( Matrix const& m ) {
    std::vector<int> best_solution_so_far;
    std::vector<int> current_assignment;
    int cost_current_assignemnt = 0;
    int cost_best_solution_so_far = std::numeric_limits<int>::max();

    backtracking_branch_bound_aux( m, 0, current_assignment, cost_current_assignemnt, best_solution_so_far, cost_best_solution_so_far );

    return best_solution_so_far;
}

/////////// BACKTRACKING + BRANCH & BOUND  + BEST FIRST ////////////////
void
backtracking_branch_bound_best_first_aux( 
        Matrix const& m, int depth, 
        std::vector<int>  & current_assignment,
        int               & cost_current_assignemnt, 
        std::vector<int>  & best_solution_so_far, 
        int               & cost_best_solution_so_far )
{
    // base case
    if ( depth == m.size() ) {
        if ( cost_current_assignemnt < cost_best_solution_so_far ) {
            cost_best_solution_so_far = cost_current_assignemnt;
            best_solution_so_far = current_assignment;
        }
    }

    std::vector< std::tuple<int,int> > ordered_branches; // tuples of (job index, lower_bound)

    // precompute lower bound for each node and store it in a vector of tuples
    for ( int j=0; j < m.size(); ++j ) {
        if ( std::find( current_assignment.begin(), current_assignment.end(), j) != current_assignment.end() ) continue;

        // precompute lower bound for each node and store it in a vector of tuples (job index, lower_bound)
        current_assignment.push_back( j );
        cost_current_assignemnt += m[depth][j];
        int lb = lower_bound( m, depth+1, current_assignment, cost_current_assignemnt );
        ordered_branches.push_back( std::make_tuple(j,lb) );
        current_assignment.pop_back( );
        cost_current_assignemnt -= m[depth][j];
    }

    // sort the vector of tuples by lower_bound (second element of the tuple)
    std::sort( ordered_branches.begin(), ordered_branches.end(),
            []( std::tuple<int,int> const& a, std::tuple<int,int> const& b ) {
                return std::get<1>(a) < std::get<1>(b); // increasing order of lower bound
            } );

    for ( auto const& t : ordered_branches ) {
        int job_index = std::get<0>(t);
        current_assignment.push_back( job_index );
        cost_current_assignemnt += m[depth][job_index];

        int lb = std::get<1>(t);
        //branch cancelation check
        if ( lb < cost_best_solution_so_far ) {
            backtracking_branch_bound_best_first_aux( m, depth+1, 
                    current_assignment, cost_current_assignemnt, 
                    best_solution_so_far, cost_best_solution_so_far);
        } else {
          current_assignment.pop_back( );
          cost_current_assignemnt -= m[depth][job_index];
          break; // since the vector is ordered by lower bound
                 // all the following nodes will have a lower bound greater than the best solution so far
        }

        current_assignment.pop_back( );
        cost_current_assignemnt -= m[depth][job_index];
    }
}

std::vector<int>
backtracking_branch_bound_best_first( Matrix const& m ) {
    std::vector<int> best_solution_so_far;
    std::vector<int> current_assignment;
    int cost_current_assignemnt = 0;
    int cost_best_solution_so_far = std::numeric_limits<int>::max();
    backtracking_branch_bound_best_first_aux( m, 0, current_assignment, cost_current_assignemnt, best_solution_so_far, cost_best_solution_so_far );
    return best_solution_so_far;
}

int main() {
    int cost = 0;
    std::vector<int> solution;
   
//    Matrix m = {
//        {6,2,4,8},
//        {3,4,7,6},
//        {2,7,8,5},
//        {3,5,4,2}
//    };
    int N = 10;
    // randomely generate a cost matrix with values between 1 and 10 using C++-11 random number generation library
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dis(1, 10);

    Matrix m(N, std::vector<int>(N));
    for (int i=0; i<N; ++i) {
        for (int j=0; j<N; ++j) {
            m[i][j] = dis(gen);
        }
    }

    // time the three algorithms
    auto start = std::chrono::high_resolution_clock::now();
    if ( N > 12 ) {
        std::cout << "Backtracking vanilla is too slow for N > 10, skipping..." << std::endl;
    } else {
      solution = backtracking_vanilla( m );
      cost = 0;
      for (int i=0; i<N; ++i) {
        cost += m[i][ solution[i] ];
        std::cout << solution[i] << " ";
      }
      std::cout << "  cost " << cost << std::endl;
    }
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = end - start;
    std::cout << "Backtracking vanilla time: " << elapsed.count() << " seconds" << std::endl;

    start = std::chrono::high_resolution_clock::now();
    solution = backtracking_branch_bound( m );
    cost = 0;
    for (int i=0; i<N; ++i) {
        cost += m[i][ solution[i] ];
        std::cout << solution[i] << " ";
    }
    std::cout << "  cost " << cost << std::endl;
    end = std::chrono::high_resolution_clock::now();
    elapsed = end - start;
    std::cout << "Backtracking branch and bound time: " << elapsed.count() << " seconds" << std::endl;

    start = std::chrono::high_resolution_clock::now();
    solution = backtracking_branch_bound_best_first( m );
    cost = 0;
    for (int i=0; i<N; ++i) {
        cost += m[i][ solution[i] ];
        std::cout << solution[i] << " ";
    }
    std::cout << "  cost " << cost << std::endl;
    end = std::chrono::high_resolution_clock::now();
    elapsed = end - start;
    std::cout << "Backtracking best first time: " << elapsed.count() << " seconds" << std::endl;

    return 0;
}
