简介:使用scikit-learn进行数据挖掘

本贴最后更新于 3152 天前,其中的信息可能已经渤澥桑田

该简介翻译自 An introduction to machine learning with scikit-learn
选择翻译这篇简介的原因很简单:

  • scikit-learn 是非常优秀的 python 机器学习库
  • 该篇写得非常好,即使不使用 sklearn,也可以作为数据挖掘入门的短文。

以下是翻译的内容。


#本节内容

在本章节中,我们介绍一些在 sklearn 中会使用到的机器学习专业名词,并给出一些简单的例子。

机器学习:问题设定

通常来说,学习问题关注样本大小为 n 的数据集,并尝试预测未知的数据集。若每个样本不只是一个简单的数字,而是一个多维的条目,我们称之有多个属性或特征。

我们可以把学习问题划分为几个大的类别:

  • 有监督学习(supervised learning),在这种学习问题中,数据会附带我们要预测的属性。有监督学习可以进而分为以下两类:
    • 分类(classification):样本属于两个或多个分类,我们要从已经标记类别的数据中学习,并对未标记类别的数据进行预测。分类问题的一个典型例子是识别手写数字,该问题的目的识别每个输入向量对应的有限且离散的数字。换句话说分类问题是,离散形式(相对于连续)的有监督学习,提供的 n 个样本的类别是有限的,我们尝试为每个样本标记正确的分类。
    • 回归(regression):若输出的期望值是 1 个或多个连续变量,我们称该问题为回归。回归问题的一个典型例子是通过三文鱼的年龄和重量,预测其长度。
  • 无监督学习(unsupervised learning),在这种学习问题中,训练数据集是不包含任何目标值的输入向量 x。学习的目的有多种:
    • 聚类(clustering),发现数据中相似的样本分组。
    • 密度估计(density estimation),通过输入空间确定数据的分布。
    • 为了数据可视化或其他目的,将多维空间降低至 2 或 3 维

训练集和测试集
可粗略认为,机器学习就是从一个数据集中学习隐含的规则,并应用到新的数据集上。因此在机器学习实践中,为了评估算法,总是强制把数据集分为两个部分:训练集,用于学习隐含规则;测试集,用于测试规则。


#加载样例数据集
scikit-learn 自带了几个标准数据集,例如用于分类的 iris 和 digits 数据集,用于回归的 boston house prices 数据集。

接下来,我们使用 Python 交互式环境加载 iris 和 digits 数据集。
我们约定用 '$'表示 shell 类型,>>> 表示 python 交互环境。

$ python
>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> digits = datasets.load_digits()

数据集是一个类字典对象,包括了全部的数据和该数据的元数据。数据保存在 .data 成员中,该成员是(n 个向量*m 个特征)的数组。在有监督学习中,类别变量存储在 .target 成员中。
例如,在 digits 数据集中,通过 digits.data 可以获取用于分类的向量。

>>> print(digits.data)  
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ...,
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

digits.target 中存储了 digits 数据集中对应每个向量的类别,也是我们预测的目标。

>>> digits.target
array([0, 1, 2, ..., 8, 9, 8])

数据格式
数据集总是一个二维数组,格式为(n 个向量 * m 个特征),尽管原始数据可能是其他不同的格式。在 digits 数据集中,每个原始数据是用(8,8)表示的图像(在 digits.data 中被压缩到一行):

>>> digits.images[0]
array([[  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.],
       [  0.,   0.,  13.,  15.,  10.,  15.,   5.,   0.],
       [  0.,   3.,  15.,   2.,   0.,  11.,   8.,   0.],
       [  0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.],
       [  0.,   5.,   8.,   0.,   0.,   9.,   8.,   0.],
       [  0.,   4.,  11.,   0.,   1.,  12.,   7.,   0.],
       [  0.,   2.,  14.,   5.,  10.,  12.,   0.,   0.],
       [  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])

#学习和预测
在 digits 数据集中,目标是预测给定的图像数据代表的数字。我们知道训练样本对应的分类(数字 0 到 9),训练对应的 estimator,用于预测未知分类的图像。

在 scikit-learn 中,用于分类的 estimator 是一个实现了 fit(X, y)predict(T) 的 Python 对象。

实现了支持向量分类的 sklearn.svm.SVC 类就是一个 estimator。estimator 的构造函数接受模型的参数。但暂时,我们把 estimator 当作一个黑盒:

>>> from sklearn import svm
>>> clf = svm.SVC(gamma=0.001, C=100.)

选择模型的参数
这上面的例子中,我们手动地设置 gamma 的值。通过使用类似于 grid search 或 cross validation 工具,可以自动地寻找适合的参数。

上面例子将我们的 estimator 实例命名为 clf,因为其是一个分类器(classifier)。现在,需要将其通过学习调整对应模型。这个过程通过将训练数据集传给 fit 方法来实现。我们用除了最后一个图像的 digits 数据集作为训练数据集,在 python 中可以方便地使用[:-1]来构造训练集:

>>> clf.fit(digits.data[:-1], digits.target[:-1])  
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)  

现在我们可以用该模型对新数据进行预测,可以询问模型刚才没有使用的最后一个图像对应的数字:

>>> clf.predict(digits.data[-1:])
array([8])  

最后一个图像数据对应的图像如下:

digit imag

如你所见,这确实是一个具有挑战性的任务:图像的分辨率特别差。你同意分类器的判定吗?

这里给出一个完整的分类问题的例子:Recognizing hand-written digits,你可以执行这个代码,并进行学习。


#模型持久化
通过 Python 内建的序列化模块 pickle,可以将 sklearn 中的模型进行持久化。

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0:1])
array([0])
>>> y[0]
0

特别的,在 sklearn 中,可以使用 joblib 替代 pickle (joblib.dump 和 joblib.load),joblib 在大数据上表现更加高效,但只能序列化到磁盘中,而非字符串。

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'filename.pkl')   

然后,你可以重新读取并反序列化该模型(可能在另外的一个 python 程序中):

>>> clf = joblib.load('filename.pkl') 


joblib.dump 返回一个文件名列表。clf 对象中包含的每一个单独的 numpy 数组会被序列化为文件系统中的一个单独文件。当使用 joblib.load 读取模型时,文件夹下的每个文件都是必要的。

注意 pickle 有一些安全性和可维护性的问题。参考 Model persistence,获取更多有关 sklearn 中模型持久化的信息。


#惯例
scikit-learn 中的 estimator 遵循以下的规则,好让他们的行为更加可预测。

##类型转换
除非明确指明,否则输入将会被强制转换为 float64

>>> import numpy as np
>>> from sklearn import random_projection

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(10, 2000)
>>> X = np.array(X, dtype='float32')
>>> X.dtype
dtype('float32')

>>> transformer = random_projection.GaussianRandomProjection()
>>> X_new = transformer.fit_transform(X)
>>> X_new.dtype
dtype('float64')

在上面例子中,X 的类型为 float32,通过 .fit_transform(X) 被转化为 float64

回归的结果被转化为 float32, 分类的结果保持不变:

>>> from sklearn import datasets
>>> from sklearn.svm import SVC
>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))  
['setosa', 'setosa', 'setosa']  

在上面例子中,第一个 predict() 返回整数数组,因为用于训练的 iris.target 是整数数组。第二个 predict() 返回字符串数组,因为用于训练的 iris.target_names 是字符串数组。

##改变和升级参数
通过 sklearn.pipeline.Pipeline.set_params 方法 estimator 的超参数在构造后仍然可以修改。通过多次调用 fit() 方法可以覆盖之前的 fit()

>>> import numpy as np
>>> from sklearn.svm import SVC

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(100, 10)
>>> y = rng.binomial(1, 0.5, 100)
>>> X_test = rng.rand(5, 10)

>>> clf = SVC()
>>> clf.set_params(kernel='linear').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([1, 0, 1, 1, 0])

>>> clf.set_params(kernel='rbf').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([0, 0, 0, 1, 0])

在该例子中,SVC()构造函数中设定了的默认核函数为 rbf,但是随后被改为 linear 并训练模型,然后又重新修改为 rbf 并重新训练模型。

  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    545 引用 • 672 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • wizardforcel

    什么都好。。就是官方的教程太少了。。

  • 其他回帖
  • R

    👍 楼主翻译的?

  • R

    @Zing 挺好的 @88250 编辑记录功能调整下,都可以弄个文档翻译区了

  • Zing
    作者

    @R 嗯嗯 是的 我翻译的 水平有限

推荐标签 标签

  • 工具

    子曰:“工欲善其事,必先利其器。”

    288 引用 • 734 回帖
  • PostgreSQL

    PostgreSQL 是一款功能强大的企业级数据库系统,在 BSD 开源许可证下发布。

    22 引用 • 22 回帖
  • 分享

    有什么新发现就分享给大家吧!

    248 引用 • 1795 回帖
  • Solo

    Solo 是一款小而美的开源博客系统,专为程序员设计。Solo 有着非常活跃的社区,可将文章作为帖子推送到社区,来自社区的回帖将作为博客评论进行联动(具体细节请浏览 B3log 构思 - 分布式社区网络)。

    这是一种全新的网络社区体验,让热爱记录和分享的你不再感到孤单!

    1435 引用 • 10056 回帖 • 489 关注
  • SpaceVim

    SpaceVim 是一个社区驱动的模块化 vim/neovim 配置集合,以模块的方式组织管理插件以
    及相关配置,为不同的语言开发量身定制了相关的开发模块,该模块提供代码自动补全,
    语法检查、格式化、调试、REPL 等特性。用户仅需载入相关语言的模块即可得到一个开箱
    即用的 Vim-IDE。

    3 引用 • 31 回帖 • 105 关注
  • jsoup

    jsoup 是一款 Java 的 HTML 解析器,可直接解析某个 URL 地址、HTML 文本内容。它提供了一套非常省力的 API,可通过 DOM,CSS 以及类似于 jQuery 的操作方法来取出和操作数据。

    6 引用 • 1 回帖 • 484 关注
  • 爬虫

    网络爬虫(Spider、Crawler),是一种按照一定的规则,自动地抓取万维网信息的程序。

    106 引用 • 275 回帖 • 1 关注
  • Gzip

    gzip (GNU zip)是 GNU 自由软件的文件压缩程序。我们在 Linux 中经常会用到后缀为 .gz 的文件,它们就是 Gzip 格式的。现今已经成为互联网上使用非常普遍的一种数据压缩格式,或者说一种文件格式。

    9 引用 • 12 回帖 • 147 关注
  • Typecho

    Typecho 是一款博客程序,它在 GPLv2 许可证下发行,基于 PHP 构建,可以运行在各种平台上,支持多种数据库(MySQL、PostgreSQL、SQLite)。

    12 引用 • 65 回帖 • 445 关注
  • 七牛云

    七牛云是国内领先的企业级公有云服务商,致力于打造以数据为核心的场景化 PaaS 服务。围绕富媒体场景,七牛先后推出了对象存储,融合 CDN 加速,数据通用处理,内容反垃圾服务,以及直播云服务等。

    27 引用 • 225 回帖 • 163 关注
  • 996
    13 引用 • 200 回帖 • 10 关注
  • Flume

    Flume 是一套分布式的、可靠的,可用于有效地收集、聚合和搬运大量日志数据的服务架构。

    9 引用 • 6 回帖 • 637 关注
  • V2Ray
    1 引用 • 15 回帖 • 1 关注
  • TGIF

    Thank God It's Friday! 感谢老天,总算到星期五啦!

    288 引用 • 4485 回帖 • 663 关注
  • 智能合约

    智能合约(Smart contract)是一种旨在以信息化方式传播、验证或执行合同的计算机协议。智能合约允许在没有第三方的情况下进行可信交易,这些交易可追踪且不可逆转。智能合约概念于 1994 年由 Nick Szabo 首次提出。

    1 引用 • 11 回帖 • 2 关注
  • Caddy

    Caddy 是一款默认自动启用 HTTPS 的 HTTP/2 Web 服务器。

    12 引用 • 54 回帖 • 159 关注
  • Docker

    Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的容器中,然后发布到任何流行的操作系统上。容器完全使用沙箱机制,几乎没有性能开销,可以很容易地在机器和数据中心中运行。

    492 引用 • 926 回帖
  • 反馈

    Communication channel for makers and users.

    123 引用 • 913 回帖 • 250 关注
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 683 关注
  • Kubernetes

    Kubernetes 是 Google 开源的一个容器编排引擎,它支持自动化部署、大规模可伸缩、应用容器化管理。

    110 引用 • 54 回帖 • 1 关注
  • 阿里云

    阿里云是阿里巴巴集团旗下公司,是全球领先的云计算及人工智能科技公司。提供云服务器、云数据库、云安全等云计算服务,以及大数据、人工智能服务、精准定制基于场景的行业解决方案。

    89 引用 • 345 回帖
  • CSS

    CSS(Cascading Style Sheet)“层叠样式表”是用于控制网页样式并允许将样式信息与网页内容分离的一种标记性语言。

    196 引用 • 540 回帖 • 1 关注
  • 区块链

    区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。所谓共识机制是区块链系统中实现不同节点之间建立信任、获取权益的数学算法 。

    91 引用 • 751 回帖 • 1 关注
  • SVN

    SVN 是 Subversion 的简称,是一个开放源代码的版本控制系统,相较于 RCS、CVS,它采用了分支管理系统,它的设计目标就是取代 CVS。

    29 引用 • 98 回帖 • 694 关注
  • SEO

    发布对别人有帮助的原创内容是最好的 SEO 方式。

    35 引用 • 200 回帖 • 27 关注
  • Electron

    Electron 基于 Chromium 和 Node.js,让你可以使用 HTML、CSS 和 JavaScript 构建应用。它是一个由 GitHub 及众多贡献者组成的活跃社区共同维护的开源项目,兼容 Mac、Windows 和 Linux,它构建的应用可在这三个操作系统上面运行。

    15 引用 • 136 回帖
  • Hibernate

    Hibernate 是一个开放源代码的对象关系映射框架,它对 JDBC 进行了非常轻量级的对象封装,使得 Java 程序员可以随心所欲的使用对象编程思维来操纵数据库。

    39 引用 • 103 回帖 • 715 关注