简介:使用scikit-learn进行数据挖掘

本贴最后更新于 3341 天前,其中的信息可能已经渤澥桑田

该简介翻译自 An introduction to machine learning with scikit-learn
选择翻译这篇简介的原因很简单:

  • scikit-learn 是非常优秀的 python 机器学习库
  • 该篇写得非常好,即使不使用 sklearn,也可以作为数据挖掘入门的短文。

以下是翻译的内容。


#本节内容

在本章节中,我们介绍一些在 sklearn 中会使用到的机器学习专业名词,并给出一些简单的例子。

机器学习:问题设定

通常来说,学习问题关注样本大小为 n 的数据集,并尝试预测未知的数据集。若每个样本不只是一个简单的数字,而是一个多维的条目,我们称之有多个属性或特征。

我们可以把学习问题划分为几个大的类别:

  • 有监督学习(supervised learning),在这种学习问题中,数据会附带我们要预测的属性。有监督学习可以进而分为以下两类:
    • 分类(classification):样本属于两个或多个分类,我们要从已经标记类别的数据中学习,并对未标记类别的数据进行预测。分类问题的一个典型例子是识别手写数字,该问题的目的识别每个输入向量对应的有限且离散的数字。换句话说分类问题是,离散形式(相对于连续)的有监督学习,提供的 n 个样本的类别是有限的,我们尝试为每个样本标记正确的分类。
    • 回归(regression):若输出的期望值是 1 个或多个连续变量,我们称该问题为回归。回归问题的一个典型例子是通过三文鱼的年龄和重量,预测其长度。
  • 无监督学习(unsupervised learning),在这种学习问题中,训练数据集是不包含任何目标值的输入向量 x。学习的目的有多种:
    • 聚类(clustering),发现数据中相似的样本分组。
    • 密度估计(density estimation),通过输入空间确定数据的分布。
    • 为了数据可视化或其他目的,将多维空间降低至 2 或 3 维

训练集和测试集
可粗略认为,机器学习就是从一个数据集中学习隐含的规则,并应用到新的数据集上。因此在机器学习实践中,为了评估算法,总是强制把数据集分为两个部分:训练集,用于学习隐含规则;测试集,用于测试规则。


#加载样例数据集
scikit-learn 自带了几个标准数据集,例如用于分类的 iris 和 digits 数据集,用于回归的 boston house prices 数据集。

接下来,我们使用 Python 交互式环境加载 iris 和 digits 数据集。
我们约定用 '$'表示 shell 类型,>>> 表示 python 交互环境。

$ python >>> from sklearn import datasets >>> iris = datasets.load_iris() >>> digits = datasets.load_digits()

数据集是一个类字典对象,包括了全部的数据和该数据的元数据。数据保存在 .data 成员中,该成员是(n 个向量*m 个特征)的数组。在有监督学习中,类别变量存储在 .target 成员中。
例如,在 digits 数据集中,通过 digits.data 可以获取用于分类的向量。

>>> print(digits.data) [[ 0. 0. 5. ..., 0. 0. 0.] [ 0. 0. 0. ..., 10. 0. 0.] [ 0. 0. 0. ..., 16. 9. 0.] ..., [ 0. 0. 1. ..., 6. 0. 0.] [ 0. 0. 2. ..., 12. 0. 0.] [ 0. 0. 10. ..., 12. 1. 0.]]

digits.target 中存储了 digits 数据集中对应每个向量的类别,也是我们预测的目标。

>>> digits.target array([0, 1, 2, ..., 8, 9, 8])

数据格式
数据集总是一个二维数组,格式为(n 个向量 * m 个特征),尽管原始数据可能是其他不同的格式。在 digits 数据集中,每个原始数据是用(8,8)表示的图像(在 digits.data 中被压缩到一行):

>>> digits.images[0] array([[ 0., 0., 5., 13., 9., 1., 0., 0.], [ 0., 0., 13., 15., 10., 15., 5., 0.], [ 0., 3., 15., 2., 0., 11., 8., 0.], [ 0., 4., 12., 0., 0., 8., 8., 0.], [ 0., 5., 8., 0., 0., 9., 8., 0.], [ 0., 4., 11., 0., 1., 12., 7., 0.], [ 0., 2., 14., 5., 10., 12., 0., 0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]])

#学习和预测
在 digits 数据集中,目标是预测给定的图像数据代表的数字。我们知道训练样本对应的分类(数字 0 到 9),训练对应的 estimator,用于预测未知分类的图像。

在 scikit-learn 中,用于分类的 estimator 是一个实现了 fit(X, y)predict(T) 的 Python 对象。

实现了支持向量分类的 sklearn.svm.SVC 类就是一个 estimator。estimator 的构造函数接受模型的参数。但暂时,我们把 estimator 当作一个黑盒:

>>> from sklearn import svm >>> clf = svm.SVC(gamma=0.001, C=100.)

选择模型的参数
这上面的例子中,我们手动地设置 gamma 的值。通过使用类似于 grid search 或 cross validation 工具,可以自动地寻找适合的参数。

上面例子将我们的 estimator 实例命名为 clf,因为其是一个分类器(classifier)。现在,需要将其通过学习调整对应模型。这个过程通过将训练数据集传给 fit 方法来实现。我们用除了最后一个图像的 digits 数据集作为训练数据集,在 python 中可以方便地使用[:-1]来构造训练集:

>>> clf.fit(digits.data[:-1], digits.target[:-1]) SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False)

现在我们可以用该模型对新数据进行预测,可以询问模型刚才没有使用的最后一个图像对应的数字:

>>> clf.predict(digits.data[-1:]) array([8])

最后一个图像数据对应的图像如下:

digit imag

如你所见,这确实是一个具有挑战性的任务:图像的分辨率特别差。你同意分类器的判定吗?

这里给出一个完整的分类问题的例子:Recognizing hand-written digits,你可以执行这个代码,并进行学习。


#模型持久化
通过 Python 内建的序列化模块 pickle,可以将 sklearn 中的模型进行持久化。

>>> from sklearn import svm >>> from sklearn import datasets >>> clf = svm.SVC() >>> iris = datasets.load_iris() >>> X, y = iris.data, iris.target >>> clf.fit(X, y) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> import pickle >>> s = pickle.dumps(clf) >>> clf2 = pickle.loads(s) >>> clf2.predict(X[0:1]) array([0]) >>> y[0] 0

特别的,在 sklearn 中,可以使用 joblib 替代 pickle (joblib.dump 和 joblib.load),joblib 在大数据上表现更加高效,但只能序列化到磁盘中,而非字符串。

>>> from sklearn.externals import joblib >>> joblib.dump(clf, 'filename.pkl')

然后,你可以重新读取并反序列化该模型(可能在另外的一个 python 程序中):

>>> clf = joblib.load('filename.pkl')


joblib.dump 返回一个文件名列表。clf 对象中包含的每一个单独的 numpy 数组会被序列化为文件系统中的一个单独文件。当使用 joblib.load 读取模型时,文件夹下的每个文件都是必要的。

注意 pickle 有一些安全性和可维护性的问题。参考 Model persistence,获取更多有关 sklearn 中模型持久化的信息。


#惯例
scikit-learn 中的 estimator 遵循以下的规则,好让他们的行为更加可预测。

##类型转换
除非明确指明,否则输入将会被强制转换为 float64

>>> import numpy as np >>> from sklearn import random_projection >>> rng = np.random.RandomState(0) >>> X = rng.rand(10, 2000) >>> X = np.array(X, dtype='float32') >>> X.dtype dtype('float32') >>> transformer = random_projection.GaussianRandomProjection() >>> X_new = transformer.fit_transform(X) >>> X_new.dtype dtype('float64')

在上面例子中,X 的类型为 float32,通过 .fit_transform(X) 被转化为 float64

回归的结果被转化为 float32, 分类的结果保持不变:

>>> from sklearn import datasets >>> from sklearn.svm import SVC >>> iris = datasets.load_iris() >>> clf = SVC() >>> clf.fit(iris.data, iris.target) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> list(clf.predict(iris.data[:3])) [0, 0, 0] >>> clf.fit(iris.data, iris.target_names[iris.target]) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> list(clf.predict(iris.data[:3])) ['setosa', 'setosa', 'setosa']

在上面例子中,第一个 predict() 返回整数数组,因为用于训练的 iris.target 是整数数组。第二个 predict() 返回字符串数组,因为用于训练的 iris.target_names 是字符串数组。

##改变和升级参数
通过 sklearn.pipeline.Pipeline.set_params 方法 estimator 的超参数在构造后仍然可以修改。通过多次调用 fit() 方法可以覆盖之前的 fit()

>>> import numpy as np >>> from sklearn.svm import SVC >>> rng = np.random.RandomState(0) >>> X = rng.rand(100, 10) >>> y = rng.binomial(1, 0.5, 100) >>> X_test = rng.rand(5, 10) >>> clf = SVC() >>> clf.set_params(kernel='linear').fit(X, y) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='linear', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> clf.predict(X_test) array([1, 0, 1, 1, 0]) >>> clf.set_params(kernel='rbf').fit(X, y) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> clf.predict(X_test) array([0, 0, 0, 1, 0])

在该例子中,SVC()构造函数中设定了的默认核函数为 rbf,但是随后被改为 linear 并训练模型,然后又重新修改为 rbf 并重新训练模型。

  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    554 引用 • 675 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • R

    👍 楼主翻译的?

  • Zing
    作者

    @R 嗯嗯 是的 我翻译的 水平有限

  • R via Android

    @Zing 挺好的 @88250 编辑记录功能调整下,都可以弄个文档翻译区了

  • wizardforcel

    什么都好。。就是官方的教程太少了。。

推荐标签 标签

  • 链书

    链书(Chainbook)是 B3log 开源社区提供的区块链纸质书交易平台,通过 B3T 实现共享激励与价值链。可将你的闲置书籍上架到链书,我们共同构建这个全新的交易平台,让闲置书籍继续发挥它的价值。

    链书社

    链书目前已经下线,也许以后还有计划重制上线。

    14 引用 • 257 回帖 • 2 关注
  • 外包

    有空闲时间是接外包好呢还是学习好呢?

    26 引用 • 233 回帖 • 3 关注
  • 印象笔记
    3 引用 • 16 回帖
  • Vim

    Vim 是类 UNIX 系统文本编辑器 Vi 的加强版本,加入了更多特性来帮助编辑源代码。Vim 的部分增强功能包括文件比较(vimdiff)、语法高亮、全面的帮助系统、本地脚本(Vimscript)和便于选择的可视化模式。

    29 引用 • 66 回帖
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 518 关注
  • GitLab

    GitLab 是利用 Ruby 一个开源的版本管理系统,实现一个自托管的 Git 项目仓库,可通过 Web 界面操作公开或私有项目。

    46 引用 • 72 回帖 • 2 关注
  • WiFiDog

    WiFiDog 是一套开源的无线热点认证管理工具,主要功能包括:位置相关的内容递送;用户认证和授权;集中式网络监控。

    1 引用 • 7 回帖 • 614 关注
  • 面试

    面试造航母,上班拧螺丝。多面试,少加班。

    326 引用 • 1395 回帖
  • Log4j

    Log4j 是 Apache 开源的一款使用广泛的 Java 日志组件。

    20 引用 • 18 回帖 • 34 关注
  • PHP

    PHP(Hypertext Preprocessor)是一种开源脚本语言。语法吸收了 C 语言、 Java 和 Perl 的特点,主要适用于 Web 开发领域,据说是世界上最好的编程语言。

    167 引用 • 408 回帖 • 489 关注
  • BookxNote

    BookxNote 是一款全新的电子书学习工具,助力您的学习与思考,让您的大脑更高效的记忆。

    笔记整理交给我,一心只读圣贤书。

    1 引用 • 1 回帖 • 1 关注
  • Ngui

    Ngui 是一个 GUI 的排版显示引擎和跨平台的 GUI 应用程序开发框架,基于
    Node.js / OpenGL。目标是在此基础上开发 GUI 应用程序可拥有开发 WEB 应用般简单与速度同时兼顾 Native 应用程序的性能与体验。

    7 引用 • 9 回帖 • 405 关注
  • Caddy

    Caddy 是一款默认自动启用 HTTPS 的 HTTP/2 Web 服务器。

    10 引用 • 54 回帖 • 181 关注
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    167 引用 • 597 回帖 • 1 关注
  • Sublime

    Sublime Text 是一款可以用来写代码、写文章的文本编辑器。支持代码高亮、自动完成,还支持通过插件进行扩展。

    10 引用 • 5 回帖 • 3 关注
  • 一些有用的避坑指南。

    69 引用 • 93 回帖
  • Ant-Design

    Ant Design 是服务于企业级产品的设计体系,基于确定和自然的设计价值观上的模块化解决方案,让设计者和开发者专注于更好的用户体验。

    17 引用 • 23 回帖 • 1 关注
  • 倾城之链
    23 引用 • 66 回帖 • 166 关注
  • TGIF

    Thank God It's Friday! 感谢老天,总算到星期五啦!

    291 引用 • 4495 回帖 • 662 关注
  • 互联网

    互联网(Internet),又称网际网络,或音译因特网、英特网。互联网始于 1969 年美国的阿帕网,是网络与网络之间所串连成的庞大网络,这些网络以一组通用的协议相连,形成逻辑上的单一巨大国际网络。

    98 引用 • 367 回帖
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 707 关注
  • Facebook

    Facebook 是一个联系朋友的社交工具。大家可以通过它和朋友、同事、同学以及周围的人保持互动交流,分享无限上传的图片,发布链接和视频,更可以增进对朋友的了解。

    4 引用 • 15 回帖 • 445 关注
  • JVM

    JVM(Java Virtual Machine)Java 虚拟机是一个微型操作系统,有自己的硬件构架体系,还有相应的指令系统。能够识别 Java 独特的 .class 文件(字节码),能够将这些文件中的信息读取出来,使得 Java 程序只需要生成 Java 虚拟机上的字节码后就能在不同操作系统平台上进行运行。

    180 引用 • 120 回帖 • 4 关注
  • BAE

    百度应用引擎(Baidu App Engine)提供了 PHP、Java、Python 的执行环境,以及云存储、消息服务、云数据库等全面的云服务。它可以让开发者实现自动地部署和管理应用,并且提供动态扩容和负载均衡的运行环境,让开发者不用考虑高成本的运维工作,只需专注于业务逻辑,大大降低了开发者学习和迁移的成本。

    19 引用 • 75 回帖 • 678 关注
  • 运维

    互联网运维工作,以服务为中心,以稳定、安全、高效为三个基本点,确保公司的互联网业务能够 7×24 小时为用户提供高质量的服务。

    151 引用 • 257 回帖 • 2 关注
  • 周末

    星期六到星期天晚,实行五天工作制后,指每周的最后两天。再过几年可能就是三天了。

    14 引用 • 297 回帖
  • 京东

    京东是中国最大的自营式电商企业,2015 年第一季度在中国自营式 B2C 电商市场的占有率为 56.3%。2014 年 5 月,京东在美国纳斯达克证券交易所正式挂牌上市(股票代码:JD),是中国第一个成功赴美上市的大型综合型电商平台,与腾讯、百度等中国互联网巨头共同跻身全球前十大互联网公司排行榜。

    14 引用 • 102 回帖 • 312 关注