简介:使用scikit-learn进行数据挖掘

本贴最后更新于 3105 天前,其中的信息可能已经渤澥桑田

该简介翻译自 An introduction to machine learning with scikit-learn
选择翻译这篇简介的原因很简单:

  • scikit-learn 是非常优秀的 python 机器学习库
  • 该篇写得非常好,即使不使用 sklearn,也可以作为数据挖掘入门的短文。

以下是翻译的内容。


#本节内容

在本章节中,我们介绍一些在 sklearn 中会使用到的机器学习专业名词,并给出一些简单的例子。

机器学习:问题设定

通常来说,学习问题关注样本大小为 n 的数据集,并尝试预测未知的数据集。若每个样本不只是一个简单的数字,而是一个多维的条目,我们称之有多个属性或特征。

我们可以把学习问题划分为几个大的类别:

  • 有监督学习(supervised learning),在这种学习问题中,数据会附带我们要预测的属性。有监督学习可以进而分为以下两类:
    • 分类(classification):样本属于两个或多个分类,我们要从已经标记类别的数据中学习,并对未标记类别的数据进行预测。分类问题的一个典型例子是识别手写数字,该问题的目的识别每个输入向量对应的有限且离散的数字。换句话说分类问题是,离散形式(相对于连续)的有监督学习,提供的 n 个样本的类别是有限的,我们尝试为每个样本标记正确的分类。
    • 回归(regression):若输出的期望值是 1 个或多个连续变量,我们称该问题为回归。回归问题的一个典型例子是通过三文鱼的年龄和重量,预测其长度。
  • 无监督学习(unsupervised learning),在这种学习问题中,训练数据集是不包含任何目标值的输入向量 x。学习的目的有多种:
    • 聚类(clustering),发现数据中相似的样本分组。
    • 密度估计(density estimation),通过输入空间确定数据的分布。
    • 为了数据可视化或其他目的,将多维空间降低至 2 或 3 维

训练集和测试集
可粗略认为,机器学习就是从一个数据集中学习隐含的规则,并应用到新的数据集上。因此在机器学习实践中,为了评估算法,总是强制把数据集分为两个部分:训练集,用于学习隐含规则;测试集,用于测试规则。


#加载样例数据集
scikit-learn 自带了几个标准数据集,例如用于分类的 iris 和 digits 数据集,用于回归的 boston house prices 数据集。

接下来,我们使用 Python 交互式环境加载 iris 和 digits 数据集。
我们约定用 '$'表示 shell 类型,>>> 表示 python 交互环境。

$ python
>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> digits = datasets.load_digits()

数据集是一个类字典对象,包括了全部的数据和该数据的元数据。数据保存在 .data 成员中,该成员是(n 个向量*m 个特征)的数组。在有监督学习中,类别变量存储在 .target 成员中。
例如,在 digits 数据集中,通过 digits.data 可以获取用于分类的向量。

>>> print(digits.data)  
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ...,
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

digits.target 中存储了 digits 数据集中对应每个向量的类别,也是我们预测的目标。

>>> digits.target
array([0, 1, 2, ..., 8, 9, 8])

数据格式
数据集总是一个二维数组,格式为(n 个向量 * m 个特征),尽管原始数据可能是其他不同的格式。在 digits 数据集中,每个原始数据是用(8,8)表示的图像(在 digits.data 中被压缩到一行):

>>> digits.images[0]
array([[  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.],
       [  0.,   0.,  13.,  15.,  10.,  15.,   5.,   0.],
       [  0.,   3.,  15.,   2.,   0.,  11.,   8.,   0.],
       [  0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.],
       [  0.,   5.,   8.,   0.,   0.,   9.,   8.,   0.],
       [  0.,   4.,  11.,   0.,   1.,  12.,   7.,   0.],
       [  0.,   2.,  14.,   5.,  10.,  12.,   0.,   0.],
       [  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])

#学习和预测
在 digits 数据集中,目标是预测给定的图像数据代表的数字。我们知道训练样本对应的分类(数字 0 到 9),训练对应的 estimator,用于预测未知分类的图像。

在 scikit-learn 中,用于分类的 estimator 是一个实现了 fit(X, y)predict(T) 的 Python 对象。

实现了支持向量分类的 sklearn.svm.SVC 类就是一个 estimator。estimator 的构造函数接受模型的参数。但暂时,我们把 estimator 当作一个黑盒:

>>> from sklearn import svm
>>> clf = svm.SVC(gamma=0.001, C=100.)

选择模型的参数
这上面的例子中,我们手动地设置 gamma 的值。通过使用类似于 grid search 或 cross validation 工具,可以自动地寻找适合的参数。

上面例子将我们的 estimator 实例命名为 clf,因为其是一个分类器(classifier)。现在,需要将其通过学习调整对应模型。这个过程通过将训练数据集传给 fit 方法来实现。我们用除了最后一个图像的 digits 数据集作为训练数据集,在 python 中可以方便地使用[:-1]来构造训练集:

>>> clf.fit(digits.data[:-1], digits.target[:-1])  
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)  

现在我们可以用该模型对新数据进行预测,可以询问模型刚才没有使用的最后一个图像对应的数字:

>>> clf.predict(digits.data[-1:])
array([8])  

最后一个图像数据对应的图像如下:

digit imag

如你所见,这确实是一个具有挑战性的任务:图像的分辨率特别差。你同意分类器的判定吗?

这里给出一个完整的分类问题的例子:Recognizing hand-written digits,你可以执行这个代码,并进行学习。


#模型持久化
通过 Python 内建的序列化模块 pickle,可以将 sklearn 中的模型进行持久化。

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0:1])
array([0])
>>> y[0]
0

特别的,在 sklearn 中,可以使用 joblib 替代 pickle (joblib.dump 和 joblib.load),joblib 在大数据上表现更加高效,但只能序列化到磁盘中,而非字符串。

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'filename.pkl')   

然后,你可以重新读取并反序列化该模型(可能在另外的一个 python 程序中):

>>> clf = joblib.load('filename.pkl') 


joblib.dump 返回一个文件名列表。clf 对象中包含的每一个单独的 numpy 数组会被序列化为文件系统中的一个单独文件。当使用 joblib.load 读取模型时,文件夹下的每个文件都是必要的。

注意 pickle 有一些安全性和可维护性的问题。参考 Model persistence,获取更多有关 sklearn 中模型持久化的信息。


#惯例
scikit-learn 中的 estimator 遵循以下的规则,好让他们的行为更加可预测。

##类型转换
除非明确指明,否则输入将会被强制转换为 float64

>>> import numpy as np
>>> from sklearn import random_projection

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(10, 2000)
>>> X = np.array(X, dtype='float32')
>>> X.dtype
dtype('float32')

>>> transformer = random_projection.GaussianRandomProjection()
>>> X_new = transformer.fit_transform(X)
>>> X_new.dtype
dtype('float64')

在上面例子中,X 的类型为 float32,通过 .fit_transform(X) 被转化为 float64

回归的结果被转化为 float32, 分类的结果保持不变:

>>> from sklearn import datasets
>>> from sklearn.svm import SVC
>>> iris = datasets.load_iris()
>>> clf = SVC()
>>> clf.fit(iris.data, iris.target)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))
[0, 0, 0]

>>> clf.fit(iris.data, iris.target_names[iris.target])  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

>>> list(clf.predict(iris.data[:3]))  
['setosa', 'setosa', 'setosa']  

在上面例子中,第一个 predict() 返回整数数组,因为用于训练的 iris.target 是整数数组。第二个 predict() 返回字符串数组,因为用于训练的 iris.target_names 是字符串数组。

##改变和升级参数
通过 sklearn.pipeline.Pipeline.set_params 方法 estimator 的超参数在构造后仍然可以修改。通过多次调用 fit() 方法可以覆盖之前的 fit()

>>> import numpy as np
>>> from sklearn.svm import SVC

>>> rng = np.random.RandomState(0)
>>> X = rng.rand(100, 10)
>>> y = rng.binomial(1, 0.5, 100)
>>> X_test = rng.rand(5, 10)

>>> clf = SVC()
>>> clf.set_params(kernel='linear').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([1, 0, 1, 1, 0])

>>> clf.set_params(kernel='rbf').fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
>>> clf.predict(X_test)
array([0, 0, 0, 1, 0])

在该例子中,SVC()构造函数中设定了的默认核函数为 rbf,但是随后被改为 linear 并训练模型,然后又重新修改为 rbf 并重新训练模型。

  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    541 引用 • 672 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • R

    👍 楼主翻译的?

  • Zing
    作者

    @R 嗯嗯 是的 我翻译的 水平有限

  • R

    @Zing 挺好的 @88250 编辑记录功能调整下,都可以弄个文档翻译区了

  • wizardforcel

    什么都好。。就是官方的教程太少了。。

推荐标签 标签

  • 单点登录

    单点登录(Single Sign On)是目前比较流行的企业业务整合的解决方案之一。SSO 的定义是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。

    9 引用 • 25 回帖
  • SOHO

    为成为自由职业者在家办公而努力吧!

    7 引用 • 55 回帖 • 18 关注
  • Elasticsearch

    Elasticsearch 是一个基于 Lucene 的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎,基于 RESTful 接口。Elasticsearch 是用 Java 开发的,并作为 Apache 许可条款下的开放源码发布,是当前流行的企业级搜索引擎。设计用于云计算中,能够达到实时搜索,稳定,可靠,快速,安装使用方便。

    117 引用 • 99 回帖 • 223 关注
  • Linux

    Linux 是一套免费使用和自由传播的类 Unix 操作系统,是一个基于 POSIX 和 Unix 的多用户、多任务、支持多线程和多 CPU 的操作系统。它能运行主要的 Unix 工具软件、应用程序和网络协议,并支持 32 位和 64 位硬件。Linux 继承了 Unix 以网络为核心的设计思想,是一个性能稳定的多用户网络操作系统。

    939 引用 • 940 回帖
  • PostgreSQL

    PostgreSQL 是一款功能强大的企业级数据库系统,在 BSD 开源许可证下发布。

    22 引用 • 22 回帖 • 1 关注
  • 酷鸟浏览器

    安全 · 稳定 · 快速
    为跨境从业人员提供专业的跨境浏览器

    3 引用 • 59 回帖 • 31 关注
  • golang

    Go 语言是 Google 推出的一种全新的编程语言,可以在不损失应用程序性能的情况下降低代码的复杂性。谷歌首席软件工程师罗布派克(Rob Pike)说:我们之所以开发 Go,是因为过去 10 多年间软件开发的难度令人沮丧。Go 是谷歌 2009 发布的第二款编程语言。

    497 引用 • 1387 回帖 • 294 关注
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    19 引用 • 31 回帖
  • Git

    Git 是 Linux Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。

    209 引用 • 358 回帖
  • 小说

    小说是以刻画人物形象为中心,通过完整的故事情节和环境描写来反映社会生活的文学体裁。

    28 引用 • 108 回帖
  • WiFiDog

    WiFiDog 是一套开源的无线热点认证管理工具,主要功能包括:位置相关的内容递送;用户认证和授权;集中式网络监控。

    1 引用 • 7 回帖 • 586 关注
  • 宕机

    宕机,多指一些网站、游戏、网络应用等服务器一种区别于正常运行的状态,也叫“Down 机”、“当机”或“死机”。宕机状态不仅仅是指服务器“挂掉了”、“死机了”状态,也包括服务器假死、停用、关闭等一些原因而导致出现的不能够正常运行的状态。

    13 引用 • 82 回帖 • 53 关注
  • Mobi.css

    Mobi.css is a lightweight, flexible CSS framework that focus on mobile.

    1 引用 • 6 回帖 • 733 关注
  • 服务器

    服务器,也称伺服器,是提供计算服务的设备。由于服务器需要响应服务请求,并进行处理,因此一般来说服务器应具备承担服务并且保障服务的能力。

    124 引用 • 580 回帖
  • 旅游

    希望你我能在旅途中找到人生的下一站。

    90 引用 • 899 回帖
  • OkHttp

    OkHttp 是一款 HTTP & HTTP/2 客户端库,专为 Android 和 Java 应用打造。

    16 引用 • 6 回帖 • 60 关注
  • BAE

    百度应用引擎(Baidu App Engine)提供了 PHP、Java、Python 的执行环境,以及云存储、消息服务、云数据库等全面的云服务。它可以让开发者实现自动地部署和管理应用,并且提供动态扩容和负载均衡的运行环境,让开发者不用考虑高成本的运维工作,只需专注于业务逻辑,大大降低了开发者学习和迁移的成本。

    19 引用 • 75 回帖 • 632 关注
  • GitBook

    GitBook 使您的团队可以轻松编写和维护高质量的文档。 分享知识,提高团队的工作效率,让用户满意。

    3 引用 • 8 回帖 • 2 关注
  • 代码片段

    代码片段分为 CSS 与 JS 两种代码,添加在 [设置 - 外观 - 代码片段] 中,这些代码会在思源笔记加载时自动执行,用于改善笔记的样式或功能。

    用户在该标签下分享代码片段时需在帖子标题前添加 [css] [js] 用于区分代码片段类型。

    54 引用 • 292 回帖
  • 思源笔记

    思源笔记是一款隐私优先的个人知识管理系统,支持完全离线使用,同时也支持端到端加密同步。

    融合块、大纲和双向链接,重构你的思维。

    22019 引用 • 87804 回帖 • 2 关注
  • JVM

    JVM(Java Virtual Machine)Java 虚拟机是一个微型操作系统,有自己的硬件构架体系,还有相应的指令系统。能够识别 Java 独特的 .class 文件(字节码),能够将这些文件中的信息读取出来,使得 Java 程序只需要生成 Java 虚拟机上的字节码后就能在不同操作系统平台上进行运行。

    180 引用 • 120 回帖 • 1 关注
  • 创业

    你比 99% 的人都优秀么?

    84 引用 • 1399 回帖 • 1 关注
  • Flume

    Flume 是一套分布式的、可靠的,可用于有效地收集、聚合和搬运大量日志数据的服务架构。

    9 引用 • 6 回帖 • 621 关注
  • RabbitMQ

    RabbitMQ 是一个开源的 AMQP 实现,服务器端用 Erlang 语言编写,支持多种语言客户端,如:Python、Ruby、.NET、Java、C、PHP、ActionScript 等。用于在分布式系统中存储转发消息,在易用性、扩展性、高可用性等方面表现不俗。

    49 引用 • 60 回帖 • 366 关注
  • OnlyOffice
    4 引用 • 2 关注
  • Netty

    Netty 是一个基于 NIO 的客户端-服务器编程框架,使用 Netty 可以让你快速、简单地开发出一个可维护、高性能的网络应用,例如实现了某种协议的客户、服务端应用。

    49 引用 • 33 回帖 • 19 关注
  • 反馈

    Communication channel for makers and users.

    123 引用 • 911 回帖 • 237 关注