• 发表时间:2018
  • 论文链接:https://www.aclweb.org/anthology/P18-2033
  • 代码:https://github.com/LiuQL2/MedicalChatbot
  • 代码语言:python

摘要

本文构建了一个用于自动诊断的对话系统。首先,从线上医学论坛上病人的自述以及病人医生间的交谈中提取症状,从而构建数据集;然后,本文提出了用于自动诊断的任务型对话系统框架,该系统能够通过与病人交谈,获取除病人自述外的其他症状。实验表明,从交谈中获取的额外症状能够极大地提升疾病诊断精度,本文的对话系统能够自动地收集这些症状,而且诊断准确度更高。

数据

数据是从中文医学网站上的儿科中收集的,包括四种疾病类型:上呼吸道感染、儿童功能性消化不良、腹泻和支气管炎。标记数据包括两个过程:症状提取、症状归一化。

症状提取

图1

图1
标记数据时,采用BIO进行症状识别,每个提取出的症状表述打上True或False的标签来表明病人有没有该症状。

症状归一化

不同人表述症状是不一样的,比如有人说拉肚子,有人说腹泻,因此要将这些症状表述为专业术语,采用的是SNOMED CT(一种临床医学语标准)标准,如图2所示。通过病人所提供的症状可以分为两类:显性症状和隐性症状。显性症状是指病人在咨询时提供的症状,如病人说:“医生,我流鼻涕打喷嚏,这是怎么回事啊”,其中鼻流涕和打喷嚏就是两个显性症状。隐性症状是指医生通过咨询获知的症状,如医生接着问:“那你拉肚子不?”, 此处的腹泻就是隐性症状。

在这里插入图片描述

图2

本文框架

本文对话系统框架包括三大模块:NLU(自然语言理解):检测用户意图、提取槽位值;DM(对话管理):追踪对话状态、给出系统行动;NLG(自然语言生成):根据系统行动生成自然语言。其中NLU和NLG模块都是采用基于模板的方法,重点研究DM模块。
DM模块包含两个子模块:对话状态追踪(DST)和策略学习。

对话策略学习

用户模拟器

训练是采用DQN,网络的输入是当前状态 S t S_t St,输出为agent的action。那么训练数据从何而来呢,这时候就需要用户模拟器,和agent模拟任务驱动对话过程。我们称这个过程为warm_start。此处假设有4种疾病,warm_start过程描述如下:

  1. 对话管理系统初始化
  • 用户模拟器初始化action a u , 0 a_{u,0} au,0:随机从所有数据中选取一条,将数据中的所有显性症状作为inform_slots(例如上述例子中的鼻流涕和打喷嚏),action为:request, request_slots为disease。
  • 利用初始的用户action更新状态
  • 初始化agent。
  1. 模拟对话系统
  • agent基于规则根据状态选择action:此处分析一下状态中用户提供了哪些显性症状,用户的需求是询问哪种疾病,例如:agent在数据集中发现只有鼻流涕和打喷嚏是不能确定任何一种疾病的,否则对话就结束了。那么agent就在数据集中查询4种疾病中哪种出现这两种症状的频率比较高,假设发现4种疾病中上呼吸道感染最容易出现这两种症状,但是还有三种症状需要询问病人去确认一下,那就从这三种症状中随机选一种吧,此时,agent采取action 为request,request_slots就是刚才随机选择的症状;
  • agent采取行动后,DM就更新状态(此状态包含的信息比较多:对话轮数、agent_action、user_action、current_slots(提供当前需要的信息)、agent和用户的历史action信息等)
  • user再根据状态和agent的action进行回复,如确认一下有没有agent询问的特征;
  • DM再次更新状态;
    上述过程不断进行,直到agent确诊了疾病是什么,或者达到我们设定的最大对话轮数。

从上述过程中,我们将(state,agent_action,reward,next_state,对话是否结束)这些信息记录下来,后续用作训练数据。
其中括号中的大多都是采用热编码的形式。(具体的过程请参看代码)

DQN训练

DQN是一种off-policy的深度强化学习算法,off-policy就表明了agent的动作选择网络和目标网络肯定不是同一个,因此有两个一样的网络,网络的输入是state,输出为agent_action。

  • 网络结构:网络结构非常简单,一层全连接层;
  • 输入(batch_size, 200):200是个假设的state表示维度,在实际使用中决定其大小的因素比较多;
  • 输出:假设agent_action有300种,则输出为(batch_size,300)

动作选择网络每个iteration的参数都会利用loss回传进行更新,那么loss函数是怎么算的呢?
其标签就是将 next_state记为 s ′ s^{'} s送给目标网络, y i = r + γ m a x a ′ Q ( s ′ , a ′ ) y_{i} = r+\gamma max_{a^{'}}Q(s^{'},a^{'}) yi=r+γmaxaQ(s,a),其中 r r r表示及时回报。
loss函数就是目标网络和动作选择网络输出 y i y_{i} yi间的均方误差。

动作选择网络利用上述过程实时更新,目标网络在每个epoch结束时,直接copy动作选择网络的参数进行更新。

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐