机器学习 (2)——线性回归

本贴最后更新于 2235 天前,其中的信息可能已经渤澥桑田

0x00 前言

第一篇记录了机器学习的定义、分类和范围,这一篇开始从机器学习的方法学起,总结机器学习的经典方法,然后尽量自己写代码实现。

0x01 回归算法

回归算法属于机器学习中监督学习类的算法,是入门机器学习最基础的算法。

回归分析是研究自变量和因变量之间关系的一种预测模型技术。这些技术应用于预测,时间序列模型和找到变量之间关系。

回归算法就是量化因变量受自变量影响的大小,建立线性回归方程或者非线性回归方程,从而达对因变量的预测,或者对因变量的解释作用。

0x02 回归分析流程

① 探索性分析,画不同变量之间的散点图,进行相关性检验等,了解数据的大致情况,以及得知重点关注那几个变量;

② 变量和模型选择;

③ 回归分析假设条件验证;

④ 共线性和强影响点检查;

⑤ 模型修改,并且重复 ③④;

⑥ 模型验证。

0x03 回归算法分类

回归算法主要通过三种方法分类:自变量的个数、因变量的类型和回归线的形状。

常见的回归算法有:

  • 线性回归
  • 逻辑回归
  • 多项式回归
  • 逐步回归
  • 岭回归
  • Lasso 回归
  • ElasticNet 回归

0x04 线性回归(Linear Regression)

线性回归是世界上最知名的建模方法之一,在线性回归模型中,因变量是连续型的,自变量可以使连续型或离散型的,回归线是线性的。

线性回归用最适直线(回归线)去建立因变量 Y 和一个或多个自变量 X 之间的关系。可以用公式来表示:

Y=A+B*X+e

A 为截距,B 为回归线的斜率,e 是误差项。

简单线性回归与多元线性回归的差别在于:多元线性回归有多个(>1)自变量,而简单线性回归只有一个自变量。

简单线性回归

我们首先实现一个只有单一自变量的简单线性回归

我们实现这个算法,可以先以 Andrew Ng 机器学习讲义中美国俄亥俄州 Portland Oregon 城市房屋价格为例:

这个例子中近简化使用房屋面积一个因子作为自变量,y 轴对应其因变量房屋价格。所以我们机器学习的线性回归就变为对于给定有限的数据集,进行一元线性回归,即找到一个一次函数 y=y(x) + e,使得 y 满足

当 x={2104, 1600, 2400, 1416, 3000, … }, y={400, 330, 369, 232, 540, … }
面积(feet²) 价格(1000$)
2104 400
1600 330
2400 369
1416 232
3000 540
··· ···
对这个问题我们先给出假设函数即需要拟合的直线:

其中 a 和 b 是我们要求得的参数,参数得变化会引起函数的变化。

而我们解出参数之后的函数是否为最优解,我们需要引入一个概念:Cost Function,即代价函数或成本函数。

代价函数(Cost Function)

在回归问题中,衡量最优解的常用代价函数为平方误差。

平方误差在高中和大学的概率论、统计学等课程中我们都有所了解,就是用样本数据和拟合出的线做差值,然后对差值进行平方和并除以点数 m 计算平均值。

而在这里,我们要导出代价函数,额外除以 1/2 做数学简化,形成以下代价函数:

ps:这里额外除以 1/2,是为了之后平方函数的微分项将抵消 1/2 项,以方便计算梯度下降。

下来我们求解最优解的问题就转变为了求解代价函数的最小值。

其中 J 是基于 θ 的函数,我们可以先将其简化成只有 θ1 的函数,令 θ0=0.

然后我们不断给定 θ1 的值,基于样本值进行计算代价函数 J,就可以得到一个 θ1 和 J 的函数,并在某一点取得极小值。

如样本数据为 y ={(1,1), (2,2),(3,3)}时,可以得到如下的 J-θ1 图形:

我们求解线性回归最优解的方法一般是梯度下降法和最小二乘法

最小二乘法

代价函数中使用的均方误差,其实对应了我们常用的欧几里得的距离(欧式距离,Euclidean Distance), 基于均方误差最小化进行模型求解的方法称为“最小二乘法”(least square method),即通过最小化误差的平方和寻找数据的最佳函数匹配。

当函数子变量为一维时,最小二乘法就蜕变成寻找一条直线。

如我们上例中的模型,寻找 J 极小值就是分别用 J 对 θ1 和 θ0 求偏导,然后寻找偏导为零的点。

解得:

局限性

最小二乘法算法简单,容易理解,而然在现实机器学习却有其局限性:

  • 并非所有函数都可以求出驻点,即导数为 0 的点,f(x)=0
  • 求解方程困难,或求根公式复杂
  • 导数并无解析解
  • 最小二乘法的矩阵公式,计算一个矩阵的逆是相当耗费时间的, 而且求逆也会存在数值不稳定的情况

梯度下降法

正是由于在实际中,最小二乘法遇到的困难和局限性,尤其是多数超定方程组不存在解,我们由求导转向迭代逼近。也就是梯度下降算法。

首先我们了解一下什么是梯度,这在复变函数等大学课程中都曾经学过。

方向导数

方向导数即研究在某一点的任意方向的变化率,是偏导数的广义扩展。

梯度

梯度则基于方向导数,是一个向量而非数,梯度代表了各个方向导数中,变化趋势最大的那个方向。

那么,梯度方向就是增长最快的方向,负梯度方向就是减小最快的方向。

梯度下降算法

梯度下降算法通常也被称作最速下降法。其目的是找到一个局部极小值点;其目标与最小二乘法相同,都是使得估算值与实际值的总平方差尽量小。

其方法是采用计算数学的迭代法,先给定一初始点,然后向下降最快的方向调整,在若干次迭代之后找到局部最小。

比如我们给定上面的方程,初始参数是 θ0,θ1,我们不断改变 θ0,θ1 从而减少 J(θ0,θ1)的值,具体做法是求导。直到最终收敛。

迭代公式如下:

其中 θj 可以是 θ0 和 θ1 这两个参数,α 为步长,整个式子的意义为,θ0,θ1 每次向 J(θ0,θ1)负梯度方向下降步长 α。

学习率

公式中的步长 α,也称为学习率,用来控制每次下降的幅度。

我们应该调整参数 α 以确保梯度下降算法在合理的时间内收敛。

  • 如果 α 过小,每步会移动非常近,收敛时间就会很长。
  • 如果 α 过大,每步会移动比较远,会导致直接越过极小值,甚至无法收敛到最低点。

如果我们时间耗费较长或无法收敛,那就说明我们要重新制定学习率 α。

线性回归梯度下降

对于线性模型,我们可以这样写梯度下降函数。

h(x)是需要拟合的函数。

J(θ)称为均方误差或 cost function。用来衡量训练集众的样本对线性模式的拟合程度。

m 为训练集众样本的个数。

θ 是我们最终需要通过梯度下降法来求得的参数。

接下来的梯度下降法就有两种不同的迭代思路。

批量梯度下降(Batch Gradient Descent)

可以看到上述每次迭代都需要计算所有样本的残差并加和,批量梯度下降是梯度下降法最原始的形式,它的具体思路是在更新每一参数时都使用所有的样本来进行更新。

1.计算 J(θ)关于 θT 的偏导数,也就得到了向量中每一个 θ 的梯度。

2.沿着梯度的反方向更新参数 θ 的值

3.迭代直到收敛。

优点:全局最优解,易于并行实现。
缺点:当样本数目很多时,训练过程会很慢。

随机梯度下降(Stochastic gradient descent)

和批量梯度有所不同的地方在于,每次迭代只选取一个样本的数据,一旦到达最大的迭代次数或是满足预期的精度,就停止。

随机梯度下降法的 θ 更新表达式。

迭代直到收敛。

优点:训练速度快。
缺点:准确度下降,并不是全局最优,不易于并行实现。

视觉效果

当我们的成本函数处于图的坑底时,J 值最小,为最佳解。

  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖
  • 运维

    互联网运维工作,以服务为中心,以稳定、安全、高效为三个基本点,确保公司的互联网业务能够 7×24 小时为用户提供高质量的服务。

    149 引用 • 257 回帖
  • 持续集成

    持续集成(Continuous Integration)是一种软件开发实践,即团队开发成员经常集成他们的工作,通过每个成员每天至少集成一次,也就意味着每天可能会发生多次集成。每次集成都通过自动化的构建(包括编译,发布,自动化测试)来验证,从而尽早地发现集成错误。

    15 引用 • 7 回帖 • 1 关注
  • 开源中国

    开源中国是目前中国最大的开源技术社区。传播开源的理念,推广开源项目,为 IT 开发者提供了一个发现、使用、并交流开源技术的平台。目前开源中国社区已收录超过两万款开源软件。

    7 引用 • 86 回帖
  • 工具

    子曰:“工欲善其事,必先利其器。”

    286 引用 • 729 回帖
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3187 引用 • 8213 回帖
  • ActiveMQ

    ActiveMQ 是 Apache 旗下的一款开源消息总线系统,它完整实现了 JMS 规范,是一个企业级的消息中间件。

    19 引用 • 13 回帖 • 672 关注
  • 博客

    记录并分享人生的经历。

    273 引用 • 2388 回帖
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    6 引用 • 63 回帖 • 1 关注
  • 微软

    微软是一家美国跨国科技公司,也是世界 PC 软件开发的先导,由比尔·盖茨与保罗·艾伦创办于 1975 年,公司总部设立在华盛顿州的雷德蒙德(Redmond,邻近西雅图)。以研发、制造、授权和提供广泛的电脑软件服务业务为主。

    8 引用 • 44 回帖
  • MongoDB

    MongoDB(来自于英文单词“Humongous”,中文含义为“庞大”)是一个基于分布式文件存储的数据库,由 C++ 语言编写。旨在为应用提供可扩展的高性能数据存储解决方案。MongoDB 是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数据库的。它支持的数据结构非常松散,是类似 JSON 的 BSON 格式,因此可以存储比较复杂的数据类型。

    90 引用 • 59 回帖 • 1 关注
  • 周末

    星期六到星期天晚,实行五天工作制后,指每周的最后两天。再过几年可能就是三天了。

    14 引用 • 297 回帖 • 1 关注
  • Thymeleaf

    Thymeleaf 是一款用于渲染 XML/XHTML/HTML5 内容的模板引擎。类似 Velocity、 FreeMarker 等,它也可以轻易的与 Spring 等 Web 框架进行集成作为 Web 应用的模板引擎。与其它模板引擎相比,Thymeleaf 最大的特点是能够直接在浏览器中打开并正确显示模板页面,而不需要启动整个 Web 应用。

    11 引用 • 19 回帖 • 354 关注
  • Maven

    Maven 是基于项目对象模型(POM)、通过一小段描述信息来管理项目的构建、报告和文档的软件项目管理工具。

    186 引用 • 318 回帖 • 304 关注
  • SMTP

    SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。SMTP 协议属于 TCP/IP 协议簇,它帮助每台计算机在发送或中转信件时找到下一个目的地。

    4 引用 • 18 回帖 • 614 关注
  • 一些有用的避坑指南。

    69 引用 • 93 回帖
  • B3log

    B3log 是一个开源组织,名字来源于“Bulletin Board Blog”缩写,目标是将独立博客与论坛结合,形成一种新的网络社区体验,详细请看 B3log 构思。目前 B3log 已经开源了多款产品:SymSoloVditor思源笔记

    1063 引用 • 3453 回帖 • 203 关注
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    543 引用 • 672 回帖
  • InfluxDB

    InfluxDB 是一个开源的没有外部依赖的时间序列数据库。适用于记录度量,事件及实时分析。

    2 引用 • 72 关注
  • 以太坊

    以太坊(Ethereum)并不是一个机构,而是一款能够在区块链上实现智能合约、开源的底层系统。以太坊是一个平台和一种编程语言 Solidity,使开发人员能够建立和发布下一代去中心化应用。 以太坊可以用来编程、分散、担保和交易任何事物:投票、域名、金融交易所、众筹、公司管理、合同和知识产权等等。

    34 引用 • 367 回帖
  • PostgreSQL

    PostgreSQL 是一款功能强大的企业级数据库系统,在 BSD 开源许可证下发布。

    22 引用 • 22 回帖
  • WebClipper

    Web Clipper 是一款浏览器剪藏扩展,它可以帮助你把网页内容剪藏到本地。

    3 引用 • 9 回帖
  • 知乎

    知乎是网络问答社区,连接各行各业的用户。用户分享着彼此的知识、经验和见解,为中文互联网源源不断地提供多种多样的信息。

    10 引用 • 66 回帖
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    21 引用 • 31 回帖
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    407 引用 • 1246 回帖 • 582 关注
  • 域名

    域名(Domain Name),简称域名、网域,是由一串用点分隔的名字组成的 Internet 上某一台计算机或计算机组的名称,用于在数据传输时标识计算机的电子方位(有时也指地理位置)。

    43 引用 • 208 回帖