86. 实现 Sarsa 学习算法走出迷宫#
86.1. 介绍#
在基于价值的强化学习中我们主要实现了 Q-Learning 算法,事实上 Sarsa 算法和 Q-Learning 最大的区别就在于 Q-Table 的更新,本次实验,结合实验中 Q-Learning 的算法实现,并根据 Sarsa 的算法流程来完成迷宫挑战。
86.2. 知识点#
Q-Table 初始化
Q-Table 更新函数
Sarsa 完整算法实现
86.3. Q-Table 初始化#
根据前面的实验内容,你应该知道不论是 Q-Learning 还是 Sarsa,其核心都是基于价值迭代,所以需要先初始化 Q-Table。
挑战:按要求初始化 Q-Table。
规定:构造一个 \(16*4\) 的 DataFrame 表(16 个 state,4 个 action)作为 Q-Table。
提示:和实验中 Q-Learning 初始化方式相同。
import numpy as np
import pandas as pd
import time
from IPython import display
def init_q_table():
### 代码开始 ### (≈ 2 行代码)
actions = None
q_table = None
### 代码结束 ###
return q_table
init_q_table()
Note
本课程中,Notebook 挑战系统无法自动评判,你需要自行补充上方单元格中缺失的代码并运行,如果输出结果和下方的期望输出结果一致,即代表此挑战顺利通过。完成全部内容后,点击「提交检测」即可通过,此说明后续不再出现。
期望输出
up | down | left | right | |
---|---|---|---|---|
0 | 0.0 | 0.0 | 0.0 | 0.0 |
1 | 0.0 | 0.0 | 0.0 | 0.0 |
2 | 0.0 | 0.0 | 0.0 | 0.0 |
3 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | 0.0 | 0.0 | 0.0 | 0.0 |
5 | 0.0 | 0.0 | 0.0 | 0.0 |
6 | 0.0 | 0.0 | 0.0 | 0.0 |
7 | 0.0 | 0.0 | 0.0 | 0.0 |
8 | 0.0 | 0.0 | 0.0 | 0.0 |
9 | 0.0 | 0.0 | 0.0 | 0.0 |
10 | 0.0 | 0.0 | 0.0 | 0.0 |
11 | 0.0 | 0.0 | 0.0 | 0.0 |
12 | 0.0 | 0.0 | 0.0 | 0.0 |
13 | 0.0 | 0.0 | 0.0 | 0.0 |
14 | 0.0 | 0.0 | 0.0 | 0.0 |
15 | 0.0 | 0.0 | 0.0 | 0.0 |
86.4. 动作选择#
接下来,我们需要使用
\(\epsilon-greedy\)
方法根据 Q-Table 进行动作选择,这里仿照实验内容实现
act_choose
函数。
挑战:使用 \(\epsilon-greedy\) 方法根据 Q-Table 进行动作选择。
规定:在概率为 \(1-epsilon\) ,或 Q 值都为 0 的情况下,随机选择动作;此外,按照 Q 的最大值选择动作,并且动作用 action 表示。
提示:这里可能会使用 if,else 语句判断,与实验中内容相同。
def act_choose(state, q_table, epsilon):
state_act = q_table.iloc[state, :]
actions = np.array(["up", "down", "left", "right"])
### 代码开始 ### (≈ 4 行代码)
if None:
action = None
else:
action = None
### 代码结束 ###
return action
运行测试
seed = np.random.RandomState(25) # 为了保证验证结果相同引入随机数种子
a = seed.rand(16, 4)
test_q_table = pd.DataFrame(a, columns=["up", "down", "left", "right"])
l = []
for s in [1, 4, 7, 12, 14]:
l.append(act_choose(state=s, q_table=test_q_table, epsilon=1))
l
期望输出
['left', 'right', 'right', 'right', 'left']
86.5. 行为反馈#
在行为反馈中我们同样将 terminal 终点的奖励设为
10
,将 hole 陷阱的惩罚设为
-10
,同样为了尽快找到最短路径,每一步的惩罚为
-1
。直接沿用实验中相似代码块即可。
def env_feedback(state, action, hole, terminal):
reward = 0.0
end = 0
a, b = state
if action == "up":
a -= 1
if a < 0:
a = 0
next_state = (a, b)
elif action == "down":
a += 1
if a >= 4:
a = 3
next_state = (a, b)
elif action == "left":
b -= 1
if b < 0:
b = 0
next_state = (a, b)
elif action == "right":
b += 1
if b >= 4:
b = 3
next_state = (a, b)
if next_state == terminal:
reward = 10.0
end = 2
elif next_state == hole:
reward = -10.0
end = 1
else:
reward = -1.0
return next_state, reward, end
86.6. Q-Table 更新#
接下来,就需要完成 Q-Table 更新函数。通过实验内容可知,Sarsa 的 Q-Table 的更新是与 Q-Learning 最大的区别之处,所以需要根据 Sarsa 的 Q-Table 更新公式来实现。
挑战:根据下方 Sarsa 的 Q-Table 的更新公式完善 Q-Table 更新函数。
提示:结合 Q-Learning 中 Q-Table
更新函数进行修改,通过标签查看 DataFrame 特定值时使用
.loc[]
。
def update_q_table(
q_table, state, action, next_state, next_action, terminal, gamma, alpha, reward
):
x, y = state
next_x, next_y = next_state
q_original = q_table.loc[x * 4 + y, action]
if next_state != terminal:
### 代码开始 ### (≈ 1 行代码)
q_predict = None
### 代码结束 ###
else:
q_predict = reward
### 代码开始 ### (≈ 1 行代码)
q_table.loc[None] = None
### 代码结束 ###
return q_table
运行测试(仅执行一次,重复执行请重启 kernel)
new_q_table = update_q_table(
q_table=test_q_table,
state=(2, 2),
action="right",
next_state=(2, 3),
next_action="down",
terminal=(3, 2),
gamma=0.9,
alpha=0.8,
reward=10,
)
new_q_table.loc[10, "right"]
期望输出:(仅执行一次得到的结果)
8.740755431411795
同样为了展示强化学习效果,定义一个状态展示函数,此处综合沿用实验中相应代码块即可。
def show_state(end, state, episode, step, q_table):
terminal = (3, 2)
hole = (2, 1)
env = np.array([["_ "] * 4] * 4)
env[terminal] = "$ "
env[hole] = "# "
env[state] = "L "
interaction = ""
for i in env:
interaction += "".join(i) + "\n"
if state == terminal:
message = "EPISODE: {}, STEP: {}".format(episode, step)
interaction += message
display.clear_output(wait=True)
print(interaction)
print("\n" + "q_table:")
print(q_table)
time.sleep(3) # 在成功到终点时,等待 3 秒
else:
display.clear_output(wait=True)
print(interaction)
print("\n" + "q_table:")
print(q_table)
time.sleep(0.3) # 在这里控制每走一步所需要时间
86.7. Sarsa 算法实现#
最后,我们根据 Sarsa 算法伪代码来实现完整的学习过程。
挑战:顺利完成以上几个函数后,根据下方 Sarsa 算法伪代码实现完整的学习过程。请结合 Q-Learning 完成代码。
def sarsa(max_episodes, alpha, gamma, epsilon):
q_table = init_q_table()
terminal = (3, 2)
hole = (2, 1)
episodes = 0
while episodes < max_episodes:
step = 0
state = (0, 0)
end = 0
show_state(end, state, episodes, step, q_table)
x, y = state
### 代码开始 ### (≈ 1 行代码)
action = None # 动作选择
### 代码结束 ###
while end == 0:
next_state, reward, end = env_feedback(
state, action, hole, terminal
) # 环境反馈
next_x, next_y = next_state
next_action = act_choose(next_x * 4 + next_y, q_table, epsilon) # 动作选择
### 代码开始 ### (≈ 3 行代码)
q_table = None # q-table 更新
state = None
action = None
### 代码结束 ###
step += 1
show_state(end, state, episodes, step, q_table)
if end == 2:
episodes += 1
sarsa(max_episodes=20, alpha=0.8, gamma=0.9, epsilon=0.9) # 执行测试
期望输出
_ _ _ _
_ _ _ _
_ # _ _
_ _ L _
EPISODE: 19, STEP: 5
q_table:
up down left right
0 -4.421534 -3.457078 -3.936450 -4.152483
1 -3.409185 -9.062400 -3.433181 -3.596441
2 -2.213120 4.590499 -3.514029 -3.414613
3 -1.536000 -1.574400 -2.908418 -3.936450
4 -4.109114 -2.730086 -2.836070 -2.548000
5 -2.836070 -9.984000 -2.065920 -1.720000
6 -2.850867 8.000000 -2.562662 -0.800000
7 -2.342144 -0.800000 -1.982720 0.000000
8 -2.509531 -2.348544 -2.213120 -8.000000
9 0.000000 0.000000 0.000000 0.000000
10 -3.033926 10.000000 -8.000000 -0.800000
11 -2.844488 0.000000 0.000000 -0.800000
12 -2.766862 -1.536000 -1.536000 6.142464
13 0.000000 0.000000 -2.325504 9.600000
14 0.000000 0.000000 0.000000 0.000000
15 0.000000 0.000000 0.000000 0.000000
由于 Q Table 的值是随机的,上面的实验结果仅供参考。只要随着迭代次数的增加,Q Table 按要求持续更新,并使得智能体走的步数变少,最终接近 5 步即可。
○ 欢迎分享本文链接到你的社交账号、博客、论坛等。更多的外链会增加搜索引擎对本站收录的权重,从而让更多人看到这些内容。