{
"cells": [
{
"cell_type": "markdown",
"id": "4159be0e",
"metadata": {
"hideCode": false,
"hidePrompt": false,
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 2022-03-28 Gradient Descent\n",
"\n",
"## Last time\n",
"\n",
"* Assumptions of linear models\n",
"* Look at your data!\n",
"* Partial derivatives\n",
"* Loss functions\n",
"\n",
"## Today\n",
"\n",
"* Discuss projects\n",
"* Gradient-based optimization for linear models\n",
"* Nonlinear models"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "eb781a7a",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"data": {
"text/plain": [
"vcond (generic function with 1 method)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using LinearAlgebra\n",
"using Plots\n",
"default(linewidth=4, legendfontsize=12)\n",
"\n",
"function vander(x, k=nothing)\n",
" if isnothing(k)\n",
" k = length(x)\n",
" end\n",
" m = length(x)\n",
" V = ones(m, k)\n",
" for j in 2:k\n",
" V[:, j] = V[:, j-1] .* x\n",
" end\n",
" V\n",
"end\n",
"\n",
"function vander_chebyshev(x, n=nothing)\n",
" if isnothing(n)\n",
" n = length(x) # Square by default\n",
" end\n",
" m = length(x)\n",
" T = ones(m, n)\n",
" if n > 1\n",
" T[:, 2] = x\n",
" end\n",
" for k in 3:n\n",
" #T[:, k] = x .* T[:, k-1]\n",
" T[:, k] = 2 * x .* T[:,k-1] - T[:, k-2]\n",
" end\n",
" T\n",
"end\n",
"\n",
"function chebyshev_regress_eval(x, xx, n)\n",
" V = vander_chebyshev(x, n)\n",
" vander_chebyshev(xx, n) / V\n",
"end\n",
"\n",
"runge(x) = 1 / (1 + 10*x^2)\n",
"runge_noisy(x, sigma) = runge.(x) + randn(size(x)) * sigma\n",
"\n",
"CosRange(a, b, n) = (a + b)/2 .+ (b - a)/2 * cos.(LinRange(-pi, 0, n))\n",
"\n",
"vcond(mat, points, nmax) = [cond(mat(points(-1, 1, n))) for n in 2:nmax]"
]
},
{
"cell_type": "markdown",
"id": "19abcf99",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Workshopping projects today\n",
"\n",
"* Maybe you've started, maybe you're still looking for a good project\n",
"* What have you been thinking about?\n",
"* What do you need help on?\n",
"\n",
"## Each breakout group will report out\n",
"1. One interesting thing about process\n",
"2. One question you have relevant to the projects (specific or general)\n",
"3. One question from content we've covered"
]
},
{
"cell_type": "markdown",
"id": "f29e52ed",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Variational notation for derivatives\n",
"\n",
"It's convenient to express derivatives in terms of how they act on an infinitessimal perturbation. So we might write\n",
"\n",
"$$ \\delta f = \\frac{\\partial f}{\\partial x} \\delta x .$$\n",
"\n",
"(It's common to use $\\delta x$ or $dx$ for these infinitesimals.) This makes inner products look like a normal product rule\n",
"\n",
"$$ \\delta(\\mathbf x^T \\mathbf y) = \\mathbf y^T (\\delta \\mathbf x) + \\mathbf x^T (\\delta \\mathbf y). $$\n",
"\n",
"A powerful example of variational notation is differentiating a matrix inverse\n",
"\n",
"$$ 0 = \\delta I = \\delta(A^{-1} A) = (\\delta A^{-1}) A + A^{-1} (\\delta A) $$\n",
"and thus\n",
"$$ \\delta A^{-1} = - A^{-1} (\\delta A) A^{-1} $$"
]
},
{
"cell_type": "markdown",
"id": "eea01b31",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Optimization for linear models\n",
"Given data $(x,y)$ and loss function $L(c; x,y)$, we wish to find the coefficients $c$ that minimize the loss, thus yielding the \"best predictor\" (in a sense that can be made statistically precise). I.e.,\n",
"$$ \\bar c = \\arg\\min_c L(c; x,y) . $$\n",
"\n",
"It is usually desirable to design models such that the loss function is differentiable with respect to the coefficients $c$, because this allows the use of more efficient optimization methods. Recall that our forward model is given in terms of the Vandermonde matrix,\n",
"\n",
"$$ f(x, c) = V(x) c $$\n",
"\n",
"and thus\n",
"\n",
"$$ \\frac{\\partial f}{\\partial c} = V(x) . $$"
]
},
{
"cell_type": "markdown",
"id": "254b58df",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Derivative of loss function\n",
"\n",
"We now differentiate our loss function\n",
"$$ L(c; x, y) = \\frac 1 2 \\lVert f(x, c) - y \\rVert^2 $$\n",
"using a more linear algebraic approach to write the same expression is\n",
"\\begin{align} \\nabla_c L(c; x,y) &= \\big( f(x,c) - y \\big)^T V(x) \\\\\n",
"&= \\big(V(x) c - y \\big)^T V(x) \\\\\n",
"&= V(x)^T \\big( V(x) c - y \\big) .\n",
"\\end{align}\n",
"A necessary condition for the loss function to be minimized is that $\\nabla_c L(c; x,y) = 0$.\n",
"\n",
"* Is the condition sufficient for general $f(x, c)$?\n",
"* Is the condition sufficient for the linear model $f(x,c) = V(x) c$?\n",
"* Have we seen this sort of equation before?"
]
},
{
"cell_type": "markdown",
"id": "ab251374",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Gradient descent\n",
"\n",
"Instead of solving the least squares problem using linear algebra (QR factorization), we could solve it using gradient descent. That is, on each iteration, we'll take a step in the direction of the negative gradient."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "04841eea",
"metadata": {
"cell_style": "center"
},
"outputs": [
{
"data": {
"text/plain": [
"grad_descent (generic function with 1 method)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function grad_descent(loss, grad, c0; gamma=1e-3, tol=1e-5)\n",
" \"\"\"Minimize loss(c) via gradient descent with initial guess c0\n",
" using learning rate gamma. Declares convergence when gradient\n",
" is less than tol or after 500 steps.\n",
" \"\"\"\n",
" c = copy(c0)\n",
" chist = [copy(c)]\n",
" lhist = [loss(c)]\n",
" for it in 1:500\n",
" g = grad(c)\n",
" c -= gamma * g\n",
" push!(chist, copy(c))\n",
" push!(lhist, loss(c))\n",
" if norm(g) < tol\n",
" break\n",
" end\n",
" end\n",
" (c, hcat(chist...), lhist)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "ca23b5fe",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Quadratic model"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "ab3a31f0",
"metadata": {
"cell_style": "split",
"hideCode": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cond(A) = 9.46578492882319\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = [1 1; 1 8]\n",
"@show cond(A)\n",
"loss(c) = .5 * c' * A * c\n",
"grad(c) = A * c\n",
"\n",
"c, chist, lhist = grad_descent(loss, grad, [.9, .9],\n",
" gamma=.22)\n",
"plot(lhist, yscale=:log10, xlims=(0, 80))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "7909e35d",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(chist[1, :], chist[2, :], marker=:circle)\n",
"x = LinRange(-1, 1, 30)\n",
"contour!(x, x, (x,y) -> loss([x, y]))"
]
},
{
"cell_type": "markdown",
"id": "ddd68837",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Chebyshev regression via optimization\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "5520e580",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"8-element Vector{Float64}:\n",
" 0.7891712932446856\n",
" 0.007025641791364225\n",
" -1.957064629386119\n",
" -0.36715187386667464\n",
" 0.9229823103908219\n",
" 0.3031710357963882\n",
" 0.4284481908323782\n",
" 0.27267026221101087"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = LinRange(-1, 1, 200)\n",
"sigma = 0.5; n = 8\n",
"y = runge_noisy(x, sigma)\n",
"V = vander(x, n)\n",
"function loss(c)\n",
" r = V * c - y\n",
" .5 * r' * r\n",
"end\n",
"function grad(c)\n",
" r = V * c - y\n",
" V' * r\n",
"end\n",
"c, _, lhist = grad_descent(loss, grad, ones(n),\n",
" gamma=0.008)\n",
"c"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "77957be4",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cond(V' * V) = 52902.52994792479\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c0 = V \\ y\n",
"l0 = 0.5 * norm(V * c0 - y)^2\n",
"@show cond(V' * V)\n",
"plot(lhist, yscale=:log10)\n",
"plot!(i -> l0, color=:black)"
]
},
{
"cell_type": "markdown",
"id": "8c8e60c7",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Why use QR vs gradient-based optimization?"
]
},
{
"cell_type": "markdown",
"id": "7a95e2e9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Nonlinear models\n",
"\n",
"Instead of the linear model\n",
"$$ f(x,c) = V(x) c = c_0 + c_1 \\underbrace{x}_{T_1(x)} + c_2 T_2(x) + \\dotsb $$\n",
"let's consider a rational model with only three parameters\n",
"$$ f(x,c) = \\frac{1}{c_1 + c_2 x + c_3 x^2} = (c_1 + c_2 x + c_3 x^2)^{-1} . $$\n",
"We'll use the same loss function\n",
"$$ L(c; x,y) = \\frac 1 2 \\lVert f(x,c) - y \\rVert^2 . $$\n",
"\n",
"We will also need the gradient\n",
"$$ \\nabla_c L(c; x,y) = \\big( f(x,c) - y \\big)^T \\nabla_c f(x,c) $$\n",
"where\n",
"\\begin{align}\n",
"\\frac{\\partial f(x,c)}{\\partial c_1} &= -(c_1 + c_2 x + c_3 x^2)^{-2} = - f(x,c)^2 \\\\\n",
"\\frac{\\partial f(x,c)}{\\partial c_2} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x = - f(x,c)^2 x \\\\\n",
"\\frac{\\partial f(x,c)}{\\partial c_3} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x^2 = - f(x,c)^2 x^2 .\n",
"\\end{align}"
]
},
{
"cell_type": "markdown",
"id": "247e6403",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Fitting a rational function"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d6c610e7",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"gradient (generic function with 1 method)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(x, c) = 1 ./ (c[1] .+ c[2].*x + c[3].*x.^2)\n",
"function gradf(x, c)\n",
" f2 = f(x, c).^2\n",
" [-f2 -f2.*x -f2.*x.^2]\n",
"end\n",
"function loss(c)\n",
" r = f(x, c) - y\n",
" 0.5 * r' * r\n",
"end\n",
"function gradient(c)\n",
" r = f(x, c) - y\n",
" vec(r' * gradf(x, c))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ffd37ce6",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c, _, lhist = grad_descent(loss, gradient, ones(3), gamma=8e-2)\n",
"plot(lhist, yscale=:log10)"
]
},
{
"cell_type": "markdown",
"id": "7f4be7cd",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Compare fits on noisy data"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "11ab1f9e",
"metadata": {
"slideshow": {
"slide_type": ""
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scatter(x, y)\n",
"V = vander_chebyshev(x, 7)\n",
"plot!(x -> runge(x), color=:black, label=\"Runge\")\n",
"plot!(x, V * (V \\ y), label=\"Chebyshev fit\")\n",
"plot!(x -> f(x, c), label=\"Rational fit\")"
]
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"celltoolbar": "Slideshow",
"hide_code_all_hidden": false,
"kernelspec": {
"display_name": "Julia 1.7.2",
"language": "julia",
"name": "julia-1.7"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.7.2"
},
"rise": {
"enable_chalkboard": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}