sklearn-文本分析

本贴最后更新于 2992 天前,其中的信息可能已经时移世异

本章节的目的是通过一个实际的问题来介绍 scikit-learn 的主要文本分析工具。该问题是:分析有 20 个主题的文本文件(新闻帖)。

在本章节中,我们会接触到如下内容:

  • 加载文件内容和类别
  • 抽取适合机器学习的特征向量
  • 训练线性模型来拟合分类
  • 使用网格搜索来寻找适合特征抽取和分类的参数配置

#开始
在开始该教程之前,你必须安装 scikit-learn 和所有需求的依赖。
安装相关的请查看 installation

该教程的源码可以在你的 scikit-learn 文件夹下找到:

scikit-learn/doc/tutorial/text_analytics/  

教程文件下,应该包含了如下文件:

  • *.rst files - 使用 sphinx 写的教程文档
  • data - 本教程将用到的数据集
  • skeletons - 练习题的不完全示例脚本
  • solutions - 练习题的答案

你可以将 skeletons 复制到你硬盘上的文件夹下,并重命名为 sklearn_tut_workspace,这样你就可以编辑自己的练习题解决方法,同时也不影响原来的内容:

% cp -r skeletons work_directory/sklearn_tut_workspace  

机器学习算法需要数据。到每个 $TUTORIAL_HOME/data 子文件价下,运行 fetch_data.py 脚本。
例如:

% cd $TUTORIAL_HOME/data/languages
% less fetch_data.py
% python fetch_data.py

#加载“Twenty Newsgroups”数据集
这是“Twenty Newsgroups”数据集的官方描述

20 Newsgroups 数据集是大约 20000 新闻报道文档的集合,大致覆盖了 20 类不同的新闻报道。这些文档最初是由 Ken Lang 为了支撑他的论文“Newsweeder: Learning to filter netnews”收集的。20 Newsgroups 数据集很快在机器学习处理文本技术实验中流行起来,常用于文本分类和聚类。
接下来,我们将使用 sklearn 内建的数据集加载器加载 20 newsgroups 数据集。当然,你也可以在网上下载数据集,再用 sklearn.datasets.load_files 指向解压出来的 20news-bydate-train 子文件夹。

为了节约时间,在第一个例子中,我们只是关注 20 类中的 4 类新闻报道:

>>> categories = ['alt.atheism', 'soc.religion.christian',
...               'comp.graphics', 'sci.med']

现在我们加载属于上述 4 类的新闻的文件:

>>> from sklearn.datasets import fetch_20newsgroups
>>> twenty_train = fetch_20newsgroups(subset='train',
...     categories=categories, shuffle=True, random_state=42)  

返回的数据集是 sklearn 中的 bunch 实体:包含的字段信息和数据,可以像 python 中的 dict 或 object 一样访问。target_names 属性保存了类别:

>>> twenty_train.target_names
['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

加载到内容中的文件数据存储在 data 属性中。也可以使用 filenames 属性访问:

>>> len(twenty_train.data)
2257
>>> len(twenty_train.filenames)
2257

打印加载的第一个文件的第一行:

>>> print("\n".join(twenty_train.data[0].split("\n")[:3]))
From: sd345@city.ac.uk (Michael Collier)
Subject: Converting images to HP LaserJet III?
Nntp-Posting-Host: hampton

>>> print(twenty_train.target_names[twenty_train.target[0]])
comp.graphics

有监督学习算法在训练集中需要每个文档和对应的类别属性。在这个例子中,类别是新闻报道的类别,同时也是每个文档的父文件夹的名字。类别属性用整数代表按顺序存储在 target 属性中:

>>> twenty_train.target[:10]
array([1, 1, 3, 3, 3, 3, 3, 2, 2, 2])

可以用以下方法还原类别的真正名称:

>>> for t in twenty_train.target[:10]:
...     print(twenty_train.target_names[t])
...
comp.graphics
comp.graphics
soc.religion.christian
soc.religion.christian
soc.religion.christian
soc.religion.christian
soc.religion.christian
sci.med
sci.med
sci.med  

你可以注意到样本已经被随机洗牌,这对以下这种情况特别有用:你只是选择第一个样本来快速训练模型,并以训练的结果来启发之后的正式训练。


#从文本文件抽取特征
为了对文本文件使用机器学习,首先我们需要将文本内容转化为数值特征向量。
##Bags of words
最直观的方法就是抽取有代表性的单词:

  1. 为训练集中每个文档中出现的每个单词分配一个固定的数字 id(建立从单词映射到数字索引的 dict)
  2. 对每个文档 i,计算每个单词 w 出现的次数并存在 X[i, j],其中特征 j 是单词 w 在 1 中分配的 id 值。

由 Bags of words 方法产生的向量的维度 n_features 是语料库中不同单词的数量:约大于 100,000

若样本数量 n_samples == 10000,特征向量 X 用类型为 float32 的 numpy array 表示,那么需要 10000 * 100000 * 4 bytes = 4GB 内存,即使对于目前的计算机来说也是很勉强的。

幸运的是,特征 X 中的大部分值为 0,因为给定的文档中使用的单词不超过几千个。因此,我们认为 bags of words 的结果 是典型的 高维系数数据集。我们可以通过只存储向量中非零的部分来节约内存。

scipy.sparse 矩阵正是为了解决这种问题设计的数据结构,sklearn 内建中已经支持了这种数据结构。

##使用 sklearn 进行分词(Tokenizing text)
文本预处理,分词和过滤被包含在高级组件中,可以用于创建特征字典和将文档转化为特征向量:

>>> from sklearn.feature_extraction.text import CountVectorizer
>>> count_vect = CountVectorizer()
>>> X_train_counts = count_vect.fit_transform(twenty_train.data)
>>> X_train_counts.shape
(2257, 35788)

CountVectorizer 支持计算 N-grams 单词或字符序列。一旦 fit 完成,CountVectorizer 建立起特征索引的 dict:

>>> count_vect.vocabulary_.get(u'algorithm')
4690  

单词表中单词的索引值指向其在整个训练语料库中的出现次数。

##将出现次数转化为频率

计算出现次数是一个好的开端,但是存在如下问题:更长的文档中单词的平均出现次数会比短文档的更高,即使他们的主题是一致的。

为了避免出现上述可能的差异,使用文档中每个单词出现的数量除以该文档单词的总数量:这个新特征称为 tf (Term Frequencies,词频)。

另一个需要考虑的问题是,一个文档的语料库越小,则每个语料包含的信息量越大。因此需要削减语料库大的文档中单词特征的权重。

这种削减方法是 tf-idf(Term Frequency times Inverse Document Frequency)

tf 和 tf-idf 可以通过下面代码计算:

>>> from sklearn.feature_extraction.text import TfidfTransformer
>>> tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts)
>>> X_train_tf = tf_transformer.transform(X_train_counts)
>>> X_train_tf.shape
(2257, 35788)  

在上面示例代码中,我们先使用 fit(...) 方法使用数据调整 estimator,接着使用 transform(...)方法将我们的计数矩阵转化为 tf-idf 表示。直接使用 fit_transform(..) 方法将这两个步骤可以合并到一起以减少一些中间计算。 以下代码实现的功能和上面的代码一直:

>>> tfidf_transformer = TfidfTransformer()
>>> X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
>>> X_train_tfidf.shape
(2257, 35788)

#训练分类器
现在我们已经得到了特征,我们可以训练一个分类器并尝试预测测试数据的类别。我们将以 naive Bayes 开始,该分类器可以为我们的问题提供一个良好的基线。sklearn 包含了 naive Bayes 的几个变种版本,其中多项式版本最适合于词频。

>>> from sklearn.naive_bayes import MultinomialNB
>>> clf = MultinomialNB().fit(X_train_tfidf, twenty_train.target)  

预测新的文档的类别,我们需要用和上述几乎一直的方法抽出特征。不同的是,我们直接调用 TfidfTransformer 的 transform 方法,而不是 fit_transform。因为之前我们已经使用训练样本 fit 过了。

>>> docs_new = ['God is love', 'OpenGL on the GPU is fast']
>>> X_new_counts = count_vect.transform(docs_new)
>>> X_new_tfidf = tfidf_transformer.transform(X_new_counts)

>>> predicted = clf.predict(X_new_tfidf)

>>> for doc, category in zip(docs_new, predicted):
...     print('%r => %s' % (doc, twenty_train.target_names[category]))
...
'God is love' => soc.religion.christian
'OpenGL on the GPU is fast' => comp.graphics  

#建立管道
为了更加简便地使用 vectorizer => transformer => classifier 工作流程,sklearn 提供了 Pipeline 类,该类类似于混合分类器:

>>> from sklearn.pipeline import Pipeline
>>> text_clf = Pipeline([('vect', CountVectorizer()),
...                      ('tfidf', TfidfTransformer()),
...                      ('clf', MultinomialNB()),
... ])  

其中 vect, tfidf 和 clf 是随意命名的。我们将在下面网格搜索一节中看到他们的用法。现在训练整个模型(包括特征抽取、转化、分类器训练),仅仅需要通过以下命令:

>>> text_clf = text_clf.fit(twenty_train.data, twenty_train.target)

#使用测试集评估
评估模型预测的正确率是非常简单的:

>>> import numpy as np
>>> twenty_test = fetch_20newsgroups(subset='test',
...     categories=categories, shuffle=True, random_state=42)
>>> docs_test = twenty_test.data
>>> predicted = text_clf.predict(docs_test)
>>> np.mean(predicted == twenty_test.target)            
0.834...  

我们获得 83.4% 的准确率。现在我们看看能否使用线性 SVM 模型获得更好的结果(线性 SVM 被普遍地认为是最好的文本分类算法,虽然比 naive Bayes 慢一些)。我们仅仅需要将管道中的分类器进行特换即可。

>>> from sklearn.linear_model import SGDClassifier
>>> text_clf = Pipeline([('vect', CountVectorizer()),
...                      ('tfidf', TfidfTransformer()),
...                      ('clf', SGDClassifier(loss='hinge', penalty='l2',
...                                            alpha=1e-3, n_iter=5, random_state=42)),
... ])
>>> _ = text_clf.fit(twenty_train.data, twenty_train.target)
>>> predicted = text_clf.predict(docs_test)
>>> np.mean(predicted == twenty_test.target)            
0.912...  

此外 sklearn 还提供了更详细的效果评估工具:

>>> from sklearn import metrics
>>> print(metrics.classification_report(twenty_test.target, predicted,
...     target_names=twenty_test.target_names))
...                                         
                        precision    recall  f1-score   support

           alt.atheism       0.95      0.81      0.87       319
         comp.graphics       0.88      0.97      0.92       389
               sci.med       0.94      0.90      0.92       396
soc.religion.christian       0.90      0.95      0.93       398

           avg / total       0.92      0.91      0.91      1502


>>> metrics.confusion_matrix(twenty_test.target, predicted)
array([[258,  11,  15,  35],
       [  4, 379,   3,   3],
       [  5,  33, 355,   3],
       [  5,  10,   4, 379]])  

由混淆矩阵可以看出,新闻报道中的 atheism 主题比 comp.graphics 更容易被混淆。


#使用网格搜索调整参数
我们已经在 TfidfTransformer 中使用了一些参数,如“use_idf”。同样,分类器也会有很多参数,如 MultinomialNB 分类器包含平滑参数 aipha,SGDClassifier 包含惩罚参数 alpha 等等。

逐个调整 pipline 中的参数是不明智的,我们需要一个穷举搜索方法(exhaustive search)帮助我们寻找参数网格中最好的参数组合。

>>> from sklearn.grid_search import GridSearchCV
>>> parameters = {'vect__ngram_range': [(1, 1), (1, 2)],
...               'tfidf__use_idf': (True, False),
...               'clf__alpha': (1e-2, 1e-3),
... }

显然的,穷举搜索方法开销是较大的。如果我们有多核 CPU,我们可以通过设置 n_jobs = -1,让网格搜索计算时使用所有的 cpu 进行并行计算:

>>> gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)  

网格搜索实例和普通的 sklearn 模型一样。让我们在一个较小的数据集中机械能网格搜索,以缩短计算时间:

>>> gs_clf = gs_clf.fit(twenty_train.data[:400], twenty_train.target[:400])  

GridSearchCV 的 fit 方法返回一个分类器,我们可以使用它进行预测:

>>> twenty_train.target_names[gs_clf.predict(['God is love'])]
'soc.religion.christian'  

但是另一方面,这个分类器是相当巨大和笨拙的。我们可以使用 grid_scores_ 属性从该对象中获取最佳的参数列表。

>>> best_parameters, score, _ = max(gs_clf.grid_scores_, key=lambda x: x[1])
>>> for param_name in sorted(parameters.keys()):
...     print("%s: %r" % (param_name, best_parameters[param_name]))
...
clf__alpha: 0.001
tfidf__use_idf: True
vect__ngram_range: (1, 1)

>>> score                                              
0.900...  

练习题连接


#路在何方
以下是几点建议可以帮助你在学完本指导后,在 sklearn 路上走得更远:

  • 尝试玩一玩 CountVectorizer 下的 analyzer 和 token normalisation
  • 如果你没有类属性,尝试使用聚类方法获得
  • 如果每个文档有多个类属性,可以看看 Multiclass and multilabel section
  • 尝试使用 Truncated SVD 进行潜在语义分析
  • 使用 Out-of-core 分类方法学习没办法加载到主存的数据
  • 尝试使用 Hashing Vectorizer 替换 CountVectorizer
  • 数据挖掘
    17 引用 • 32 回帖 • 2 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    536 引用 • 672 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • Q&A

    提问之前请先看《提问的智慧》,好的问题比好的答案更有价值。

    7015 引用 • 31704 回帖 • 220 关注
  • Netty

    Netty 是一个基于 NIO 的客户端-服务器编程框架,使用 Netty 可以让你快速、简单地开发出一个可维护、高性能的网络应用,例如实现了某种协议的客户、服务端应用。

    49 引用 • 33 回帖 • 20 关注
  • B3log

    B3log 是一个开源组织,名字来源于“Bulletin Board Blog”缩写,目标是将独立博客与论坛结合,形成一种新的网络社区体验,详细请看 B3log 构思。目前 B3log 已经开源了多款产品:SymSoloVditor思源笔记

    1083 引用 • 3461 回帖 • 257 关注
  • 负能量

    上帝为你关上了一扇门,然后就去睡觉了....努力不一定能成功,但不努力一定很轻松 (° ー °〃)

    88 引用 • 1234 回帖 • 442 关注
  • GraphQL

    GraphQL 是一个用于 API 的查询语言,是一个使用基于类型系统来执行查询的服务端运行时(类型系统由你的数据定义)。GraphQL 并没有和任何特定数据库或者存储引擎绑定,而是依靠你现有的代码和数据支撑。

    4 引用 • 3 回帖 • 16 关注
  • PWL

    组织简介

    用爱发电 (Programming With Love) 是一个以开源精神为核心的民间开源爱好者技术组织,“用爱发电”象征开源与贡献精神,加入组织,代表你将遵守组织的“个人开源爱好者”的各项条款。申请加入:用爱发电组织邀请帖
    用爱发电组织官网:https://programmingwithlove.stackoverflow.wiki/

    用爱发电组织的核心驱动力:

    • 遵守开源守则,体现开源&贡献精神:以分享为目的,拒绝非法牟利。
    • 自我保护:使用适当的 License 保护自己的原创作品。
    • 尊重他人:不以各种理由、各种漏洞进行未经允许的抄袭、散播、洩露;以礼相待,尊重所有对社区做出贡献的开发者;通过他人的分享习得知识,要留下足迹,表示感谢。
    • 热爱编程、热爱学习:加入组织,热爱编程是首当其要的。我们欢迎热爱讨论、分享、提问的朋友,也同样欢迎默默成就的朋友。
    • 倾听:正确并恳切对待、处理问题与建议,及时修复开源项目的 Bug ,及时与反馈者沟通。不抬杠、不无视、不辱骂。
    • 平视:不诋毁、轻视、嘲讽其他开发者,主动提出建议、施以帮助,以和谐为本。只要他人肯努力,你也可能会被昔日小看的人所超越,所以请保持谦虚。
    • 乐观且活跃:你的努力决定了你的高度。不要放弃,多年后回头俯瞰,才会发现自己已经成就往日所仰望的水平。积极地将项目开源,帮助他人学习、改进,自己也会获得相应的提升、成就与成就感。
    1 引用 • 487 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    76 引用 • 37 回帖 • 1 关注
  • Kafka

    Kafka 是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据。 这种动作(网页浏览,搜索和其他用户的行动)是现代系统中许多功能的基础。 这些数据通常是由于吞吐量的要求而通过处理日志和日志聚合来解决。

    35 引用 • 35 回帖
  • ZeroNet

    ZeroNet 是一个基于比特币加密技术和 BT 网络技术的去中心化的、开放开源的网络和交流系统。

    1 引用 • 21 回帖 • 609 关注
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    19 引用 • 31 回帖 • 1 关注
  • Solidity

    Solidity 是一种智能合约高级语言,运行在 [以太坊] 虚拟机(EVM)之上。它的语法接近于 JavaScript,是一种面向对象的语言。

    3 引用 • 18 回帖 • 353 关注
  • Bug

    Bug 本意是指臭虫、缺陷、损坏、犯贫、窃听器、小虫等。现在人们把在程序中一些缺陷或问题统称为 bug(漏洞)。

    71 引用 • 1737 回帖 • 1 关注
  • CSDN

    CSDN (Chinese Software Developer Network) 创立于 1999 年,是中国的 IT 社区和服务平台,为中国的软件开发者和 IT 从业者提供知识传播、职业发展、软件开发等全生命周期服务,满足他们在职业发展中学习及共享知识和信息、建立职业发展社交圈、通过软件开发实现技术商业化等刚性需求。

    14 引用 • 155 回帖
  • GitLab

    GitLab 是利用 Ruby 一个开源的版本管理系统,实现一个自托管的 Git 项目仓库,可通过 Web 界面操作公开或私有项目。

    46 引用 • 72 回帖
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 566 关注
  • RIP

    愿逝者安息!

    8 引用 • 92 回帖 • 322 关注
  • OpenStack

    OpenStack 是一个云操作系统,通过数据中心可控制大型的计算、存储、网络等资源池。所有的管理通过前端界面管理员就可以完成,同样也可以通过 Web 接口让最终用户部署资源。

    10 引用 • 5 关注
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 646 关注
  • Electron

    Electron 基于 Chromium 和 Node.js,让你可以使用 HTML、CSS 和 JavaScript 构建应用。它是一个由 GitHub 及众多贡献者组成的活跃社区共同维护的开源项目,兼容 Mac、Windows 和 Linux,它构建的应用可在这三个操作系统上面运行。

    15 引用 • 136 回帖 • 6 关注
  • 音乐

    你听到信仰的声音了么?

    60 引用 • 510 回帖 • 1 关注
  • TensorFlow

    TensorFlow 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。

    20 引用 • 19 回帖
  • flomo

    flomo 是新一代 「卡片笔记」 ,专注在碎片化时代,促进你的记录,帮你积累更多知识资产。

    4 引用 • 91 回帖
  • abitmean

    有点意思就行了

    38 关注
  • 书籍

    宋真宗赵恒曾经说过:“书中自有黄金屋,书中自有颜如玉。”

    76 引用 • 390 回帖
  • 电影

    这是一个不能说的秘密。

    120 引用 • 598 回帖
  • SMTP

    SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。SMTP 协议属于 TCP/IP 协议簇,它帮助每台计算机在发送或中转信件时找到下一个目的地。

    4 引用 • 18 回帖 • 608 关注
  • Oracle

    Oracle(甲骨文)公司,全称甲骨文股份有限公司(甲骨文软件系统有限公司),是全球最大的企业级软件公司,总部位于美国加利福尼亚州的红木滩。1989 年正式进入中国市场。2013 年,甲骨文已超越 IBM,成为继 Microsoft 后全球第二大软件公司。

    103 引用 • 126 回帖 • 443 关注