博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Lesson 2 Gradient Desent
阅读量:4040 次
发布时间:2019-05-24

本文共 5780 字,大约阅读时间需要 19 分钟。

The goal is to find X such that

minXf(X)

Using gradient descent algorithm to obtain the minimum value of the funtion.

let y=f(x)
Init: x=x0,y0=f(x0) , iterative step α , convergent precision ϵ

The ith iterative formula can be expressed as:

xi=xi1αf(xi1)

Example: solve the minimum of function f(x)=x2+3x+2

let x0=0 , step \alpha = 0.1, convergent precision ϵ=104

f = @(x) x.^2 - 3*x + 2;hold onfor x=0:0.001:3    plot(x, f(x),'k-');endx = 0;y0 = f(x);plot(x, y0, 'ro-');alpha = 0.1;epsilon = 10^(-4);gnorm = inf;while (gnorm > epsilon)    x = x - alpha*(2*x-3);    y = f(x);    gnorm = abs(y-y0);    plot(x, y, 'ro');    y0 = y;end

这里写图片描述

let’s move into multi-variable case, say we have m samples, each sample has n features. X is expressed as:

X=xT1xT2xTm
where

xi=xi1xi2xin
Then
X
can be denoted as :
X=x11x21xm1x11x21xm1x1nx2nxmn

Assuming h(x.)=j=1najx.j=xT.a

Here,

a=a1a2an
is a unknown vector we need to solve.

Xay=h(x1)y1h(x2)y2h(xm)ym

Now the objective function is minaf(a)=12(Xay)T(Xay)

Before the derivation, I would like to introduce some facts:

tr(AB)=tr(BA) ………………………………..(1)
tr(ABC)=tr(BCA)=tr(CAB) ………………………………..(2)
tr(A)=tr(AT) ………………………………..(3)
if aR , tr(a)=a ………………………………..(4)
Atr(AB)=BT ………………………………..(5)
Atr(ABATC)=CAB+CTABT ………………………………..(6)

In order to obtain the critical points of f(a) , we take the derivative of f(a) w.r.t a and set it to be zero.

af(a)af(a)=0=a12(Xay)T(Xay)=12a(aTXTXaaTXTyyTXa+yTy)=12atr(aTXTXaaTXTyyTXa+yTy)// the trace of a scalar is still a scalar=12(atr(aTXTXa)atr(aTXTy)atr(yTXa)+atr(yTy))=12(atr(aTXTXa)atr(yTXa)atr(yTXa)+atr(yTy))=12(atr(aTXTXa)2XTy)=12(atr(aaTXTX)2XTy)=12(atr(aIaTXTX)2XTy)=XTXaXTy=0

we can easily get a as follows:

a=(XTX)1XTy

function [xopt,fopt,niter,gnorm,dx] = grad_descent(varargin)% grad_descent.m demonstrates how the gradient descent method can be used% to solve a simple unconstrained optimization problem. Taking large step% sizes can lead to algorithm instability. The variable alpha below% specifies the fixed step size. Increasing alpha above 0.32 results in% instability of the algorithm. An alternative approach would involve a% variable step size determined through line search.%% This example was used originally for an optimization demonstration in ME% 149, Engineering System Design Optimization, a graduate course taught at% Tufts University in the Mechanical Engineering Department. A% corresponding video is available at:% % http://www.youtube.com/watch?v=cY1YGQQbrpQ%% Author: James T. Allison, Assistant Professor, University of Illinois at% Urbana-Champaign% Date: 3/4/12if nargin==0    % define starting point    x0 = [3 3]';elseif nargin==1    % if a single input argument is provided, it is a user-defined starting    % point.    x0 = varargin{
1};else error('Incorrect number of input arguments.')end% termination tolerancetol = 1e-6;% maximum number of allowed iterationsmaxiter = 1000;% minimum allowed perturbationdxmin = 1e-6;% step size ( 0.33 causes instability, 0.2 quite accurate)alpha = 0.1;% initialize gradient norm, optimization vector, iteration counter, perturbationgnorm = inf; x = x0; niter = 0; dx = inf;% define the objective function:f = @(x1,x2) x1.^2 + x1.*x2 + 3*x2.^2;% plot objective function contours for visualization:figure(1); clf; ezcontour(f,[-5 5 -5 5]); axis equal; hold on% redefine objective function syntax for use with optimization:f2 = @(x) f(x(1),x(2));% gradient descent algorithm:while and(gnorm>=tol, and(niter <= maxiter, dx >= dxmin)) % calculate gradient: g = grad(x); gnorm = norm(g); % take step: xnew = x - alpha*g; % check step if ~isfinite(xnew) display(['Number of iterations: ' num2str(niter)]) error('x is inf or NaN') end % plot current point plot([x(1) xnew(1)],[x(2) xnew(2)],'ko-') refresh % update termination metrics niter = niter + 1; dx = norm(xnew-x); x = xnew;endxopt = x;fopt = f2(xopt);niter = niter - 1;% define the gradient of the objectivefunction g = grad(x)g = [2*x(1) + x(2) x(1) + 6*x(2)];

这里写图片描述

function [xopt,fopt,niter,gnorm,dx] = grad_descent(varargin)if nargin==0    % define starting point    x0 = [3 3]';elseif nargin==1    % if a single input argument is provided, it is a user-defined starting    % point.    x0 = varargin{
1};else error('Incorrect number of input arguments.')end% termination tolerancetol = 1e-6;% maximum number of allowed iterationsmaxiter = 1000;% minimum allowed perturbationdxmin = 1e-6;% step size ( 0.33 causes instability, 0.2 quite accurate)alpha = 0.1;% initialize gradient norm, optimization vector, iteration counter, perturbationgnorm = inf; x = x0; niter = 0; dx = inf;% define the objective function:f = @(x1,x2) x1.^2 + x1.*x2 + 3*x2.^2;m = -5:0.1:5;[X,Y] = meshgrid(m);Z = f(X,Y);% plot objective function contours for visualization:figure(1); clf; meshc(X,Y,Z); hold on% redefine objective function syntax for use with optimization:f2 = @(x) f(x(1),x(2));% gradient descent algorithm:while and(gnorm>=tol, and(niter <= maxiter, dx >= dxmin)) % calculate gradient: g = grad(x); gnorm = norm(g); % take step: xnew = x - alpha*g; % check step if ~isfinite(xnew) display(['Number of iterations: ' num2str(niter)]) error('x is inf or NaN') end % plot current point plot([x(1) xnew(1)],[x(2) xnew(2)],'ko-') plot3([x(1) xnew(1)],[x(2) xnew(2)], [f(x(1),x(2)) f(xnew(1),xnew(2))]... ,'r+-'); refresh % update termination metrics niter = niter + 1; dx = norm(xnew-x); x = xnew;endxopt = x;fopt = f2(xopt);niter = niter - 1;% define the gradient of the objectivefunction g = grad(x)g = [2*x(1) + x(2) x(1) + 6*x(2)];

这里写图片描述

你可能感兴趣的文章
CImg库编译使用.
查看>>
openstack虚拟机创建流程
查看>>
openstack网络总结
查看>>
excel 查找一个表的数据在另一个表中是否存在
查看>>
centos 7 上配置dnsmasq 同时支持ipv4和ipv6的DHCP服务
查看>>
AsyncTask、View.post(Runnable)、ViewTreeObserver三种方式总结frame animation自动启动
查看>>
Android中AsyncTask的简单用法
查看>>
解决跨网场景下,CAS重定向无法登录的问题(无需修改现有代码)
查看>>
java反编译命令
查看>>
activemq依赖包获取
查看>>
概念区别
查看>>
final 的作用
查看>>
在Idea中使用Eclipse编译器
查看>>
idea讲web项目部署到tomcat,热部署
查看>>
IDEA Properties中文unicode转码问题
查看>>
Idea下安装Lombok插件
查看>>
zookeeper
查看>>
Idea导入的工程看不到src等代码
查看>>
技术栈
查看>>
Jenkins中shell-script执行报错sh: line 2: npm: command not found
查看>>