博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
(一)神经网络入门之线性回归
阅读量:2440 次
发布时间:2019-05-10

本文共 5219 字,大约阅读时间需要 17 分钟。

作者:chen_h

微信号 & QQ:862251340
微信公众号:coderpai
简书地址:


这篇教程是翻译写的神经网络教程,作者已经授权翻译,这是。

该教程将介绍如何入门神经网络,一共包含五部分。你可以在以下链接找到完整内容。

这篇教程中的代码是由 Python 2 产生的,在教程的最后,我会给出全部代码的链接,帮助学习。神经网络中有关矩阵的运算我们采用来构建,画图使用来构建。如果你来没有按照这些软件,那么我强烈建议你使用来安装,这个软件包中包含了运行这个教程的所有软件包,非常方便使用。

我们先导入教程需要的软件包

from __future__ import print_functionimport numpy as npimport matplotlib.pyplot as plt

线性回归


本教程主要包含三部分:

* 一个非常简单的神经网络
* 一些概念,比如目标函数,损失函数
* 梯度下降

首先我们来构建一个最简单的神经网络,这个神经网络只有一个输入,一个输出,用来构建一个线性回归模型,从输入的x来预测一个真实结果t。神经网络的模型结构为y = x * w,其中x是输入参数,w是权重,y是预测结果。神经网络的模型可以被表示为下图:

在常规的神经网络中,神经网络结构中有多个层,非线性激活函数和每个节点上面的偏差单元。在这个教程中,我们只使用一个只有一个权重w的层,并且没有激活函数和偏差单元。在中,权重w和偏差单元一般都写成一个参数向量β,其中偏差单元是y轴上面的截距,w是回归线的斜率。在线性回归中,我们一般使用来优化这些参数。

在这篇教程中,我们的目的是最小化目标损失函数,使得实际输出的y和正确结果t尽可能的接近。损失函数我们定义为:

损失函数

对于损失函数的优化,我们采用,这个方法是神经网络中常见的优化方法。

定义目标函数

在这个例子中,我们使用函数f来产生目标结果t,但是对目标结果加上一些N(0, 0.2),其中N表示正态分布,均值是0,方差是0.2f定义为f(x) = 2xx是输入参数,回归线的斜率是2,截距是0。所以最后的t = f(x) + N(0, 0.2)

我们将产生20个均匀分布的数据作为数据样本x,然后设计目标结果t。下面的程序我们生成了xt,以及画出了他们之间的线性关系。

# Define the vector of input samples as x, with 20 values sampled from a uniform distribution# between 0 and 1x = np.random.uniform(0, 1, 20)# Generate the target values t from x with small gaussian noise so the estimation won't be perfect.# Define a function f that represents the line that generates t without noisedef f(x): return x * 2# Create the targets t with some gaussian noisenoise_variance = 0.2 # Variance of the gaussian noise# Gaussian noise error for each sample in xnoise = np.random.randn(x.shape[0]) * noise_variance# Create targets tt = f(x) + noise
# Plot the target t versus the input xplt.plot(x, t, 'o', label='t')# Plot the initial lineplt.plot([0, 1], [f(0), f(1)], 'b-', label='f(x)')plt.xlabel('$x$', fontsize=15)plt.ylabel('$t$', fontsize=15)plt.ylim([0,2])plt.title('inputs (x) vs targets (t)')plt.grid()plt.legend(loc=2)plt.show()

定义损失函数

我们将优化模型y = w * x中的参数w,使得对于训练集中的N个样本,损失函数达到最小。

损失函数

即,我们的优化目标是:

损失函数优化目标

从函数中,我们可以发现,我们将所有样本的误差都进行了累加,这就是所谓的批训练(batch training)。我们也可以在训练的时候,每次训练一个样本,这种方法在在线训练中非常常用。

我们利用以下函数画出损失函数与权重的关系。从图中,我们可以看出损失函数的值达到最小时,w的值是2。这个值就是我们函数f(x)的斜率。这个损失函数是一个,并且只有一个全局最小值。

nn(x, w)函数实现了神经网络模型,cost(y, t)函数实现了损失函数。

# Define the neural network function y = x * wdef nn(x, w): return x*w# Define the cost functiondef cost(y, t): return ((t - y) ** 2).sum()

cost vs weight

优化损失函数

对于教程中简单的损失函数,可能你看一眼就能知道最佳的权重是什么。但是对于或者更高维度的损失函数,这就是我们为什么要使用各种的原因了。

梯度下降

在训练神经网络中,算法是一种比较常用的优化算法。梯度下降算法的原理是损失函数对于每个参数进行,并且利用对参数进行更新。权重w通过循环进行更新:

权重更新函数

其中,w(k)表示权重w更新到第k步时的值,Δw为定义为:

权重的梯度

其中,μ是学习率,它的含义是在参数更新的时候,每一步的跨度大小。∂ξ/∂w表示损失函数ξ对于w的梯度。对于每一个训练样本i,我们可以利用推导出对应的梯度,如下:

链式规则

其中,ξi是第i个样本的损失函数,因此,∂ξi/∂yi可以这样进行推导:

函数推导

因为y(i) = x(i) ∗ w,所以我们对于∂yi/∂w可以这样进行推导:

函数推导

因此,对于第i个训练样本,Δw的完整推导如下:

Δw 完整推导

在批处理过程中,我们将所有的梯度都进行累加:

批处理函数推导

在进行梯度下降之前,我们需要对权重进行一个初始化,然后再使用梯度下降算法进行训练,最后直至算法收敛。学习率作为一个,需要单独调试。

gradient(w, x, t)函数实现了梯度∂ξ/∂wdelta_w(w_k, x, t, learning_rate)函数实现了Δw

# define the gradient function. Remember that y = nn(x, w) = x * wdef gradient(w, x, t):  return 2 * x * (nn(x, w) - t)# define the update function delta wdef delta_w(w_k, x, t, learning_rate):  return learning_rate * gradient(w_k, x, t).sum()# Set the initial weight parameterw = 0.1# Set the learning ratelearning_rate = 0.1# Start performing the gradient descent updates, and print the weights and cost:nb_of_iterations = 4 # number of gradient descent updatesw_cost = [(w, cost(nn(x, w), t))] # List to store the weight, costs valuesfor i in range(nb_of_iterations):  dw = delta_w(w, x, t, learning_rate) # Get the delta w update  w = w - dw # Update the current weight parameter  w_cost.append((w, cost(nn(x, w), t))) # Add weight, cost to list# Print the final w, and costfor i in range(0, len(w_cost)):  print('w({}): {:.4f} \t cost: {:.4f}'.format(i, w_cost[i][0], w_cost[i][1]))# outputw(0): 0.1000   cost: 23.3917w(1): 2.3556   cost: 1.0670w(2): 2.0795   cost: 0.7324w(3): 2.1133   cost: 0.7274w(4): 2.1091   cost: 0.7273

从计算结果中,我们很容易的看出来了,梯度下降算法很快的收敛到了2.0左右,接下来可视化一下梯度下降过程。

# Plot the first 2 gradient descent updatesplt.plot(ws, cost_ws, 'r-')  # Plot the error curve# Plot the updatesfor i in range(0, len(w_cost)-2):  w1, c1 = w_cost[i]  w2, c2 = w_cost[i+1]  plt.plot(w1, c1, 'bo')  plt.plot([w1, w2],[c1, c2], 'b-')  plt.text(w1, c1+0.5, '$w({})$'.format(i)) # Show figureplt.xlabel('$w$', fontsize=15)plt.ylabel('$\\xi$', fontsize=15)plt.title('Gradient descent updates plotted on cost function')plt.grid()plt.show()

Gradient descent updates plotted on cost function

梯度更新

上图展示了梯度下降的可视化过程。图中蓝色的点表示在第k轮中w(k)的值。从图中我们可以得知,w的值越来越收敛于2.0。该模型训练10次就能收敛,如下图所示。

w = 0# Start performing the gradient descent updatesnb_of_iterations = 10  # number of gradient descent updatesfor i in range(nb_of_iterations):  dw = delta_w(w, x, t, learning_rate)  # get the delta w update  w = w - dw  # update the current weight parameter
# Plot the fitted line agains the target line# Plot the target t versus the input xplt.plot(x, t, 'o', label='t')# Plot the initial lineplt.plot([0, 1], [f(0), f(1)], 'b-', label='f(x)')# plot the fitted lineplt.plot([0, 1], [0*w, 1*w], 'r-', label='fitted line')plt.xlabel('input x')plt.ylabel('target t')plt.ylim([0,2])plt.title('input vs. target')plt.grid()plt.legend(loc=2)plt.show()

input vs. target


作者:chen_h

微信号 & QQ:862251340
简书地址:

CoderPai 是一个专注于算法实战的平台,从基础的算法到人工智能算法都有设计。如果你对算法实战感兴趣,请快快关注我们吧。加入AI实战微信群,AI实战QQ群,ACM算法微信群,ACM算法QQ群。长按或者扫描如下二维码,关注 “CoderPai” 微信号(coderpai)

这里写图片描述

你可能感兴趣的文章
大家好,新学生。 请问怎么升级Redhat9.0 kernel 2.4.X-->2.6.18 的详细过程(转)
查看>>
FreeBSD6.1+无线+永中......桌面安装【附笔记】(转)
查看>>
adsl设置(转)
查看>>
Wii将有一个可升级的Linux操作系统(转)
查看>>
Linux机为先锋智能机和PDA06销量大(转)
查看>>
Oracle与SQL Server在企业应用中的比较(转)
查看>>
Unix类操作系统入门(转)
查看>>
让FreeBSD使用ntpd同步时间(转)
查看>>
用cat命令查看文件内的特殊字符(转)
查看>>
debian sid下vmware不能运行一则(转)
查看>>
Linux操作系统套接字编程的5个隐患(转)
查看>>
Ubuntu Linux:定制Ubuntu安装CD(转)
查看>>
调查显示:企业级Linux用户不断攀升(转)
查看>>
Ubuntu/Linux入门介绍-dpkg(转)
查看>>
SCO UNIX学习宝典 高级进阶(转)
查看>>
Oracle9i RAC for RedFlag Linux DC4.1 32bit 安装流程(转)
查看>>
Sybase和Oracle安装过程中常遇到的问题(转)
查看>>
红帽Linux新系统整合虚拟技术 实现简易操作(转)
查看>>
Linux下/etc/default/boot文件字段说明(转)
查看>>
Linux壁纸系列三十四(转)
查看>>