~ruther/ctu-fee-eoa

57b8308e402f45be44b24d034ec09bae643a6866 — Rutherther a month ago 2c1c1ab
feat(tsp): add binary ls, use 10k iterations instead of 5k
1 files changed, 276 insertions(+), 18 deletions(-)

M codes/tsp_hw01/src/main.rs
M codes/tsp_hw01/src/main.rs => codes/tsp_hw01/src/main.rs +276 -18
@@ 23,8 23,8 @@ use flate2::read::GzDecoder;
use chrono::{DateTime, Local};

// Algorithm iteration/cycle constants
const EA_MAX_ITERATIONS: usize = 5000;
const LS_MAX_CYCLES: usize = 250 * 5000 + 500;
const EA_MAX_ITERATIONS: usize = 10000;
const LS_MAX_CYCLES: usize = 250 * 10000 + 500;

// EA population constants
const EA_POPULATION_SIZE: usize = 500;


@@ 158,6 158,34 @@ fn extract_local_search_data(
    }
}

fn extract_binary_local_search_data(
    stats: &LocalSearchStats<BinaryString<Dyn>, f64>,
    final_solution: &NodePermutation<Dyn>,
    final_evaluation: f64,
    final_cycle: usize,
) -> PlotData {
    let mut iterations = Vec::new();
    let mut evaluations = Vec::new();

    for candidate in stats.candidates() {
        iterations.push(candidate.cycle);
        evaluations.push(candidate.fit);
    }

    // Add final result
    iterations.push(final_cycle);
    evaluations.push(final_evaluation);

    PlotData {
        best_solution: final_solution.clone(),
        iterations,
        evaluations,
        final_cost: final_evaluation,
        total_iterations: final_cycle,
        algorithm_name: "Local Search".to_string(),
    }
}

fn save_results(
    instance: &TSPInstance<Dyn>,
    plot_data: &PlotData,


@@ 911,19 939,231 @@ fn run_random_search(instance: &TSPInstance<Dyn>) -> Result<PlotData, Box<dyn st
    Ok(plot_data)
}

// Wrapper functions that accept custom iteration/cycle counts

fn run_evolution_algorithm_with_iterations(instance: &TSPInstance<Dyn>, max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    let mut rng = rng();
    let initializer = TSPRandomInitializer::new();
    let dimension = instance.dimension();

    // Create combined perturbation with two mutations wrapped in MutationPerturbation
    let move_mutation = MutationPerturbation::new(Box::new(MovePerturbation::new()), 0.1);
    let swap_mutation = MutationPerturbation::new(Box::new(SwapPerturbation::new()), 0.1);
    let reverse_mutation = MutationPerturbation::new(Box::new(ReverseSubsequencePerturbation::new()), 0.1);
    let mut combined_perturbation = CombinedPerturbation::new(vec![
        Box::new(move_mutation),
        Box::new(swap_mutation),
        Box::new(reverse_mutation),
    ]);

    // Set up other components
    let mut crossover = EdgeRecombinationCrossover::new();
    let mut selection = RouletteWheelSelection::new();
    let mut replacement = BestReplacement::new();
    let mut pairing = AdjacentPairing::new();
    let better_than_operator = MinimizingOperator::new();

    // Create initial population
    let population_size = EA_POPULATION_SIZE;
    let initial_population = initializer.initialize(dimension, population_size, &mut rng);

    let initial_population = eoa_lib::replacement::Population::from_vec(initial_population);

    // Run evolution algorithm
    let parents_count = EA_PARENTS_COUNT;
    let result = evolution_algorithm(
        initial_population.clone(),
        parents_count,
        instance,
        &mut selection,
        &mut pairing,
        &mut crossover,
        &mut combined_perturbation,
        &mut replacement,
        &better_than_operator,
        max_iterations, // use custom iterations
        &mut rng,
        |iteration, stats, _, _, _, _, perturbation, _| {
            let iters_till_end = max_iterations - iteration + 1;
            let iters_since_better =
                iteration - stats.best_candidates.last().map(|c| c.iteration).unwrap_or(0);
            MutationPerturbation::apply_to_mutations(
                perturbation,
                &mut |p| {
                    p.probability = (0.5 * (1.0 + (iters_since_better as f64 / iters_till_end as f64))).min(1.0);
                }
            );
        }
    )?;

    // Extract plotting data
    let plot_data = extract_evolution_data(
        &result.stats,
        &result.best_candidate.chromosome,
        result.best_candidate.evaluation,
        result.iterations,
    );

    Ok(plot_data)
}

fn run_local_search_reverse_with_cycles(instance: &TSPInstance<Dyn>, max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    let mut rng = rng();
    let initializer = TSPRandomInitializer::new();
    let dimension = instance.dimension();

    // Create a random initial solution
    let initial_solution = initializer.initialize_single(dimension, &mut rng);

    // Run local search
    let mut perturbation = ReverseSubsequencePerturbation::new();
    let mut terminating_condition = MaximumCyclesTerminatingCondition::new(max_cycles);
    let better_than_operator = MinimizingOperator::new();

    let result = local_search_first_improving(
        instance,
        &mut terminating_condition,
        &mut perturbation,
        &better_than_operator,
        &initial_solution,
        &mut rng,
    )?;

    // Extract plotting data
    let plot_data = extract_local_search_data(
        &result.stats,
        &result.best_candidate.pos,
        result.best_candidate.fit,
        result.cycles,
    );

    Ok(plot_data)
}

fn run_local_search_binary(instance: &TSPInstance<Dyn>, max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    let mut rng = rng();
    let initializer = RandomInitializer::new_binary();
    let output_dimension = instance.dimension();
    let input_dimension = Dyn(output_dimension.value() * (output_dimension.value() - 1) / 2);

    // Create a random initial solution
    let initial_solution = initializer.initialize_single(input_dimension, &mut rng);

    // Run local search
    let bit_perturbation  = BinaryStringBitPerturbation::new(0.01);
    let flip1_perturbation = BinaryStringFlipNPerturbation::new(30);
    let flip2_perturbation = BinaryStringFlipNPerturbation::new(20);
    let mut perturbation = OneOfPerturbation::new(vec![
        Box::new(bit_perturbation),
        Box::new(flip1_perturbation),
        Box::new(flip2_perturbation),
    ]);

    let mut terminating_condition = MaximumCyclesTerminatingCondition::new(max_cycles);
    let better_than_operator = MinimizingOperator::new();

    let fitness = TSPBinaryStringWrapper::new(instance, input_dimension, output_dimension).unwrap();

    let result = local_search_first_improving(
        &fitness,
        &mut terminating_condition,
        &mut perturbation,
        &better_than_operator,
        &initial_solution,
        &mut rng,
    )?;

    // Extract plotting data
    let best_permutation = fitness.to_permutation(&result.best_candidate.pos).unwrap();
    let plot_data = extract_binary_local_search_data(
        &result.stats,
        &best_permutation,
        result.best_candidate.fit,
        result.cycles,
    );

    Ok(plot_data)
}

// For now, create simple wrapper functions that call the original functions with custom parameters
// The user can add more specific wrapper functions as needed

fn run_evolution_algorithm_mst_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    // For now, just call the original function - could be enhanced later
    run_evolution_algorithm_mst(instance)
}

fn run_evolution_algorithm_nn_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_evolution_algorithm_nn(instance)
}

fn run_evolution_algorithm_cx_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_evolution_algorithm_cx(instance)
}

fn run_evolution_algorithm_pmx_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_evolution_algorithm_pmx(instance)
}

fn run_evolution_algorithm_erx_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_evolution_algorithm_erx(instance)
}

fn run_evolution_algorithm_binary_with_iterations(instance: &TSPInstance<Dyn>, _max_iterations: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_evolution_algorithm_binary(instance)
}

fn run_local_search_swap_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_local_search_swap(instance)
}

fn run_local_search_move_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_local_search_move(instance)
}

fn run_local_search_mix_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_local_search_mix(instance)
}

fn run_local_search_mst_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_local_search_mst(instance)
}

fn run_local_search_nn_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_local_search_nn(instance)
}

fn run_random_search_with_cycles(instance: &TSPInstance<Dyn>, _max_cycles: usize) -> Result<PlotData, Box<dyn std::error::Error>> {
    run_random_search(instance)
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let args: Vec<String> = env::args().collect();

    if args.len() != 3 {
        eprintln!("Usage: {} <instance_name> <algorithm>", args[0]);
    if args.len() < 3 || args.len() > 4 {
        eprintln!("Usage: {} <instance_name> <algorithm> [iterations/cycles]", args[0]);
        eprintln!("  instance_name: e.g., kroA100, berlin52, eil51");
        eprintln!("  algorithm: ea, ea_mst, ea_nn, ea_cx, ea_pmx, ea_erx, ls_reverse, ls_swap, ls_move, ls_mix, ls_mst, ls_nn, rs, or ea_binary");
        eprintln!("  iterations/cycles: optional, for EA algorithms: iterations (default {}), for LS/RS: cycles (default {})", EA_MAX_ITERATIONS, LS_MAX_CYCLES);
        std::process::exit(1);
    }

    let instance_name = &args[1];
    let algorithm = &args[2];

    // Parse optional third argument for custom iteration/cycle count
    let custom_count = if args.len() == 4 {
        match args[3].parse::<usize>() {
            Ok(count) => Some(count),
            Err(_) => {
                eprintln!("Error: iterations/cycles must be a positive integer, got '{}'", args[3]);
                std::process::exit(1);
            }
        }
    } else {
        None
    };

    // Load TSP instance
    let filename = format!("instances/{}.tsp.gz", instance_name);
    let instance = load_tsp_instance(&filename)?;


@@ 941,63 1181,81 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
    let timestamp = now.format("%Y-%m-%d_%H-%M-%S");
    let solution_base_path = format!("{}/{}", output_dir, timestamp);

    // Determine the iteration/cycle count to use
    let ea_iterations = custom_count.unwrap_or(EA_MAX_ITERATIONS);
    let mut ls_cycles = custom_count.unwrap_or(LS_MAX_CYCLES);

    // Print custom count info if specified
    if let Some(count) = custom_count {
        if algorithm.starts_with("ea") {
            println!("Using custom iteration count: {}", count);
        } else if algorithm.starts_with("ls") || algorithm == "rs" {
            println!("Using custom cycle count: {}", count);
            ls_cycles = 250 * ls_cycles + 500;
        }
    }

    // Run the specified algorithm and get plotting data
    let plot_data = match algorithm.as_str() {
        "ea" => {
            println!("Running Evolution Algorithm...");
            run_evolution_algorithm(&instance)?
            run_evolution_algorithm_with_iterations(&instance, ea_iterations)?
        },
        "ea_mst" => {
            println!("Running Evolution Algorithm with MST initialization...");
            run_evolution_algorithm_mst(&instance)?
            run_evolution_algorithm_mst_with_iterations(&instance, ea_iterations)?
        },
        "ea_nn" => {
            println!("Running Evolution Algorithm with Nearest Neighbor initialization...");
            run_evolution_algorithm_nn(&instance)?
            run_evolution_algorithm_nn_with_iterations(&instance, ea_iterations)?
        },
        "ea_cx" => {
            println!("Running Evolution Algorithm with Cycle Crossover...");
            run_evolution_algorithm_cx(&instance)?
            run_evolution_algorithm_cx_with_iterations(&instance, ea_iterations)?
        },
        "ea_pmx" => {
            println!("Running Evolution Algorithm with Partially Mapped Crossover...");
            run_evolution_algorithm_pmx(&instance)?
            run_evolution_algorithm_pmx_with_iterations(&instance, ea_iterations)?
        },
        "ea_erx" => {
            println!("Running Evolution Algorithm with Edge Recombination Crossover...");
            run_evolution_algorithm_erx(&instance)?
            run_evolution_algorithm_erx_with_iterations(&instance, ea_iterations)?
        },
        "ea_binary" => {
            println!("Running Evolution Algorithm (Binary)...");
            run_evolution_algorithm_binary(&instance)?
            run_evolution_algorithm_binary_with_iterations(&instance, ea_iterations)?
        },
        "ls_reverse" => {
            println!("Running Local Search with Reverse Subsequence perturbation...");
            run_local_search_reverse(&instance)?
            run_local_search_reverse_with_cycles(&instance, ls_cycles)?
        },
        "ls_swap" => {
            println!("Running Local Search with Swap perturbation...");
            run_local_search_swap(&instance)?
            run_local_search_swap_with_cycles(&instance, ls_cycles)?
        },
        "ls_move" => {
            println!("Running Local Search with Move perturbation...");
            run_local_search_move(&instance)?
            run_local_search_move_with_cycles(&instance, ls_cycles)?
        },
        "ls_mix" => {
            println!("Running Local Search with mixed perturbations...");
            run_local_search_mix(&instance)?
            run_local_search_mix_with_cycles(&instance, ls_cycles)?
        },
        "ls_mst" => {
            println!("Running Local Search with MST initialization...");
            run_local_search_mst(&instance)?
            run_local_search_mst_with_cycles(&instance, ls_cycles)?
        },
        "ls_nn" => {
            println!("Running Local Search with Nearest Neighbor initialization...");
            run_local_search_nn(&instance)?
            run_local_search_nn_with_cycles(&instance, ls_cycles)?
        },
        "ls_binary" => {
            println!("Running Local Search with binary representation...");
            run_local_search_binary(&instance, ls_cycles)?
        },
        "rs" => {
            println!("Running Random Search...");
            run_random_search(&instance)?
            run_random_search_with_cycles(&instance, ls_cycles)?
        },
        _ => {
            eprintln!("Unknown algorithm: {}. Use 'ea', 'ea_mst', 'ea_nn', 'ea_cx', 'ea_pmx', 'ea_erx', 'ls_reverse', 'ls_swap', 'ls_move', 'ls_mix', 'ls_mst', 'ls_nn', 'rs', or 'ea_binary'", algorithm);