OrderReallocation-HeavyTruc.../main.py

114 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import numpy as np
from data_structures import OrderData, RiskEnterpriseData, SupplierData, Config
from chromosome_utils import ChromosomeUtils
from objective_calculator import ObjectiveCalculator
from encoder import Encoder
from genetic_operators import GeneticOperator
from nsga2 import NSGA2
from visualizer import ResultVisualizer
def main():
# 初始化数据
order_data = OrderData()
risk_data = RiskEnterpriseData()
supplier_data = SupplierData()
config = Config()
# 初始化工具类
utils = ChromosomeUtils(order_data, risk_data, supplier_data)
calculator = ObjectiveCalculator(order_data, risk_data, supplier_data, utils, config)
encoder = Encoder(config, utils)
genetic_op = GeneticOperator(config, utils)
nsga2 = NSGA2(config.pop_size, config.objective_num)
visualizer = ResultVisualizer(utils)
# 初始化种群
population = encoder.initialize_population()
print(f"初始化种群完成,种群大小: {population.shape}")
# 记录历史
all_objectives = []
convergence_history = []
best_front = []
best_front_objs = []
no_improve_count = 0
prev_best_avg = float('inf')
# 进化过程
for generation in range(config.max_generations):
# 计算目标函数
objectives = [calculator.calculate_objectives(chrom) for chrom in population]
all_objectives.extend(objectives)
# 记录当前代的最优前沿
ranks, fronts = nsga2.fast_non_dominated_sort(objectives)
current_front = fronts[0]
current_front_objs = [objectives[i] for i in current_front]
best_front = population[current_front]
best_front_objs = current_front_objs
# 收敛判断(基于前沿平均目标值)
if len(current_front_objs) > 0:
avg_cost = sum(obj[0] for obj in current_front_objs) / len(current_front_objs)
avg_tardiness = sum(obj[1] for obj in current_front_objs) / len(current_front_objs)
convergence_history.append((avg_cost, avg_tardiness))
current_avg = avg_cost + avg_tardiness
if abs(current_avg - prev_best_avg) < 1e-4:
no_improve_count += 1
else:
no_improve_count = 0
prev_best_avg = current_avg
else:
no_improve_count += 1
# 选择(锦标赛选择)
selected = nsga2.selection(population, objectives)
# 交叉
offspring = []
i = 0
while len(offspring) < config.pop_size:
if i+1 < config.pop_size and random.random() < config.crossover_prob:
parent1 = selected[i]
parent2 = selected[i+1]
child1, child2 = genetic_op.two_point_crossover(parent1, parent2)
offspring.append(child1)
if len(offspring) < config.pop_size:
offspring.append(child2)
else:
offspring.append(selected[i])
i += 2
# 变异
offspring = [genetic_op.uniform_mutation(chrom) if random.random() < config.mutation_prob else chrom
for chrom in offspring]
offspring = np.array(offspring[:config.pop_size])
# 合并父代和子代
combined = np.vstack([population, offspring])
combined_objs = objectives + [calculator.calculate_objectives(chrom) for chrom in offspring]
# 环境选择核心NSGA-II的环境选择
population, objectives = nsga2.environmental_selection(combined, combined_objs)
# 早停检查
if no_improve_count >= config.early_stop_patience:
print(f"早停于第 {generation}")
break
if generation % 50 == 0:
print(f"{generation} 代完成,当前最优前沿大小: {len(current_front)}")
# 结果可视化与打印
print("进化完成,绘制结果...")
visualizer.plot_pareto_front(all_objectives, best_front_objs)
visualizer.plot_convergence(convergence_history)
visualizer.print_pareto_solutions(best_front, best_front_objs)
if __name__ == "__main__":
random.seed(42)
np.random.seed(42)
main()