sklearn-文本分析

本贴最后更新于 3268 天前,其中的信息可能已经时移世异

本章节的目的是通过一个实际的问题来介绍 scikit-learn 的主要文本分析工具。该问题是:分析有 20 个主题的文本文件(新闻帖)。

在本章节中,我们会接触到如下内容:

  • 加载文件内容和类别
  • 抽取适合机器学习的特征向量
  • 训练线性模型来拟合分类
  • 使用网格搜索来寻找适合特征抽取和分类的参数配置

#开始
在开始该教程之前,你必须安装 scikit-learn 和所有需求的依赖。
安装相关的请查看 installation

该教程的源码可以在你的 scikit-learn 文件夹下找到:

scikit-learn/doc/tutorial/text_analytics/

教程文件下,应该包含了如下文件:

  • *.rst files - 使用 sphinx 写的教程文档
  • data - 本教程将用到的数据集
  • skeletons - 练习题的不完全示例脚本
  • solutions - 练习题的答案

你可以将 skeletons 复制到你硬盘上的文件夹下,并重命名为 sklearn_tut_workspace,这样你就可以编辑自己的练习题解决方法,同时也不影响原来的内容:

% cp -r skeletons work_directory/sklearn_tut_workspace

机器学习算法需要数据。到每个 $TUTORIAL_HOME/data 子文件价下,运行 fetch_data.py 脚本。
例如:

% cd $TUTORIAL_HOME/data/languages % less fetch_data.py % python fetch_data.py

#加载“Twenty Newsgroups”数据集
这是“Twenty Newsgroups”数据集的官方描述

20 Newsgroups 数据集是大约 20000 新闻报道文档的集合,大致覆盖了 20 类不同的新闻报道。这些文档最初是由 Ken Lang 为了支撑他的论文“Newsweeder: Learning to filter netnews”收集的。20 Newsgroups 数据集很快在机器学习处理文本技术实验中流行起来,常用于文本分类和聚类。
接下来,我们将使用 sklearn 内建的数据集加载器加载 20 newsgroups 数据集。当然,你也可以在网上下载数据集,再用 sklearn.datasets.load_files 指向解压出来的 20news-bydate-train 子文件夹。

为了节约时间,在第一个例子中,我们只是关注 20 类中的 4 类新闻报道:

>>> categories = ['alt.atheism', 'soc.religion.christian', ... 'comp.graphics', 'sci.med']

现在我们加载属于上述 4 类的新闻的文件:

>>> from sklearn.datasets import fetch_20newsgroups >>> twenty_train = fetch_20newsgroups(subset='train', ... categories=categories, shuffle=True, random_state=42)

返回的数据集是 sklearn 中的 bunch 实体:包含的字段信息和数据,可以像 python 中的 dict 或 object 一样访问。target_names 属性保存了类别:

>>> twenty_train.target_names ['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

加载到内容中的文件数据存储在 data 属性中。也可以使用 filenames 属性访问:

>>> len(twenty_train.data) 2257 >>> len(twenty_train.filenames) 2257

打印加载的第一个文件的第一行:

>>> print("\n".join(twenty_train.data[0].split("\n")[:3])) From: sd345@city.ac.uk (Michael Collier) Subject: Converting images to HP LaserJet III? Nntp-Posting-Host: hampton >>> print(twenty_train.target_names[twenty_train.target[0]]) comp.graphics

有监督学习算法在训练集中需要每个文档和对应的类别属性。在这个例子中,类别是新闻报道的类别,同时也是每个文档的父文件夹的名字。类别属性用整数代表按顺序存储在 target 属性中:

>>> twenty_train.target[:10] array([1, 1, 3, 3, 3, 3, 3, 2, 2, 2])

可以用以下方法还原类别的真正名称:

>>> for t in twenty_train.target[:10]: ... print(twenty_train.target_names[t]) ... comp.graphics comp.graphics soc.religion.christian soc.religion.christian soc.religion.christian soc.religion.christian soc.religion.christian sci.med sci.med sci.med

你可以注意到样本已经被随机洗牌,这对以下这种情况特别有用:你只是选择第一个样本来快速训练模型,并以训练的结果来启发之后的正式训练。


#从文本文件抽取特征
为了对文本文件使用机器学习,首先我们需要将文本内容转化为数值特征向量。
##Bags of words
最直观的方法就是抽取有代表性的单词:

  1. 为训练集中每个文档中出现的每个单词分配一个固定的数字 id(建立从单词映射到数字索引的 dict)
  2. 对每个文档 i,计算每个单词 w 出现的次数并存在 X[i, j],其中特征 j 是单词 w 在 1 中分配的 id 值。

由 Bags of words 方法产生的向量的维度 n_features 是语料库中不同单词的数量:约大于 100,000

若样本数量 n_samples == 10000,特征向量 X 用类型为 float32 的 numpy array 表示,那么需要 10000 * 100000 * 4 bytes = 4GB 内存,即使对于目前的计算机来说也是很勉强的。

幸运的是,特征 X 中的大部分值为 0,因为给定的文档中使用的单词不超过几千个。因此,我们认为 bags of words 的结果 是典型的 高维系数数据集。我们可以通过只存储向量中非零的部分来节约内存。

scipy.sparse 矩阵正是为了解决这种问题设计的数据结构,sklearn 内建中已经支持了这种数据结构。

##使用 sklearn 进行分词(Tokenizing text)
文本预处理,分词和过滤被包含在高级组件中,可以用于创建特征字典和将文档转化为特征向量:

>>> from sklearn.feature_extraction.text import CountVectorizer >>> count_vect = CountVectorizer() >>> X_train_counts = count_vect.fit_transform(twenty_train.data) >>> X_train_counts.shape (2257, 35788)

CountVectorizer 支持计算 N-grams 单词或字符序列。一旦 fit 完成,CountVectorizer 建立起特征索引的 dict:

>>> count_vect.vocabulary_.get(u'algorithm') 4690

单词表中单词的索引值指向其在整个训练语料库中的出现次数。

##将出现次数转化为频率

计算出现次数是一个好的开端,但是存在如下问题:更长的文档中单词的平均出现次数会比短文档的更高,即使他们的主题是一致的。

为了避免出现上述可能的差异,使用文档中每个单词出现的数量除以该文档单词的总数量:这个新特征称为 tf (Term Frequencies,词频)。

另一个需要考虑的问题是,一个文档的语料库越小,则每个语料包含的信息量越大。因此需要削减语料库大的文档中单词特征的权重。

这种削减方法是 tf-idf(Term Frequency times Inverse Document Frequency)

tf 和 tf-idf 可以通过下面代码计算:

>>> from sklearn.feature_extraction.text import TfidfTransformer >>> tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts) >>> X_train_tf = tf_transformer.transform(X_train_counts) >>> X_train_tf.shape (2257, 35788)

在上面示例代码中,我们先使用 fit(...) 方法使用数据调整 estimator,接着使用 transform(...)方法将我们的计数矩阵转化为 tf-idf 表示。直接使用 fit_transform(..) 方法将这两个步骤可以合并到一起以减少一些中间计算。 以下代码实现的功能和上面的代码一直:

>>> tfidf_transformer = TfidfTransformer() >>> X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts) >>> X_train_tfidf.shape (2257, 35788)

#训练分类器
现在我们已经得到了特征,我们可以训练一个分类器并尝试预测测试数据的类别。我们将以 naive Bayes 开始,该分类器可以为我们的问题提供一个良好的基线。sklearn 包含了 naive Bayes 的几个变种版本,其中多项式版本最适合于词频。

>>> from sklearn.naive_bayes import MultinomialNB >>> clf = MultinomialNB().fit(X_train_tfidf, twenty_train.target)

预测新的文档的类别,我们需要用和上述几乎一直的方法抽出特征。不同的是,我们直接调用 TfidfTransformer 的 transform 方法,而不是 fit_transform。因为之前我们已经使用训练样本 fit 过了。

>>> docs_new = ['God is love', 'OpenGL on the GPU is fast'] >>> X_new_counts = count_vect.transform(docs_new) >>> X_new_tfidf = tfidf_transformer.transform(X_new_counts) >>> predicted = clf.predict(X_new_tfidf) >>> for doc, category in zip(docs_new, predicted): ... print('%r => %s' % (doc, twenty_train.target_names[category])) ... 'God is love' => soc.religion.christian 'OpenGL on the GPU is fast' => comp.graphics

#建立管道
为了更加简便地使用 vectorizer => transformer => classifier 工作流程,sklearn 提供了 Pipeline 类,该类类似于混合分类器:

>>> from sklearn.pipeline import Pipeline >>> text_clf = Pipeline([('vect', CountVectorizer()), ... ('tfidf', TfidfTransformer()), ... ('clf', MultinomialNB()), ... ])

其中 vect, tfidf 和 clf 是随意命名的。我们将在下面网格搜索一节中看到他们的用法。现在训练整个模型(包括特征抽取、转化、分类器训练),仅仅需要通过以下命令:

>>> text_clf = text_clf.fit(twenty_train.data, twenty_train.target)

#使用测试集评估
评估模型预测的正确率是非常简单的:

>>> import numpy as np >>> twenty_test = fetch_20newsgroups(subset='test', ... categories=categories, shuffle=True, random_state=42) >>> docs_test = twenty_test.data >>> predicted = text_clf.predict(docs_test) >>> np.mean(predicted == twenty_test.target) 0.834...

我们获得 83.4% 的准确率。现在我们看看能否使用线性 SVM 模型获得更好的结果(线性 SVM 被普遍地认为是最好的文本分类算法,虽然比 naive Bayes 慢一些)。我们仅仅需要将管道中的分类器进行特换即可。

>>> from sklearn.linear_model import SGDClassifier >>> text_clf = Pipeline([('vect', CountVectorizer()), ... ('tfidf', TfidfTransformer()), ... ('clf', SGDClassifier(loss='hinge', penalty='l2', ... alpha=1e-3, n_iter=5, random_state=42)), ... ]) >>> _ = text_clf.fit(twenty_train.data, twenty_train.target) >>> predicted = text_clf.predict(docs_test) >>> np.mean(predicted == twenty_test.target) 0.912...

此外 sklearn 还提供了更详细的效果评估工具:

>>> from sklearn import metrics >>> print(metrics.classification_report(twenty_test.target, predicted, ... target_names=twenty_test.target_names)) ... precision recall f1-score support alt.atheism 0.95 0.81 0.87 319 comp.graphics 0.88 0.97 0.92 389 sci.med 0.94 0.90 0.92 396 soc.religion.christian 0.90 0.95 0.93 398 avg / total 0.92 0.91 0.91 1502 >>> metrics.confusion_matrix(twenty_test.target, predicted) array([[258, 11, 15, 35], [ 4, 379, 3, 3], [ 5, 33, 355, 3], [ 5, 10, 4, 379]])

由混淆矩阵可以看出,新闻报道中的 atheism 主题比 comp.graphics 更容易被混淆。


#使用网格搜索调整参数
我们已经在 TfidfTransformer 中使用了一些参数,如“use_idf”。同样,分类器也会有很多参数,如 MultinomialNB 分类器包含平滑参数 aipha,SGDClassifier 包含惩罚参数 alpha 等等。

逐个调整 pipline 中的参数是不明智的,我们需要一个穷举搜索方法(exhaustive search)帮助我们寻找参数网格中最好的参数组合。

>>> from sklearn.grid_search import GridSearchCV >>> parameters = {'vect__ngram_range': [(1, 1), (1, 2)], ... 'tfidf__use_idf': (True, False), ... 'clf__alpha': (1e-2, 1e-3), ... }

显然的,穷举搜索方法开销是较大的。如果我们有多核 CPU,我们可以通过设置 n_jobs = -1,让网格搜索计算时使用所有的 cpu 进行并行计算:

>>> gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)

网格搜索实例和普通的 sklearn 模型一样。让我们在一个较小的数据集中机械能网格搜索,以缩短计算时间:

>>> gs_clf = gs_clf.fit(twenty_train.data[:400], twenty_train.target[:400])

GridSearchCV 的 fit 方法返回一个分类器,我们可以使用它进行预测:

>>> twenty_train.target_names[gs_clf.predict(['God is love'])] 'soc.religion.christian'

但是另一方面,这个分类器是相当巨大和笨拙的。我们可以使用 grid_scores_ 属性从该对象中获取最佳的参数列表。

>>> best_parameters, score, _ = max(gs_clf.grid_scores_, key=lambda x: x[1]) >>> for param_name in sorted(parameters.keys()): ... print("%s: %r" % (param_name, best_parameters[param_name])) ... clf__alpha: 0.001 tfidf__use_idf: True vect__ngram_range: (1, 1) >>> score 0.900...

练习题连接


#路在何方
以下是几点建议可以帮助你在学完本指导后,在 sklearn 路上走得更远:

  • 尝试玩一玩 CountVectorizer 下的 analyzer 和 token normalisation
  • 如果你没有类属性,尝试使用聚类方法获得
  • 如果每个文档有多个类属性,可以看看 Multiclass and multilabel section
  • 尝试使用 Truncated SVD 进行潜在语义分析
  • 使用 Out-of-core 分类方法学习没办法加载到主存的数据
  • 尝试使用 Hashing Vectorizer 替换 CountVectorizer
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    556 引用 • 675 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • TextBundle

    TextBundle 文件格式旨在应用程序之间交换 Markdown 或 Fountain 之类的纯文本文件时,提供更无缝的用户体验。

    1 引用 • 2 回帖 • 82 关注
  • 职场

    找到自己的位置,萌新烦恼少。

    127 引用 • 1708 回帖
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3198 引用 • 8215 回帖
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 507 关注
  • DNSPod

    DNSPod 建立于 2006 年 3 月份,是一款免费智能 DNS 产品。 DNSPod 可以为同时有电信、网通、教育网服务器的网站提供智能的解析,让电信用户访问电信的服务器,网通的用户访问网通的服务器,教育网的用户访问教育网的服务器,达到互联互通的效果。

    6 引用 • 26 回帖 • 531 关注
  • 印象笔记
    3 引用 • 16 回帖 • 2 关注
  • iOS

    iOS 是由苹果公司开发的移动操作系统,最早于 2007 年 1 月 9 日的 Macworld 大会上公布这个系统,最初是设计给 iPhone 使用的,后来陆续套用到 iPod touch、iPad 以及 Apple TV 等产品上。iOS 与苹果的 Mac OS X 操作系统一样,属于类 Unix 的商业操作系统。

    88 引用 • 139 回帖
  • 大疆创新

    深圳市大疆创新科技有限公司(DJI-Innovations,简称 DJI),成立于 2006 年,是全球领先的无人飞行器控制系统及无人机解决方案的研发和生产商,客户遍布全球 100 多个国家。通过持续的创新,大疆致力于为无人机工业、行业用户以及专业航拍应用提供性能最强、体验最佳的革命性智能飞控产品和解决方案。

    2 引用 • 14 回帖 • 1 关注
  • RYMCU

    RYMCU 致力于打造一个即严谨又活泼、专业又不失有趣,为数百万人服务的开源嵌入式知识学习交流平台。

    4 引用 • 6 回帖 • 55 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    7 引用 • 69 回帖 • 1 关注
  • V2EX

    V2EX 是创意工作者们的社区。这里目前汇聚了超过 400,000 名主要来自互联网行业、游戏行业和媒体行业的创意工作者。V2EX 希望能够成为创意工作者们的生活和事业的一部分。

    16 引用 • 236 回帖 • 267 关注
  • 微服务

    微服务架构是一种架构模式,它提倡将单一应用划分成一组小的服务。服务之间互相协调,互相配合,为用户提供最终价值。每个服务运行在独立的进程中。服务于服务之间才用轻量级的通信机制互相沟通。每个服务都围绕着具体业务构建,能够被独立的部署。

    96 引用 • 155 回帖 • 1 关注
  • 区块链

    区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。所谓共识机制是区块链系统中实现不同节点之间建立信任、获取权益的数学算法 。

    92 引用 • 752 回帖 • 1 关注
  • 小薇

    小薇是一个用 Java 写的 QQ 聊天机器人 Web 服务,可以用于社群互动。

    由于 Smart QQ 从 2019 年 1 月 1 日起停止服务,所以该项目也已经停止维护了!

    35 引用 • 468 回帖 • 760 关注
  • 设计模式

    设计模式(Design pattern)代表了最佳的实践,通常被有经验的面向对象的软件开发人员所采用。设计模式是软件开发人员在软件开发过程中面临的一般问题的解决方案。这些解决方案是众多软件开发人员经过相当长的一段时间的试验和错误总结出来的。

    200 引用 • 120 回帖 • 3 关注
  • PHP

    PHP(Hypertext Preprocessor)是一种开源脚本语言。语法吸收了 C 语言、 Java 和 Perl 的特点,主要适用于 Web 开发领域,据说是世界上最好的编程语言。

    180 引用 • 408 回帖 • 489 关注
  • OpenResty

    OpenResty 是一个基于 NGINX 与 Lua 的高性能 Web 平台,其内部集成了大量精良的 Lua 库、第三方模块以及大多数的依赖项。用于方便地搭建能够处理超高并发、扩展性极高的动态 Web 应用、Web 服务和动态网关。

    17 引用 • 57 关注
  • danl
    165 关注
  • Windows

    Microsoft Windows 是美国微软公司研发的一套操作系统,它问世于 1985 年,起初仅仅是 Microsoft-DOS 模拟环境,后续的系统版本由于微软不断的更新升级,不但易用,也慢慢的成为家家户户人们最喜爱的操作系统。

    227 引用 • 476 回帖 • 1 关注
  • GitHub

    GitHub 于 2008 年上线,目前,除了 Git 代码仓库托管及基本的 Web 管理界面以外,还提供了订阅、讨论组、文本渲染、在线文件编辑器、协作图谱(报表)、代码片段分享(Gist)等功能。正因为这些功能所提供的便利,又经过长期的积累,GitHub 的用户活跃度很高,在开源世界里享有深远的声望,并形成了社交化编程文化(Social Coding)。

    210 引用 • 2040 回帖
  • WebClipper

    Web Clipper 是一款浏览器剪藏扩展,它可以帮助你把网页内容剪藏到本地。

    3 引用 • 9 回帖 • 3 关注
  • MongoDB

    MongoDB(来自于英文单词“Humongous”,中文含义为“庞大”)是一个基于分布式文件存储的数据库,由 C++ 语言编写。旨在为应用提供可扩展的高性能数据存储解决方案。MongoDB 是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数据库的。它支持的数据结构非常松散,是类似 JSON 的 BSON 格式,因此可以存储比较复杂的数据类型。

    90 引用 • 59 回帖 • 3 关注
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖
  • BookxNote

    BookxNote 是一款全新的电子书学习工具,助力您的学习与思考,让您的大脑更高效的记忆。

    笔记整理交给我,一心只读圣贤书。

    1 引用 • 1 回帖 • 2 关注
  • App

    App(应用程序,Application 的缩写)一般指手机软件。

    91 引用 • 384 回帖
  • AWS
    11 引用 • 28 回帖 • 11 关注
  • 创业

    你比 99% 的人都优秀么?

    82 引用 • 1395 回帖 • 2 关注