手写机器学习算法系列01——线性回归 - Go语言中文社区

手写机器学习算法系列01——线性回归


引言

线性回归是最简单的机器学习算法,说白了就是构造一元或者多元的线性方程,然后根据现有样本数据进行函数拟合,求解出线性方程的各个参数,之后就可以通过该线性方程进行相关预测。

模拟场景

根据常识,一个人的工龄越长,那么这个人的工资也会随之增加。那么,我们现在想要知道,工龄对工资之间具体存在什么样的数学关系。
巧妇难为无米之炊,先放上一组数据:

工龄(年) 工资(元)
1.1 39343.00
1.3 46205.00
1.5 37731.00
2.0 43525.00
2.2 39891.00
2.9 56642.00
3.0 60150.00
3.2 54445.00
3.2 64445.00
3.7 57189.00
3.9 63218.00
4.0 55794.00
4.0 56957.00
4.1 57081.00
4.5 61111.00
4.9 67938.00
5.1 66029.00
5.3 83088.00
5.9 81363.00
6.0 93940.00
6.8 91738.00
7.1 98273.00
7.9 101302.00
8.2 113812.00
8.7 109431.00
9.0 105582.00
9.5 116969.00
9.6 112635.00
10.3 122391.00
10.5 121872.00

构造方程

根据这些样本数据,可以得出下面的散点图:
在这里插入图片描述

这里我们可以根据散点图的大致走向进行猜测,工龄x与工资y之间是一元线性关系,于是我们可以写出如下线性方程:

h(x)=θ0+xθ1 h(x)= theta_0 + x theta_1

其中:

  • h(x)h(x)是最终的工资
  • xx是自变量工龄
  • θ1theta_1
  • θ0theta_0

上述是线性方程最简单的形式,更普遍的情况下可能会有nn个特征值,即以下这种形式:

h(x)=θ0+x1θ1+x2θ2+...+xnθn=x0x0=1θ0+x1θ1+x2θ2+...+xnθn=i=0n(xiθi) begin{aligned} h(x) & = theta_0 + x_1 theta_1 + x_2 theta_2+...+x_n theta_n \ & = overbrace{x_0}^{x_0=1} theta_0 + x_1theta_1 + x_2theta_2+...+x_ntheta_n \ & = sum_{i=0}^n(x_itheta_i) end{aligned}

由于机器学习中基本使用矩阵进行运算,为了后续数学推导,因此我们将上述式子改写为矩阵的形式:

h(x)=i=0n(θixi)=[x0x1...xn]×[θ0θ1...θn]=XΘ begin{aligned} h(x) & = sum_{i=0}^n(theta_ix_i) \ & = left[ begin{array}{c} x_0 \ x_1 \ ... \ x_n \ end{array} right] times left[ begin{array}{c} theta_0 & theta_1 & ... & theta_n\ end{array} right]\ & = X Theta \ end{aligned}

其中:

  • ΘTheta为方程参数的列矩阵形式
  • XX为方程自变量的行矩阵形式

误差项

接下来,我们要考虑到实际值与线性方程的解肯定会存在误差,这里我们用ϵepsilon来表示误差项:

(1)y=(XΘ)+ϵ y^{真实值} =(X Theta)^{预测值} + epsilon^{误差项} tag 1

我们当然希望误差项ϵepsilon越小越好,这里我们假定误差项ϵepsilon是独立的并且服从均值为00、方差为θ2theta^2的高斯分布:
高斯分布

  • 独立表示各个样本之间的误差之间毫无干系
  • 均值为0的意义是概率集中在误差为0的情况

我们可以直接写出每一个样本的误差项ϵepsilon的概率函数

(2)p(ϵ(i))=12πσexp((ϵ(i))22σ2) p(epsilon^{(i)}) = frac{1}{sqrt[]{2pi}sigma} exp(- frac{(epsilon^{(i)})^2}{2sigma^2}) tag 2

结合(1)(2)(1)(2)可得:

(3)p(ϵ(i))=12πσexp((y(i)X(i)Θ)22σ2) p(epsilon^{(i)}) = frac{1}{sqrt[]{2pi}sigma} exp(- frac{(y^{(i)}-X^{(i)}Theta )^2}{2sigma^2}) tag 3

0 条评论

请先 登录 后评论

官方社群

GO教程

猜你喜欢