[关闭]
@evilking 2018-04-30T20:56:19.000000Z 字数 16018 阅读 2617

NLP

FastText

FastText 是一种 Facebook AI Research 在 16 年开源的一个文本分类器. paper 链接: https://arxiv.org/pdf/1607.01759.pdf

其特点就是 fast,相对于其它传统的文本分类模型,如 SVM、Logistic Regression、Neural Netword 等模型,fasttext 在保持分类效果的同时,还可大大缩短训练时间和测试时间.

在原文中有说它可以将时间复杂度从 降到 ,其中 为类别数, 为文本表征向量长度;文中还说如果这种方法只考虑前 个目标,则时间复杂度可降到 .

本篇文档组织结构,前面理论部分主要是翻译 fasttext 的这篇原文,后面源码分析部分主要是在 github上找的一套源码实现再对其分析.


模型框架

该论文前面的 Introduction 部分主要是说目前很多文本分类模型虽然准确度可以,但是在大语料库上训练的很慢,于是借鉴 Word2VecCBOW 模型的思想提出了 fasttext ,可解决其他文本分类模型训练速度慢的问题.

其中 word2vec 的 CBOW 模型框架如下:

fasttext

表示给定当前词 的上下文 ,来预测当前词 ,即

当然 CBOW 模型的实现又有两种: Hierarchical SoftmaxNegative sampling. 所以 fasttext 也是有两套实现,这在后面的源码分析中我们会分别去分析.


对于句子分类来说,一个简单又有效的任务是如 BoW 模型一样用一个向量去表征句子,然后训练出线性分类器,如逻辑回归、或 SVM. 但是,线性分类器不在特征和类别中共享参数. 这可能限制了它们在大的输出空间中的泛化,在这个大的输出空间中,一些类的样本可能非常少。

常用的解决办法是将线性分类器分解成多个低秩矩阵 (Schutze, 1992; Mikolov et al., 2013),或者使用多层神经网络 (Collobert and Weston, 2008; Zhang et al., 2015).

fasttext1

矩阵 为词典表中每个单词的词向量构成一列; 为词向量;

图中表示当前单词 n-gram 上下文中的单词的词向量 ,经过词向量平均后生成隐藏层向量;然后喂给线性分类器;文本表征是一个隐藏变量,有可能可以被重用.

这个模型与 Word2Vec 中 CBOW 模型框架类似,不同之处在于要预测的当前词被替换成了标签(因为是要分类). 我们使用 softmax 函数 来计算预测出各种类别的概率.

假设集合有 篇文档,使用最小化每个类别的负指数似然概率:

其中 为第 篇文档归一化后的特征向量表示, 是类别标签, 是权重矩阵. 这个模型可以使用随机梯度下降法和线性下降学习率在多核 CPU 上进行异步训练.


Hierarchical softmax

当类别数很大时,计算线性分类器在计算成本上是很昂贵的. 更进一步说,计算复杂度为 是类别数, 是文本表征向量的维度. 为了优化运行时间,我们使用一个基于 Huffman 编码hierarchical softmax (Goodman, 2001) 树. 训练时,计算复杂度降到了 .

当查找最可能的类别时,hierarchical softmax 同样能改进测试时间. 由 Huffman 树的性质可知,每一个节点都关联一个概率,这个概率是从根节点到当前节点这条路径成立的概率. 如果当前节点的深度为 ,向上逐层回溯的父节点分别为 ,则这个概率为:

这意味着一个节点的概率总是小于它父节点的概率. 使用深度优先搜索扫描这棵树,寻找最大概率的叶子节点时,就允许我们删除任何有着较小概率的节点.

在实验中,我们观测到测试时计算复杂度减小到了 . 这个方法可更进一步扩展,使用 二叉堆 只计算前 个目标,可以将复杂度降到 .

二叉堆可参看:
https://zh.wikipedia.org/wiki/%E4%BA%8C%E5%8F%89%E5%A0%86

http://blog.csdn.net/zhang_zp2014/article/details/47863575


N-gram 特征

词袋模型时丢掉了次序信息的,因为若需要精确的次序信息,在计算上是非常昂贵的. 替代方案是我们使用 n-grams 词袋模型,将 n 元词向量附加到文本特征中,在小范围内表示部分次序信息. 在实验中,这种替代方法与使用精确的次序信息 (Wang and Manning, 2012)相比,也是非常有效的.

我们通过 hashing 技巧 (Weinberger et al., 2009) 去维护一个能快速索引 n-grams 的 map 集合,其中 hash 函数与 (Mikolov et al. 2011) 中提到的一致.

翻译小结

这篇理论部分跟 Word2Vec 中 CBOW 模型类似,没什么好说的。主要在实践上,使用了一些技巧去优化性能,这点在后面源码分析中也有体现.

英文水平不行,有些地方翻译的不行,有问题及时指出一起来改进.

java 版源码分析

源码下载地址: https://github.com/ivanhk/fastText_java

加载到 eclipse 后的目录结构为:

fasttext3

其中 fasttext.Main 为主类,由于这套源码本身没有提供数据集进行测试,所以我们先直接跟踪源码去分析流程吧.


首先在 Main 类中的 main() 方法开始分析:

  1. //动作命令参数
  2. String command = args[0];
  3. if ("skipgram".equalsIgnoreCase(command) || "cbow".equalsIgnoreCase(command)
  4. || "supervised".equalsIgnoreCase(command)) {
  5. //训练模型
  6. op.train(args);
  7. } else if ("test".equalsIgnoreCase(command)) {
  8. //测试准确度
  9. op.test(args);
  10. } else if ("print-vectors".equalsIgnoreCase(command)) {
  11. op.printVectors(args);
  12. } else if ("predict".equalsIgnoreCase(command) || "predict-prob".equalsIgnoreCase(command)) {
  13. //预测类别
  14. op.predict(args);
  15. } else {
  16. printUsage();
  17. System.exit(1);
  18. }

该方法主要是通过传入的参数做不同的动作,从方法命名可知,我们主要分析的就是 op.train()op.predict() 这两个功能了.


训练

模型的训练主要体现在 FastText 类的 train() 方法里,这一步之前虽然有解析传入的参数的步骤,不过我们跳过,必要时笔者会进行说明。

train() 方法的逻辑也比较清晰,前面分别去读取语料库,初始化必要的输入输出矩阵,启动模型训练线程池去真正开始训练,模型训练完成后再将模型保存到文件中。

我们一步一步来,先看语料库读取:

  1. //设置文件读取工具类
  2. dict_.setLineReaderClass(lineReaderClass_);
  3. ......
  4. //构建词典表,args.input 为输入文件的路径
  5. dict_.readFromFile(args_.input);

readFromFile(String file) 方法读取语料库,构建词典表,保存在HashMap 对象 word2int 中,这里主要是用到了上面论文中最后说的 hash 技巧.

  1. try {
  2. long minThreshold = 1;
  3. String[] lineTokens;
  4. //循环读取每一行
  5. while ((lineTokens = lineReader.readLineTokens()) != null) {
  6. //对每行实体进行处理
  7. for (int i = 0; i <= lineTokens.length; i++) {
  8. if (i == lineTokens.length) {
  9. add(EOS);
  10. } else {
  11. if (Utils.isEmpty(lineTokens[i])) {
  12. continue;
  13. }
  14. //添加到 hash 表中
  15. add(lineTokens[i]);
  16. }
  17. //列表大小超过 75% 则重新散列,防止冲突泛滥
  18. if (size_ > 0.75 * MAX_VOCAB_SIZE) {
  19. minThreshold++;
  20. threshold(minThreshold, minThreshold);
  21. }
  22. }
  23. }
  24. } finally {
  25. if (lineReader != null) {
  26. lineReader.close();
  27. }
  28. }
  29. //重新 hash 散列,使得Map 散列均匀,优化性能
  30. threshold(args_.minCount, args_.minCountLabel);
  31. //与 word2vec 中处理高频和低频单词一样
  32. //以一定的概率过滤掉单词
  33. initTableDiscard();

这一步主要是往 HashMap 中不断添加实体(单词或者标签),当添加的量超过最大容量的 75% 时,为了提升 HashMap 的性能就需要扩容并重新 Hash 散列,这一步实现是在 threshold() 方法中。在散列的同时,会删除词频小于某一阈值的单词,还会根据实体的类别统计出单词数 nwords 和类别标签数 nlabels.

  1. public void add(final String w) {
  2. //进行冲突检测, hash 散列
  3. long h = find(w);
  4. //进入的实体总数
  5. ntokens_++;
  6. //实体不存在 word2int 表中
  7. if (Utils.mapGetOrDefault(word2int_, h, WORDID_DEFAULT) == WORDID_DEFAULT) {
  8. entry e = new entry();
  9. e.word = w;
  10. e.count = 1;
  11. e.type = w.startsWith(args_.label) ? entry_type.label : entry_type.word;
  12. //实体列表,包括单词和类别标签
  13. words_.add(e);
  14. //size 为不同实体数 ,或者叫实体的索引
  15. word2int_.put(h, size_++);
  16. } else {
  17. //若实体存在,则词频 +1
  18. words_.get(word2int_.get(h)).count++;
  19. }
  20. }

先调用 find() 方法做冲突检查,按照 hash 散列算法进行处理的,返回不冲突的索引.

接着分析 readFromFile() 方法后面部分,主要看 initTableDiscard() 方法,后面的 args.model 参数我们是用 fasttext 做文本分类,所以模型类别为 model_name.sup,后面部分就不用看了.

  1. public void initTableDiscard() {
  2. pdiscard_ = new ArrayList<Float>(size_);
  3. for (int i = 0; i < size_; i++) {
  4. //词频 (words是删除低频单词后的结果列表, ntokens是总共的单词数)
  5. float f = (float) (words_.get(i).count) / (float) ntokens_;
  6. pdiscard_.add((float) (Math.sqrt(args_.t / f) + args_.t / f));
  7. }
  8. }

是为了后面以一定的概率删除词频很高的词,计算公式为

其中 为一阈值, 为词频;该处理方式与 Word2Vec 中一致.


接着 FastText 类的 train() 方法继续分析,初始化输入和输出矩阵:

  1. if (!Utils.isEmpty(args_.pretrainedVectors)) {
  2. loadVectors(args_.pretrainedVectors);
  3. } else {
  4. input_ = new Matrix(dict_.nwords() + args_.bucket, args_.dim);
  5. //均匀随机分布初始化
  6. input_.uniform(1.0f / args_.dim);
  7. }
  8. //模型使用监督学习进行分类
  9. if (args_.model == model_name.sup) {
  10. output_ = new Matrix(dict_.nlabels(), args_.dim);
  11. } else {
  12. //否则就是为了构建词向量
  13. output_ = new Matrix(dict_.nwords(), args_.dim);
  14. }
  15. output_.zero();

我们先不考虑预训练的情况,直接初始化 input 矩阵,它的每行是一个单词的词向量,由均匀随机分布初始化,值的范围在 之间.

输出我们是考虑 fasttext 做监督学习预测给定文本的类别,所以 args_.model == model_name.sup,output矩阵的每一行为一个 label 的表征向量,初始值为 0 .


下面就是真正开始启动多线程去训练了.

TrainThread 类的 run() 方法中:

  1. Model model = new Model(input_, output_, args_, threadId);
  2. if (args_.model == model_name.sup) {
  3. //根据损失函数不同初始化不同的结构
  4. model.setTargetCounts(dict_.getCounts(entry_type.label));
  5. } else {
  6. model.setTargetCounts(dict_.getCounts(entry_type.word));
  7. }
  8. final long ntokens = dict_.ntokens();
  9. long localTokenCount = 0;
  10. //在方法中传输的都是单词索引,可能是为了减小内存占用
  11. List<Integer> line = new ArrayList<Integer>();
  12. List<Integer> labels = new ArrayList<Integer>();
  13. String[] lineTokens;
  14. //将两重 for 循环变成了一重
  15. while (tokenCount_.get() < args_.epoch * ntokens) {
  16. lineTokens = lineReader.readLineTokens();
  17. //这样处理是为了能循环 epoch 次的训练集,
  18. //将原本要两重 for 循环的变成一重 for 循环
  19. if (lineTokens == null) {
  20. try {
  21. lineReader.rewind();
  22. } catch (Exception e) {
  23. }
  24. //到文件末尾后再重头开始读
  25. lineTokens = lineReader.readLineTokens();
  26. }
  27. //自适应学习率
  28. float progress = (float) (tokenCount_.get()) / (args_.epoch * ntokens);
  29. float lr = (float) (args_.lr * (1.0 - progress));
  30. //读取一行单词或标签到 line 和 labels localTokenCount += dict_.getLine(lineTokens, line, labels, model.rng);
  31. //走监督模型
  32. if (args_.model == model_name.sup) {
  33. dict_.addNgrams(line, args_.wordNgrams);
  34. if (labels.size() == 0 || line.size() == 0) {
  35. continue;
  36. }
  37. //这里才是走分类模型
  38. supervised(model, lr, line, labels);
  39. //后面是为了训练词向量,我们不考虑
  40. } else if (args_.model == model_name.cbow) {
  41. cbow(model, lr, line);
  42. } else if (args_.model == model_name.sg) {
  43. skipgram(model, lr, line);
  44. }
  45. }

这一步我们主要看 model.setTargetCounts() 方法和 supervised() 方法.

  1. public void setTargetCounts(final List<Long> counts) {
  2. Utils.checkArgument(counts.size() == osz_);
  3. //负采样
  4. if (args_.loss == loss_name.ns) {
  5. initTableNegatives(counts);
  6. }
  7. //构建Huffuman树
  8. if (args_.loss == loss_name.hs) {
  9. buildTree(counts);
  10. }
  11. }

前面也说了 CBOW 模型也有两套实现方式,分别是分层的 softmax 和负采样.

我们先看 Haffman 树的实现方式,这也是论文中提到的实现.

  1. public void buildTree(final List<Long> counts) {
  2. //每一个标签作为叶子节点,网上追溯到根节点经过的节点和 haffman 编码
  3. paths = new ArrayList<List<Integer>>(osz_);
  4. codes = new ArrayList<List<Boolean>>(osz_);
  5. //用数组的方式保存 haffman 树: tree_size = 2*num(叶子节点数) - 1
  6. tree = new ArrayList<Node>(2 * osz_ - 1);
  7. //先初始化每个节点
  8. for (int i = 0; i < 2 * osz_ - 1; i++) {
  9. Node node = new Node();
  10. node.parent = -1;
  11. node.left = -1;
  12. node.right = -1;
  13. node.count = 1000000000000000L;// 1e15f;
  14. node.binary = false;
  15. tree.add(i, node);
  16. }
  17. //tree的index 前 osz_ 存叶子节点标签
  18. for (int i = 0; i < osz_; i++) {
  19. tree.get(i).count = counts.get(i);
  20. }
  21. //从叶子节点最后往前逐步构建
  22. int leaf = osz_ - 1;
  23. int node = osz_;
  24. //首先 counts实体列表应该是按照 count 从大到小排列的顺序
  25. //从 osz 开始依次去左右两边的合成 count 较小的两个实体,然后在右边最新节点创建实体
  26. for (int i = osz_; i < 2 * osz_ - 1; i++) {
  27. int[] mini = new int[2];
  28. for (int j = 0; j < 2; j++) {
  29. if (leaf >= 0 && tree.get(leaf).count < tree.get(node).count) {
  30. mini[j] = leaf--;
  31. } else {
  32. mini[j] = node++;
  33. }
  34. }
  35. //构建父节点
  36. tree.get(i).left = mini[0];
  37. tree.get(i).right = mini[1];
  38. tree.get(i).count = tree.get(mini[0]).count + tree.get(mini[1]).count;
  39. tree.get(mini[0]).parent = i;
  40. tree.get(mini[1]).parent = i;
  41. //左侧编码为 false,右侧编码为 true
  42. tree.get(mini[1]).binary = true;
  43. } //则 tree 的最后一个节点即为根节点.
  44. //对每个叶子节点创建路由选择路径
  45. for (int i = 0; i < osz_; i++) {
  46. List<Integer> path = new ArrayList<Integer>();
  47. List<Boolean> code = new ArrayList<Boolean>();
  48. int j = i;
  49. while (tree.get(j).parent != -1) {
  50. //由 Huffman 树叶子节点的特殊性得
  51. path.add(tree.get(j).parent - osz_);
  52. code.add(tree.get(j).binary);
  53. j = tree.get(j).parent;
  54. }
  55. paths.add(path);
  56. codes.add(code);
  57. }
  58. }

上面用数组的方式来创建 Haffman 树的过程比较复杂. 首先是 counts 列表是从大到小排,依次存放在 tree 列表的前 osz 个索引位,为Haffman树的叶子节点;

然后从第 osz+1 开始往右就是存放非叶子节点。由于从 tree 列表从 0osz - 1 是从大到小排,根据 Haffman 树的构造过程,不断合并 leafnode 出较小的节点生成它们的父节点,并添加到 tree 的 最优边,这么递归下去,Haffman 树就构建完成了;

其中最后一个节点就是根节点.


接下来看 supervised() 方法:

  1. public void supervised(Model model, float lr, final List<Integer> line, final List<Integer> labels) {
  2. if (labels.size() == 0 || line.size() == 0)
  3. return;
  4. //随机取一个 label 作为 line 的标号
  5. int i = Utils.randomInt(model.rng, 1, labels.size()) - 1;
  6. //line 预测 labels_i
  7. model.update(line, labels.get(i), lr);
  8. }

如果检查到这一行有多个标签,则随机选一个.

  1. public void update(final List<Integer> input, int target, float lr) {
  2. if (input.size() == 0) {
  3. return;
  4. }
  5. //计算隐藏层,将上下文的词向量累加
  6. computeHidden(input, hidden_);
  7. //负采样
  8. if (args_.loss == loss_name.ns) {
  9. loss_ += negativeSampling(target, lr);
  10. //分层 softmax
  11. } else if (args_.loss == loss_name.hs) {
  12. loss_ += hierarchicalSoftmax(target, lr);
  13. } else {
  14. loss_ += softmax(target, lr);
  15. }
  16. //真正用于训练的样例数
  17. nexamples_ += 1;
  18. //梯度平均分配
  19. if (args_.model == model_name.sup) {
  20. grad_.mul(1.0f / input.size());
  21. }
  22. //分别将梯度误差贡献到每个词向量上
  23. for (Integer it : input) {
  24. wi_.addRow(grad_, it, 1.0f);
  25. }
  26. }

这步跟 Word2Vec 的步骤差不多,只是预测当前词的词向量,变成了预测 label;

  1. public void computeHidden(final List<Integer> input, Vector hidden) {
  2. Utils.checkArgument(hidden.size() == hsz_);
  3. hidden.zero();
  4. //将 wi_ 的第 it 行词向量添加到 hidden 向量上
  5. for (Integer it : input) {
  6. hidden.addRow(wi_, it);
  7. }
  8. //求个平均
  9. hidden.mul(1.0f / input.size());
  10. }

在 CBOW 模型中,隐藏层的工作就是将当前词的上下文的词向量进行累加,生成 text representation .

接下来就是将隐藏层传入输出层计算损失函数:

  1. //负采样损失函数
  2. if (args_.loss == loss_name.ns) {
  3. loss_ += negativeSampling(target, lr);
  4. //分层 softmax损失函数
  5. } else if (args_.loss == loss_name.hs) {
  6. loss_ += hierarchicalSoftmax(target, lr);
  7. } else {
  8. loss_ += softmax(target, lr);
  9. }
  10. //真正用于训练的样例数
  11. nexamples_ += 1;
  12. //梯度平均分配
  13. if (args_.model == model_name.sup) {
  14. grad_.mul(1.0f / input.size());
  15. }
  16. //分别将梯度误差贡献到每个词向量上
  17. for (Integer it : input) {
  18. wi_.addRow(grad_, it, 1.0f);
  19. }

前面我们进行损失函数初始化时使用的 args.loss == loss_name.hs ,所以我们先只考虑 hierarchicalSoftmax() .

  1. public float hierarchicalSoftmax(int target, float lr) {
  2. float loss = 0.0f;
  3. grad_.zero();
  4. //通过 target 索引获取路由路径
  5. final List<Boolean> binaryCode = codes.get(target);
  6. final List<Integer> pathToRoot = paths.get(target);
  7. //从叶子节点开始往上回溯,逐步逻辑回归
  8. for (int i = 0; i < pathToRoot.size(); i++) {
  9. loss += binaryLogistic(pathToRoot.get(i), binaryCode.get(i), lr);
  10. }
  11. return loss;
  12. }

hierarchical softmax 的思想就是从根节点开始,逐步做逻辑回归二分类,一层一层往下,沿着预测 label 的 路径到达预测 label 所在的叶子节点 . 其中 Haffman 树的节点的参数向量为权值向量 ,边的 code 即为分类的标号.

  1. public float binaryLogistic(int target, boolean label, float lr) {
  2. //逻辑回归
  3. float score = sigmoid(wo_.dotRow(hidden_, target));
  4. //线性搜索误差,不过这里是不是还应该乘以sigmoid的导数: score*(1 - score) ?
  5. float alpha = lr * ((label ? 1.0f : 0.0f) - score);
  6. //更新梯度
  7. grad_.addRow(wo_, target, alpha);
  8. //更新 haffman 节点的参数向量
  9. wo_.addRow(hidden_, target, alpha);
  10. //负指数损失
  11. if (label) {
  12. //1
  13. return -log(score);
  14. } else {
  15. //0
  16. return -log(1.0f - score);
  17. }
  18. }

逻辑回归步先计算 ,然后计算 ,然后计算误差 ,接着根据 去更新梯度,然后更新 Haffman 树节点的权重向量 ;最后根据预测正例还是反例返回负指数损失.

再回到 update() 方法中,后面再将梯度误差平均分配到每个上下文词向量中.

不过上面计算 sigmoid 函数的误差的时候,笔者感觉有问题,应该是 ,根据均方误差对 求偏导算误差的公式得出.


上面是输出层的损失函数使用的 hierarchical softmax,到这里就分析完成了;下面补充 negative sampling 的流程.

先初始化负采样表:

  1. public void initTableNegatives(final List<Long> counts) {
  2. negatives = new ArrayList<Integer>(counts.size());
  3. float z = 0.0f;
  4. //计算非规则化的词频比例
  5. for (int i = 0; i < counts.size(); i++) {
  6. z += (float) Math.pow(counts.get(i), 0.5f);
  7. }
  8. for (int i = 0; i < counts.size(); i++) {
  9. //按词频做权重剖分 负采用表大小
  10. float c = (float) Math.pow(counts.get(i), 0.5f);
  11. //等距剖分,距离为 1
  12. for (int j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {
  13. negatives.add(i);
  14. }
  15. }
  16. //列表混洗
  17. Utils.shuffle(negatives, rng);
  18. }

负采样逻辑与 Word2Vec 一样,先根据每个单词的词频算占总体的权重,然后对 NEGATIVE_TABLE_SIZE 长度进行等距剖分,从代码中可以看到,变量 j 每次加 1,说明等距剖分的距离为 1;然后进行混洗,方便后续随机采样.

具体可参看: http://blog.csdn.net/itplus/article/details/37999613

再看训练时的损失函数部分:

  1. public float negativeSampling(int target, float lr) {
  2. float loss = 0.0f;
  3. grad_.zero();
  4. // args.neg 为采样次数
  5. for (int n = 0; n <= args_.neg; n++) {
  6. if (n == 0) {
  7. //第一个是正例
  8. loss += binaryLogistic(target, true, lr);
  9. } else {
  10. //其他是反例
  11. loss += binaryLogistic(getNegative(target), false, lr);
  12. }
  13. }
  14. return loss;
  15. }

负采样使用的思路是:预测当前词为正例样本,预测其他词为负例样本,多次采样,循环做逻辑回归. 具体可参看 Word2Vec 的负采样版本: https://www.zybuluo.com/evilking/note/871932;区别在于 Word2Vec 中是预测当前词,而 FastText 是根据当前词的上下文预测 label .

代码逻辑如注释所述,而逻辑回归部分在前面以及分析了,这里就不讲。重点是如何采样?

  1. public int getNegative(int target) {
  2. int negative;
  3. do {
  4. //从采样表里采样
  5. negative = negatives.get(negpos);
  6. //每次 +1 遍历
  7. negpos = (negpos + 1) % negatives.size();
  8. //若采样到 target 就跳过重新采样
  9. } while (target == negative);
  10. return negative;
  11. }

到这里训练部分就分析完了,代码的整个流程都还是挺清晰的.


预测

模型训练完成后就可用于预测,给定一个文本,可以预测其标签。

回到 Main 类的 predict() 方法:

  1. //加载模型
  2. FastText fasttext = new FastText();
  3. fasttext.loadModel(args[1]);
  4. //预测文本
  5. fasttext.predict(new FileInputStream(file), k, print_prob);

注意这里的参数 k ,在论文中提到,当只考虑前 个目标时,计算复杂度可降到 ,这里的 k,这是为了进一步提高速度.

我们沿着:
fasttext.predict() ——> predict(lineTokens,k) ——> model.predict():

Model 类中:

  1. public void predict(final List<Integer> input, int k, List<Pair<Float, Integer>> heap) {
  2. predict(input, k, heap, hidden_, output_);
  3. }

对这里的参数进行说明:


  1. public void predict(final List<Integer> input, int k, List<Pair<Float, Integer>> heap, Vector hidden,
  2. Vector output) {
  3. //将上下文词向量累加
  4. computeHidden(input, hidden);
  5. if (args_.loss == loss_name.hs) {
  6. //深度优先搜索
  7. dfs(k, 2 * osz_ - 2, 0.0f, heap, hidden);
  8. } else {
  9. //直接找最好的 k 个
  10. findKBest(k, heap, hidden, output);
  11. }
  12. //重新从大到小排序
  13. Collections.sort(heap, comparePairs);
  14. }

第一步计算 将当前词的上下文词向量进行累加.

针对不同的损失函数,预测这里有不同的方法,先看若输出层为 hierarchical softmax 时如何预测:

  1. public void dfs(int k, int node, float score, List<Pair<Float, Integer>> heap, Vector hidden) {
  2. //如果 heap 存 k 个已经满了,而当前计算的score比heap 中最小的节点概率还小,可直接过滤不考虑
  3. if (heap.size() == k && score < heap.get(heap.size() - 1).getKey()) {
  4. return;
  5. }
  6. //当 node 为叶子节点时
  7. if (tree.get(node).left == -1 && tree.get(node).right == -1) {
  8. //添加进 heap 中
  9. heap.add(new Pair<Float, Integer>(score, node));
  10. //将 heap 重新由大到小排序
  11. Collections.sort(heap, comparePairs);
  12. //如果这时 heap 的大小大于 k 了,就移除概率最小的那个节点
  13. if (heap.size() > k) {
  14. Collections.sort(heap, comparePairs);
  15. heap.remove(heap.size() - 1); // pop last
  16. }
  17. return;
  18. }
  19. //逻辑回归算分为正例的概率
  20. float f = sigmoid(wo_.dotRow(hidden, node - osz_));
  21. //递归左侧路径,分为负例
  22. dfs(k, tree.get(node).left, score + log(1.0f - f), heap, hidden);
  23. //递归右侧路径,分为正例
  24. dfs(k, tree.get(node).right, score + log(f), heap, hidden);
  25. }

这里依然是取前 k 个路径概率最大的节点,用到了 Haffman 树的特性: 子节点的路径概率肯定小于父节点的路径概率,如果父节点由于路径概率过小而被剔除,那子节点更加不用考虑.

对于输出层为负采样损失函数来说复杂度就高点:

  1. public void findKBest(int k, List<Pair<Float, Integer>> heap, Vector hidden, Vector output) {
  2. //分别计算预测到所有类别的概率,并用softmax规则化
  3. computeOutputSoftmax(hidden, output);
  4. //遍历所有类别,取前 k 个最大概率的节点
  5. for (int i = 0; i < osz_; i++) {
  6. if (heap.size() == k && log(output.get(i)) < heap.get(heap.size() - 1).getKey()) {
  7. continue;
  8. }
  9. heap.add(new Pair<Float, Integer>(log(output.get(i)), i));
  10. Collections.sort(heap, comparePairs);
  11. if (heap.size() > k) {
  12. Collections.sort(heap, comparePairs);
  13. // pop last
  14. heap.remove(heap.size() - 1);
  15. }
  16. }
  17. }

这个方法主要是 computeOutputSoftmax(hidden, output) 这句,后面是为了遍历所有类别,取前 k 个预测类别的概率最大的节点,代码都好理解.

  1. public void computeOutputSoftmax(Vector hidden, Vector output) {
  2. //矩阵运算,每个类别的权值向量与文本表示成绩
  3. output.mul(wo_, hidden);
  4. //计算最大的 X^T * theta
  5. float max = output.get(0), z = 0.0f;
  6. for (int i = 1; i < osz_; i++) {
  7. max = Math.max(output.get(i), max);
  8. }
  9. //softmax 规则化
  10. for (int i = 0; i < osz_; i++) {
  11. output.set(i, (float) Math.exp(output.get(i) - max));
  12. z += output.get(i);
  13. }
  14. for (int i = 0; i < osz_; i++) {
  15. output.set(i, output.get(i) / z);
  16. }
  17. }

第一步根据矩阵乘法,计算出所有类别的输出层输入概率,即 ,然后做 softmax 规则化,使用的公式如下:


其中 笔者以为是为了做数据标准化,但是又不对,如果按最大最小标准化的话应该是 ,如果 softmax 的分母将 给提出来的话,代码中的公式就是对的了.


小结

本篇上半部分主要是翻译 FastText 的原论文,在该论文中也没怎么说具体的操作步骤,因为 fasttext 的模型框架与 Word2Vec 中 CBOW 模型类似,只是将预测词向量变成了预测 label 向量,而其他技巧主要是编程实现上做了改进.

所以在下半部分源码分析中做了说明,而源码分析笔者也只能带领读者梳理下分析流程,具体源码中的用意还需要读者自己多理解.

希望能对读者理解 fasttext 有所帮助,谢谢!

参考

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注