-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtd3.py
125 lines (99 loc) · 3.62 KB
/
td3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
'''
使用stable-baselines3库中的PPO算法训练模型,解决function optimization问题
'''
import gymnasium as gym
from function_env import FunctionEnv
from stable_baselines3 import PPO
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
# 定义目标函数
def sphere(x):
return -np.sum(x ** 2)
def ackley(x):
a = 20
b = 0.2
c = 2 * np.pi
d = len(x) # 获取维度
# 使用 numpy 的向量化操作,对 x 中的每个元素进行计算
part1 = -a * np.exp(-b * np.sqrt(np.sum(x ** 2) / d))
part2 = -np.exp(np.sum(np.cos(c * x)) / d)
result = part1 + part2 + a + np.exp(1)
return -result
def rastrigin(x):
"""
Rastrigin函数实现,支持任意维度
f(x) = A*n + sum(x_i^2 - A*cos(2π*x_i))
其中A=10, n是维度数
"""
A = 10
n = len(x)
result = A * n + np.sum(x**2 - A * np.cos(2 * np.pi * x))
return result
def griewank(x):
"""
Griewank函数实现,支持任意维度
f(x) = 1 + (1/4000) * sum(x_i^2) - prod(cos(x_i/sqrt(i+1)))
特点:
- 有无数个规则分布的局部极小值点
- 全局最小值在 f(0,...,0) = 0
- 典型搜索空间: x_i ∈ [-600, 600]
"""
# 计算第一部分: sum(x_i^2)/4000
sum_part = np.sum(x**2) / 4000
# 计算第二部分: prod(cos(x_i/sqrt(i)))
# 注意:numpy的索引从0开始,而Griewank函数定义中i从1开始
indices = np.arange(1, len(x) + 1)
prod_part = np.prod(np.cos(x / np.sqrt(indices)))
# 计算Griewank函数值
result = 1 + sum_part - prod_part
# 返回负值(因为RL是最大化奖励)
return -result
def levy(x):
"""
Levy函数实现,支持任意维度
f(x) = sin²(πw₁) + Σᵢ₌₁ᵏ⁻¹[(wᵢ-1)² · (1+10sin²(πwᵢ₊₁))] + (wₙ-1)² · (1+sin²(2πwₙ))
其中 wᵢ = 1 + (xᵢ-1)/4
特点:
- 多峰函数,有多个局部最小值
- 全局最小值在 f(1,...,1) = 0
- 典型搜索空间: xᵢ ∈ [-10, 10]
"""
w = 1.0 + (x - 1.0) / 4.0
term1 = np.sin(np.pi * w[0]) ** 2
term2 = np.sum((w[:-1] - 1.0) ** 2 * (1.0 + 10.0 * np.sin(np.pi * w[1:]) ** 2))
term3 = (w[-1] - 1.0) ** 2 * (1.0 + np.sin(2.0 * np.pi * w[-1]) ** 2)
result = term1 + term2 + term3
# 返回负值(因为RL是最大化奖励)
return -result
# 创建环境
env = FunctionEnv(
function=levy,
dim=12,
step_size=0.1,
bound=[-10, 10],
max_steps=10000,
reset_state=np.array([-7.0]*12)
)
# 创建PPO模型
model_name = "PPO_12dim_levy_step02_max100_mixreward"
checkpoint_cb = CheckpointCallback(save_freq=100_000, save_path='./logs/', name_prefix=model_name)
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=r'./tensorboard_logs/', device='cpu', learning_rate=1e-3)
model.learn(total_timesteps=10_000_000, log_interval=1, callback=checkpoint_cb)
# 保存模型
model.save(f"./models/{model_name}")
del model
# 加载模型
model = PPO.load(f"./models/{model_name}")
# 测试模型
obs, info = env.reset()
init_obs = obs
init_val = info['value']
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminal, truncated, info = env.step(action)
print(f'state: {obs}, action: {action}, reward: {reward}, \nval: {info["value"]}, best: {info["best"]}, best_value: {info["best_value"]}, current_steps: {info["current_steps"]}')
print('----------------------------------')
if terminal or truncated:
break
print(f'init_obs: {init_obs}, init_val: {init_val}, \nbest: {info["best"]}, best_value: {info["best_value"]}')
env.close()