Sampling from a GP

Preliminary steps

Loading necessary packages

using Plots
using AugmentedGaussianProcesses
using Distributions
using LinearAlgebra

Generating some random data

kernel = SqExponentialKernel()
x = range(0, 10; length=50)
K = kernelmatrix(kernel, x)
f = rand(MvNormal(K + 1e-8I)) # Sample a random GP
y = rand.(Bernoulli.(AGP.logistic.(f)))
y_sign = Int.(sign.(y .- 0.5))
50-element Vector{Int64}:
  1
  1
  1
 -1
 -1
  1
  1
 -1
 -1
 -1
  ⋮
  1
  1
  1
  1
  1
  1
  1
  1
  1

We create a function to visualize the data

function plot_data(x, y; size=(300, 500))
    return Plots.scatter(x, y; alpha=0.2, markerstrokewidth=0.0, lab="", size=size)
end
plot_data(x, y; size=(500, 500))

Run the variational gaussian process approximation

@info "Running full model"
mfull = VGP(x, y_sign, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
(Variational Gaussian Process with a BernoulliLikelihood{GPLikelihoods.LogisticLink}(GPLikelihoods.LogisticLink(LogExpFunctions.logistic)) infered by Analytic Variational Inference , (local_vars = (c = [0.8754280027981819, 0.8016574173236087, 0.7275602563879209, 0.666928053209864, 0.6287656411927122, 0.6117386144033556, 0.6086588046776524, 0.617609101312484, 0.6465940025500956, 0.7038493245983231  …  1.2693870867075527, 1.432077792048025, 1.5779889855449594, 1.6899935179163021, 1.7567998047404605, 1.7738551604934378, 1.743098341262098, 1.6719581550096556, 1.5718978869669944, 1.456675158651429], θ = [0.23516942022761236, 0.23741925537415726, 0.2395260845373134, 0.24112789566262766, 0.24207671738468114, 0.2424847769188537, 0.24255756343344664, 0.24234516713878507, 0.2416392951057221, 0.24016598134616987  …  0.22108134874602478, 0.21453084120893604, 0.20844221583447298, 0.20368321166207928, 0.20082447879637214, 0.20009324942015447, 0.20141157650993, 0.20445291554503206, 0.20869930641688217, 0.21351611601117662]), opt_state = (NamedTuple(),), hyperopt_state = (NamedTuple(),), kernel_matrices = ((K = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.9793417135471325 … 1.4538695041232462e-21 1.9286534177037293e-22; 0.9793906794086937 0.20245939866196944 … 4.489420225833076e-20 6.248468235352852e-21; … ; 1.453942195781206e-21 1.0513088244072457e-20 … 0.034540938136573676 0.07728907039399337; 1.9287498479639178e-22 1.453942195781206e-21 … 0.9793906794087034 0.03454092804975722], 'U', 0),),)))

We can also create a sampling based model

@info "Sampling from model"
mmcmc = MCGP(x, y, kernel, LogisticLikelihood(), GibbsSampling(); optimiser=false)
m = mmcmc
@time samples = sample(mmcmc, 1000)
1000-element Vector{Vector{Vector{Float64}}}:
 [[1.03588000511236, 0.8665938335129812, 0.7431918308535894, 0.6869819022345018, 0.7308256666691322, 0.7917217290196091, 0.9075904822508606, 0.9733346241018603, 1.0378718003135459, 1.0740216723444875  …  1.1541075212901721, 1.1550707924229457, 1.1951251130032012, 1.3330389635369224, 1.5167671980104998, 1.7800644800186696, 2.0493654560793066, 2.3416237462703235, 2.605983657837581, 2.7896494322600063]]
 [[0.3185550147240809, 0.19909111433891896, -0.046775164140757775, -0.3501836826136149, -0.6565027303568248, -0.8583992681073341, -0.9322669537580232, -0.8754981693211129, -0.6946720747971568, -0.4717820918031349  …  -0.02887528340942458, 0.19777280990292723, 0.46185240364835556, 0.728114468840688, 0.9755065005223003, 1.1452573317495314, 1.256566945921038, 1.2842915028883835, 1.2067260939117146, 1.0425484013311455]]
 [[-0.3204591346735259, -0.42689007188705635, -0.48802503014401205, -0.5008039319769456, -0.4377029989370466, -0.2100220501035158, 0.15406404513568506, 0.5651166909915981, 1.015051000201769, 1.4536836004471958  …  -0.20491401629277317, -0.16916410771432489, -0.13169465498435273, -0.10280393250747433, -0.054232094894074434, -0.02923722556619479, 0.03217402173897321, 0.059848338041946114, 0.09873494915015257, 0.13983929787972382]]
 [[0.10289722510834082, 0.14103478524269447, 0.15493017548076832, 0.0525686651780794, -0.08425644044769537, -0.25502863469923753, -0.3938780621994951, -0.4555595563292951, -0.39840957435647406, -0.23481047746640799  …  0.8920741507807516, 1.0980589129562985, 1.3295332891426552, 1.4971658309468903, 1.5244576707116537, 1.3963113642043294, 1.0901599191517366, 0.615795131366764, 0.06820980243192942, -0.49194799066738215]]
 [[1.117063505651636, 0.9054035514033438, 0.6318717981352082, 0.3087640239198286, -0.0004917193763026018, -0.2657651382873503, -0.46232124648841094, -0.5613300924322221, -0.4802851158171905, -0.33097493457282695  …  1.1595114434181415, 1.308242679524583, 1.4591876341458798, 1.5501862386721976, 1.6228274857332416, 1.6724822676313922, 1.6824296100601286, 1.631038991156763, 1.542362068642813, 1.3662905731718786]]
 [[2.5953799034242353, 2.5106047458778407, 2.259815329045826, 1.9004420278046985, 1.452978871022773, 0.925869650064709, 0.4617481675303407, 0.10770110003715358, -0.11453531407066136, -0.16503323704523098  …  -0.18563296960807552, 0.0556968786284483, 0.2826693439585253, 0.49367315776103216, 0.6307391644500588, 0.6368997403034087, 0.5748189696950681, 0.45982181550338463, 0.3645701572643687, 0.2609559601379461]]
 [[-0.43533182963367967, -0.508866428226572, -0.575332487416349, -0.568908341069857, -0.4948529643955661, -0.3412499724987401, -0.07327814152596152, 0.23249457204786037, 0.5168264814810558, 0.7730640481799972  …  2.154997791904152, 2.124462537501796, 2.0725494144065424, 2.1107005911383734, 2.25347318150025, 2.441685281652774, 2.7076386890786104, 2.9077353781405324, 2.971356719025219, 2.838901390847778]]
 [[1.851756814135838, 1.8047166712234157, 1.6826268502385495, 1.4395233961477063, 1.1639173834882865, 0.8266263274703469, 0.48816282153890844, 0.2567954786904236, 0.1556399231251771, 0.18120182269401752  …  1.6186006963542912, 1.7929941175193582, 1.8855981171402212, 1.901738925952382, 1.8770052805620117, 1.8544103144046606, 1.8562987081866131, 1.884829368253902, 1.9241518421016244, 1.9744793423128382]]
 [[0.23521089614848645, 0.38774902967941816, 0.4843258300695398, 0.523198918794793, 0.5897970124899637, 0.6602594681507984, 0.7331491630281735, 0.8560149869985336, 1.0136105606144883, 1.1769424301761846  …  1.0090194144425724, 1.17519933680965, 1.2823129364086778, 1.3676551446568916, 1.3654996756287576, 1.3472737593519941, 1.2665639943332525, 1.179208246976181, 1.0701780992726093, 0.9122604066591304]]
 [[0.8994532244873231, 0.8402862276762235, 0.7936164915125563, 0.7198123978603742, 0.6669135017579791, 0.7021066175582333, 0.8575733708369768, 1.1023875044900744, 1.4375282109282903, 1.7299597522968613  …  1.754606510096628, 1.6196416715487212, 1.5557258777099112, 1.5120840054420992, 1.521775676543065, 1.5524338766936971, 1.6691725180747214, 1.7913254536443266, 1.9456201204730372, 2.099695700200599]]
 ⋮
 [[2.142790665383947, 1.9953774551122923, 1.6460608471657525, 1.1941550513488797, 0.6871676696072114, 0.2537816402177395, -0.03394762279428516, -0.13030637378317592, -0.008592183488040195, 0.2402276536777241  …  1.8123102180442534, 1.7359772079907672, 1.6659380055707098, 1.5885818475329463, 1.4868231906025717, 1.330383453773803, 1.1480570220752302, 0.9219494658412525, 0.6504422796363643, 0.3881756116609419]]
 [[1.0793623905052816, 1.2312835817162329, 1.279589455062725, 1.194850809989223, 1.063856605485046, 0.9287125878461889, 0.8256620069744616, 0.7549893415167435, 0.7351164159547532, 0.7419906039251378  …  1.0387024633103412, 1.1271264547453876, 1.13242773075797, 1.078791401785542, 0.9812027850850651, 0.9357563566463629, 0.9722477883004925, 1.1122649043846042, 1.3859577573461144, 1.7595729643269526]]
 [[2.3530074134108876, 1.942173291180622, 1.5143846269880328, 1.1441628658142342, 0.8958277463728254, 0.7736031476559073, 0.8080461855690004, 0.9541576102992128, 1.135141824979654, 1.2902363958645826  …  1.1334572312465914, 1.2714169212817326, 1.3266767999338722, 1.3319896750225872, 1.341361546515117, 1.3235823843142838, 1.3347877327654891, 1.372134505256873, 1.4424439523743462, 1.4894140333891923]]
 [[-0.673334591866461, -0.6188662000070337, -0.5884093921721674, -0.478396685435988, -0.4064606745719163, -0.2895944581411177, -0.14148331150668564, 0.03971749114931575, 0.23463902655689828, 0.41128601602137754  …  0.6791438685110335, 1.0846276161325472, 1.3906188529464376, 1.5775479601012798, 1.62220676101967, 1.4999726074570736, 1.258045793040635, 0.9943398957715255, 0.7652292416099267, 0.5994430078499196]]
 [[0.8950423932924649, 0.7434394179118422, 0.5052777469321739, 0.17507609136212277, -0.19781583165074507, -0.5253098565639533, -0.7925424281066589, -0.933711572388777, -0.9076827908310001, -0.7036669801429052  …  0.679478217907886, 0.6759419874459471, 0.6501467944106545, 0.5485697119824604, 0.443187322235695, 0.35569216836820283, 0.3152171029785744, 0.32341842453485325, 0.4304068476583771, 0.5778139242836905]]
 [[-0.5240332159018871, -0.4395309963234583, -0.432177888026915, -0.42611363769026006, -0.4235216793713237, -0.32373819611029214, -0.15288621926186668, 0.05455573616555548, 0.3133347439619225, 0.5948593848231893  …  1.619177079470236, 1.6347822351219863, 1.62200698244699, 1.498145526430743, 1.3121898711886841, 1.0944531655683205, 0.8131465102324246, 0.4877035136597102, 0.19098000057550246, -0.10891913152319099]]
 [[2.09397592433748, 1.8025286562637324, 1.4691668508263944, 1.0639774852349966, 0.6404544187505341, 0.23799838228047984, -0.09617344052656052, -0.35286933309311797, -0.5070645784348626, -0.5852436956671379  …  0.04131294970278154, 0.10147703129516805, 0.2381365366286754, 0.4103265705393597, 0.5783744708113052, 0.7012222711310159, 0.7847551803698027, 0.8160821141630947, 0.7493143070658559, 0.6486963417216488]]
 [[1.183436288135546, 0.8291743193109005, 0.47206039405518474, 0.17768580139505288, 0.031285844971386534, 0.06829374408114694, 0.27550033431993537, 0.5951083905352647, 0.9230248111181054, 1.2179950462011884  …  0.231511540159633, 0.5979961114261358, 0.9874137547756529, 1.3495174614249514, 1.6247190360127386, 1.7864116505624106, 1.847380791113246, 1.8052476142431182, 1.6726684722707645, 1.5152243320462078]]
 [[1.442351495809, 1.132333629097298, 0.7768351143602444, 0.4367571482876603, 0.16271635349585623, -0.02335895703125121, -0.09531096196094807, -0.056175483112862856, 0.10955797713890467, 0.39082767491015724  …  1.3404428856413657, 1.2139349356308686, 1.1174395237730874, 1.1203786696390108, 1.1713666302531576, 1.3544082932178232, 1.5847582190277252, 1.8526464646606076, 2.0934552079420667, 2.2575628144267497]]

We can now visualize the results of both models

We first plot the latent function f (truth, the VI estimate, the samples)

p1 = plot(x, f; label="true f")
plot!(x, samples; label="", color=:black, alpha=0.02, lab="")
plot!(x, mean(mfull[1]); ribbon=sqrt.(var(mfull[1])), label="VI")

And we can also plot the predictions vs the data

p2 = plot_data(x, y; size=(600, 400))
μ_vi, σ_vi = proba_y(mfull, x)
plot!(x, μ_vi; ribbon=σ_vi, label="VI")
μ_mcmc, σ_mcmc = proba_y(mmcmc, x)
plot!(x, μ_mcmc; ribbon=σ_mcmc, label="MCMC")

This page was generated using Literate.jl.