简介:使用scikit-learn进行数据挖掘

本贴最后更新于 2999 天前,其中的信息可能已经渤澥桑田

该简介翻译自 An introduction to machine learning with scikit-learn
选择翻译这篇简介的原因很简单:

  • scikit-learn 是非常优秀的 python 机器学习库
  • 该篇写得非常好,即使不使用 sklearn,也可以作为数据挖掘入门的短文。

以下是翻译的内容。


#本节内容

在本章节中,我们介绍一些在 sklearn 中会使用到的机器学习专业名词,并给出一些简单的例子。

机器学习:问题设定

通常来说,学习问题关注样本大小为 n 的数据集,并尝试预测未知的数据集。若每个样本不只是一个简单的数字,而是一个多维的条目,我们称之有多个属性或特征。

我们可以把学习问题划分为几个大的类别:

  • 有监督学习(supervised learning),在这种学习问题中,数据会附带我们要预测的属性。有监督学习可以进而分为以下两类:
    • 分类(classification):样本属于两个或多个分类,我们要从已经标记类别的数据中学习,并对未标记类别的数据进行预测。分类问题的一个典型例子是识别手写数字,该问题的目的识别每个输入向量对应的有限且离散的数字。换句话说分类问题是,离散形式(相对于连续)的有监督学习,提供的 n 个样本的类别是有限的,我们尝试为每个样本标记正确的分类。
    • 回归(regression):若输出的期望值是 1 个或多个连续变量,我们称该问题为回归。回归问题的一个典型例子是通过三文鱼的年龄和重量,预测其长度。
  • 无监督学习(unsupervised learning),在这种学习问题中,训练数据集是不包含任何目标值的输入向量 x。学习的目的有多种:
    • 聚类(clustering),发现数据中相似的样本分组。
    • 密度估计(density estimation),通过输入空间确定数据的分布。
    • 为了数据可视化或其他目的,将多维空间降低至 2 或 3 维

训练集和测试集
可粗略认为,机器学习就是从一个数据集中学习隐含的规则,并应用到新的数据集上。因此在机器学习实践中,为了评估算法,总是强制把数据集分为两个部分:训练集,用于学习隐含规则;测试集,用于测试规则。


#加载样例数据集
scikit-learn 自带了几个标准数据集,例如用于分类的 iris 和 digits 数据集,用于回归的 boston house prices 数据集。

接下来,我们使用 Python 交互式环境加载 iris 和 digits 数据集。
我们约定用 '$'表示 shell 类型,>>> 表示 python 交互环境。

$ python
>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> digits = datasets.load_digits()

数据集是一个类字典对象,包括了全部的数据和该数据的元数据。数据保存在 .data 成员中,该成员是(n 个向量*m 个特征)的数组。在有监督学习中,类别变量存储在 .target 成员中。
例如,在 digits 数据集中,通过 digits.data 可以获取用于分类的向量。

>>> print(digits.data)  
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ...,
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

digits.target 中存储了 digits 数据集中对应每个向量的类别,也是我们预测的目标。

>>> digits.target
array([0, 1, 2, ..., 8, 9, 8])

数据格式
数据集总是一个二维数组,格式为(n 个向量 * m 个特征),尽管原始数据可能是其他不同的格式。在 digits 数据集中,每个原始数据是用(8,8)表示的图像(在 digits.data 中被压缩到一行):

>>> digits.images[0]
array([[  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.],
       [  0.,   0.,  13.,  15.,  10.,  15.,   5.,   0.],
       [  0.,   3.,  15.,   2.,   0.,  11.,   8.,   0.],
       [  0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.],
       [  0.,   5.,   8.,   0.,   0.,   9.,   8.,   0.],
       [  0.,   4.,  11.,   0.,   1.,  12.,   7.,   0.],
       [  0.,   2.,  14.,   5.,  10.,  12.,   0.,   0.],
       [  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])

#学习和预测
在 digits 数据集中,目标是预测给定的图像数据代表的数字。我们知道训练样本对应的分类(数字 0 到 9),训练对应的 estimator,用于预测未知分类的图像。

在 scikit-learn 中,用于分类的 estimator 是一个实现了 fit(X, y)predict(T) 的 Python 对象。

实现了支持向量分类的 sklearn.svm.SVC 类就是一个 estimator。estimator 的构造函数接受模型的参数。但暂时,我们把 estimator 当作一个黑盒:

>>> from sklearn import svm
>>> clf = svm.SVC(gamma=0.001, C=100.)

选择模型的参数
这上面的例子中,我们手动地设置 gamma 的值。通过使用类似于 grid search 或 cross validation 工具,可以自动地寻找适合的参数。

上面例子将我们的 estimator 实例命名为 clf,因为其是一个分类器(classifier)。现在,需要将其通过学习调整对应模型。这个过程通过将训练数据集传给 fit 方法来实现。我们用除了最后一个图像的 digits 数据集作为训练数据集,在 python 中可以方便地使用[:-1]来构造训练集:

>>> clf.fit(digits.data[:-1], digits.target[:-1])  
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)  

现在我们可以用该模型对新数据进行预测,可以询问模型刚才没有使用的最后一个图像对应的数字:

>>> clf.predict(digits.data[-1:])
array([8])  

最后一个图像数据对应的图像如下:

digit imag

如你所见,这确实是一个具有挑战性的任务:图像的分辨率特别差。你同意分类器的判定吗?

这里给出一个完整的分类问题的例子:Recognizing hand-written digits,你可以执行这个代码,并进行学习。


#模型持久化
通过 Python 内建的序列化模块 pickle,可以将 sklearn 中的模型进行持久化。

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0:1])
array([0])
>>> y[0]
0

特别的,在 sklearn 中,可以使用 joblib 替代 pickle (joblib.dump 和 joblib.load),joblib 在大数据上表现更加高效,但只能序列化到磁盘中,而非字符串。

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'filename.pkl')   

然后,你可以重新读取并反序列化该模型(可能在另外的一个 python 程序中):

>>> clf = joblib.load('filename.pkl') 


joblib.dump 返回一个文件名列表。clf 对象中包含的每一个单独的 numpy 数组会被序列化为文件系统中的一个单独文件。当使用 joblib.load 读取模型时,文件夹下的每个文件都是必要的。

注意 pickle 有一些安全性和可维护性的问题。参考 Model persistence,获取更多有关 sklearn 中模型持久化的信息。


#惯例
scikit-learn 中的 estimator 遵循以下的规则,好让他们的行为更加可预测。

##类型转换
除非明确指明,否则输入将会被强制转换为 float64

>>> import numpy as np
>>> from sklearn import random_projection

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(10, 2000)
>>> X = np.array(X, dtype='float32')
>>> X.dtype
dtype('float32')

>>> transformer = random_projection.GaussianRandomProjection()
>>> X_new = transformer.fit_transform(X)
>>> X_new.dtype
dtype('float64')

在上面例子中,X 的类型为 float32,通过 .fit_transform(X) 被转化为 float64

回归的结果被转化为 float32, 分类的结果保持不变:

>>> from sklearn import datasets
>>> from sklearn.svm import SVC
>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))  
['setosa', 'setosa', 'setosa']  

在上面例子中,第一个 predict() 返回整数数组,因为用于训练的 iris.target 是整数数组。第二个 predict() 返回字符串数组,因为用于训练的 iris.target_names 是字符串数组。

##改变和升级参数
通过 sklearn.pipeline.Pipeline.set_params 方法 estimator 的超参数在构造后仍然可以修改。通过多次调用 fit() 方法可以覆盖之前的 fit()

>>> import numpy as np
>>> from sklearn.svm import SVC

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(100, 10)
>>> y = rng.binomial(1, 0.5, 100)
>>> X_test = rng.rand(5, 10)

>>> clf = SVC()
>>> clf.set_params(kernel='linear').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([1, 0, 1, 1, 0])

>>> clf.set_params(kernel='rbf').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([0, 0, 0, 1, 0])

在该例子中,SVC()构造函数中设定了的默认核函数为 rbf,但是随后被改为 linear 并训练模型,然后又重新修改为 rbf 并重新训练模型。

  • 数据挖掘
    17 引用 • 32 回帖 • 2 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    536 引用 • 672 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • R

    👍 楼主翻译的?

  • 其他回帖
  • wizardforcel

    什么都好。。就是官方的教程太少了。。

  • Zing
    作者

    @R 嗯嗯 是的 我翻译的 水平有限

  • R

    @Zing 挺好的 @88250 编辑记录功能调整下,都可以弄个文档翻译区了

推荐标签 标签

  • 前端

    前端技术一般分为前端设计和前端开发,前端设计可以理解为网站的视觉设计,前端开发则是网站的前台代码实现,包括 HTML、CSS 以及 JavaScript 等。

    247 引用 • 1347 回帖 • 2 关注
  • 微信

    腾讯公司 2011 年 1 月 21 日推出的一款手机通讯软件。用户可以通过摇一摇、搜索号码、扫描二维码等添加好友和关注公众平台,同时可以将自己看到的精彩内容分享到微信朋友圈。

    130 引用 • 793 回帖
  • Vue.js

    Vue.js(读音 /vju ː/,类似于 view)是一个构建数据驱动的 Web 界面库。Vue.js 的目标是通过尽可能简单的 API 实现响应的数据绑定和组合的视图组件。

    262 引用 • 664 回帖
  • 智能合约

    智能合约(Smart contract)是一种旨在以信息化方式传播、验证或执行合同的计算机协议。智能合约允许在没有第三方的情况下进行可信交易,这些交易可追踪且不可逆转。智能合约概念于 1994 年由 Nick Szabo 首次提出。

    1 引用 • 11 回帖 • 7 关注
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    164 引用 • 594 回帖
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • OpenShift

    红帽提供的 PaaS 云,支持多种编程语言,为开发人员提供了更为灵活的框架、存储选择。

    14 引用 • 20 回帖 • 606 关注
  • ZeroNet

    ZeroNet 是一个基于比特币加密技术和 BT 网络技术的去中心化的、开放开源的网络和交流系统。

    1 引用 • 21 回帖 • 609 关注
  • Sillot

    Insights(注意当前设置 master 为默认分支)

    汐洛彖夲肜矩阵(Sillot T☳Converbenk Matrix),致力于服务智慧新彖乄,具有彖乄驱动、极致优雅、开发者友好的特点。其中汐洛绞架(Sillot-Gibbet)基于自思源笔记(siyuan-note),前身是思源笔记汐洛版(更早是思源笔记汐洛分支),是智慧新录乄终端(多端融合,移动端优先)。

    主仓库地址:Hi-Windom/Sillot

    文档地址:sillot.db.sc.cn

    注意事项:

    1. ⚠️ 汐洛仍在早期开发阶段,尚不稳定
    2. ⚠️ 汐洛并非面向普通用户设计,使用前请了解风险
    3. ⚠️ 汐洛绞架基于思源笔记,开发者尽最大努力与思源笔记保持兼容,但无法实现 100% 兼容
    29 引用 • 25 回帖 • 53 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    5 引用 • 62 回帖
  • React

    React 是 Facebook 开源的一个用于构建 UI 的 JavaScript 库。

    192 引用 • 291 回帖 • 430 关注
  • 周末

    星期六到星期天晚,实行五天工作制后,指每周的最后两天。再过几年可能就是三天了。

    14 引用 • 297 回帖
  • 链滴

    链滴是一个记录生活的地方。

    记录生活,连接点滴

    143 引用 • 3752 回帖
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 566 关注
  • 友情链接

    确认过眼神后的灵魂连接,站在链在!

    24 引用 • 373 回帖 • 1 关注
  • HBase

    HBase 是一个分布式的、面向列的开源数据库,该技术来源于 Fay Chang 所撰写的 Google 论文 “Bigtable:一个结构化数据的分布式存储系统”。就像 Bigtable 利用了 Google 文件系统所提供的分布式数据存储一样,HBase 在 Hadoop 之上提供了类似于 Bigtable 的能力。

    17 引用 • 6 回帖 • 61 关注
  • SSL

    SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网络通信提供安全及数据完整性的一种安全协议。TLS 与 SSL 在传输层对网络连接进行加密。

    69 引用 • 190 回帖 • 474 关注
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖 • 1 关注
  • 京东

    京东是中国最大的自营式电商企业,2015 年第一季度在中国自营式 B2C 电商市场的占有率为 56.3%。2014 年 5 月,京东在美国纳斯达克证券交易所正式挂牌上市(股票代码:JD),是中国第一个成功赴美上市的大型综合型电商平台,与腾讯、百度等中国互联网巨头共同跻身全球前十大互联网公司排行榜。

    14 引用 • 102 回帖 • 403 关注
  • SMTP

    SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。SMTP 协议属于 TCP/IP 协议簇,它帮助每台计算机在发送或中转信件时找到下一个目的地。

    4 引用 • 18 回帖 • 609 关注
  • Mobi.css

    Mobi.css is a lightweight, flexible CSS framework that focus on mobile.

    1 引用 • 6 回帖 • 714 关注
  • 持续集成

    持续集成(Continuous Integration)是一种软件开发实践,即团队开发成员经常集成他们的工作,通过每个成员每天至少集成一次,也就意味着每天可能会发生多次集成。每次集成都通过自动化的构建(包括编译,发布,自动化测试)来验证,从而尽早地发现集成错误。

    14 引用 • 7 回帖 • 5 关注
  • RESTful

    一种软件架构设计风格而不是标准,提供了一组设计原则和约束条件,主要用于客户端和服务器交互类的软件。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。

    30 引用 • 114 回帖 • 2 关注
  • 思源笔记

    思源笔记是一款隐私优先的个人知识管理系统,支持完全离线使用,同时也支持端到端加密同步。

    融合块、大纲和双向链接,重构你的思维。

    20156 引用 • 77717 回帖
  • Eclipse

    Eclipse 是一个开放源代码的、基于 Java 的可扩展开发平台。就其本身而言,它只是一个框架和一组服务,用于通过插件组件构建开发环境。

    75 引用 • 258 回帖 • 632 关注
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 441 关注
  • 黑曜石

    黑曜石是一款强大的知识库工具,支持本地 Markdown 文件编辑,支持双向链接和关系图。

    A second brain, for you, forever.

    10 引用 • 88 回帖