美文网首页
什么是强化学习:三、冰面求生(二)

什么是强化学习:三、冰面求生(二)

作者: 圣_狒司机 | 来源:发表于2023-05-03 23:17 被阅读0次

一、一点小小改进

def chose_action(self,ai=False):
        pre_s_space = self.s + self.action_space
        pre_s_space = [pre_s for pre_s in pre_s_space if not np.all((self.s_chain ==pre_s),axis=1).any()]
        if pre_s_space:self.action_space = pre_s_space - self.s
        a = random.choice(self.action_space)
        return a

返回

playing 0 times!
526 times success!

状态不重复的话提速一倍,把环境弄复杂一点也是一样:

playing 20000 times!
playing 40000 times!
playing 60000 times!
70613 times success!

这种改进其实没什么大用,因为它是针对这种游戏专门的设计的,你的小人工智能还没有掌握普遍的学习智慧。

二、Q_learning

现在我们加入一点称得上人工智能的东西 Q_learning !
公式:


Q_learning

写成自加模式:
Q(s,a) += \alpha\cdot (rewards + \gamma \cdot max(Q(s^,,a))-Q(s,a) )

只用下面这个公式!
原理见 【强化学习】Q-Learning算法详解
把环境的 step函数改成:

def step(self,a=[0,0],ai=False):
        s = self.agent.s
        s_ = np.array(s) + np.array(a)
        if 0<=s_[0]<=3 and 0<=s_[1]<=11:
            self.agent.s = s_
            r = self.env[s_[0],s_[1]]
        else:
            s_ = s
            r = -1
        self.agent.post_step(s,a,r,s_)
        return s_,r

人工智能体的改动:

class Agent:
    def __init__(self):
        self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
        self.s = np.array([0,0])
        self.s_chain = np.expand_dims(self.s,0)
        self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
        self.Q = np.zeros((4,12,4))
        self.epsilon = 0.2

    def chose_action(self):
        if self.epsilon < random.random():
            a = random.choice(self.action_space)
        else:
            a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
        return a

    def post_step(self,s,a,r,s_):
        self.s_chain = np.vstack([self.s_chain,s])
        self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
        a_number = np.where((agent.action_space == a).all(axis=1))
        if r == -1:
            self.Q[s_[0],s_[1]] = -1
        update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
        self.Q[s[0],s[1],a_number] += update

    def reset(self):
        self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
        self.s = np.array([0,0])
        self.s_chain = np.expand_dims(self.s,0)
        self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)

三、拉出来遛遛

把环境弄复杂一点点,经过30万轮碰壁,他学会了找到出口:

327737 times success!
1.png

他学到的东西叫做Q表,算是他的知识库,我们打开看看:

plt.imshow((agent.Q[:,:,:-1]*255).astype(np.uint8))
Q表

可以看到越接近出口的地方越明亮,极暗的方块是他觉得不安全的位置。

四、文末放上全部代码:

import numpy as np
from matplotlib import pyplot as plt
import random
from itertools import count

class Env:
    def __init__(self):
        self.action_space = []
        self.agent = None
        self.env = np.zeros((4,12))
        self.env[-1,1:-1] = -1
        self.env[:2,3] = -1
        self.env[1:-1,8] = -1
        self.env[-1,-1] = 1
        self.env_show = self.env.copy()

    def step(self,a=[0,0],ai=False):
        s = self.agent.s
        s_ = np.array(s) + np.array(a)
        if 0<=s_[0]<=3 and 0<=s_[1]<=11:
            self.agent.s = s_
            r = self.env[s_[0],s_[1]]
        else:
            s_ = s
            r = -1
        self.agent.post_step(s,a,r,s_)
        return s_,r

    def play(self):
        env.reset()
        for t in count(1):
            a = agent.chose_action()
            if a is not None:
                s,r = env.step(a)
                if r in [-1,1]:
                    break
            else:
                r = None
                break
        return t,r
    
    def play_until_success(self):
        for t in count(1):
            _,r = self.play()
            if r:
                if t%20000 == 0:
                    print(f"playing {t} times!")
                if r == 1:
                    print(f"{t} times success!")
                    self.render()
                    break
            else:break

    def render(self):
        for i,j in self.agent.s_chain:
            self.env_show[i,j] = 0.5
        plt.imshow(self.env_show)
        plt.show()

    def reset(self):
        self.agent.reset()
        self.env_show = self.env.copy()

    def register(self,agent):
        self.agent = agent

class Agent:
    def __init__(self):
        self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
        self.s = np.array([0,0])
        self.s_chain = np.expand_dims(self.s,0)
        self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
        self.Q = np.zeros((4,12,4))
        self.epsilon = 0.2

    def chose_action(self):
        if self.epsilon < random.random():
            a = random.choice(self.action_space)
        else:
            a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
        return a

    def post_step(self,s,a,r,s_):
        self.s_chain = np.vstack([self.s_chain,s])
        self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
        a_number = np.where((agent.action_space == a).all(axis=1))
        if r == -1:
            self.Q[s_[0],s_[1]] = -1
        update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
        self.Q[s[0],s[1],a_number] += update

    def reset(self):
        self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
        self.s = np.array([0,0])
        self.s_chain = np.expand_dims(self.s,0)
        self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)

env = Env()
agent = Agent()
env.register(agent)
# env.render()
env.play_until_success()

相关文章

网友评论

      本文标题:什么是强化学习:三、冰面求生(二)

      本文链接:https://www.haomeiwen.com/subject/umpojdtx.html