Sampling from a GP

Preliminary steps

Loading necessary packages

using Plots
using AugmentedGaussianProcesses
using Distributions
using LinearAlgebra

Generating some random data

kernel = SqExponentialKernel()
x = range(0, 10; length=50)
K = kernelmatrix(kernel, x)
f = rand(MvNormal(K + 1e-8I)) # Sample a random GP
y = rand.(Bernoulli.(AGP.logistic.(f)))
y_sign = Int.(sign.(y .- 0.5))
50-element Vector{Int64}:
 -1
  1
  1
 -1
  1
 -1
 -1
 -1
 -1
 -1
  ⋮
  1
 -1
 -1
  1
 -1
  1
  1
 -1
 -1

We create a function to visualize the data

function plot_data(x, y; size=(300, 500))
    return Plots.scatter(x, y; alpha=0.2, markerstrokewidth=0.0, lab="", size=size)
end
plot_data(x, y; size=(500, 500))

Run the variational gaussian process approximation

@info "Running full model"
mfull = VGP(x, y_sign, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
(Variational Gaussian Process with a BernoulliLikelihood{GPLikelihoods.LogisticLink}(GPLikelihoods.LogisticLink()) infered by Analytic Variational Inference , (local_vars = (c = [0.7268688519271366, 0.6835118083607156, 0.6695566835093446, 0.7072217436731908, 0.8064486275080067, 0.9571508620286138, 1.1372811201426742, 1.3244791318411353, 1.500664123702649, 1.6535621857042413  …  0.7025260788078077, 0.6845433297681739, 0.6634782788070405, 0.6454618108106023, 0.6346721444247598, 0.6338307573060205, 0.6449926927064928, 0.6699641377965432, 0.7086332783369658, 0.7577622642768607], θ = [0.23954498463074714, 0.24070110184943624, 0.24106082441326024, 0.24007601274839913, 0.23727762185416854, 0.2325140232179493, 0.22613590884863385, 0.21889808990894175, 0.21168861304801456, 0.20523685890452112  …  0.24020118726734263, 0.24067426974739795, 0.24121558651285327, 0.24166737991273946, 0.2419329443544306, 0.24195349444434558, 0.24167900459731315, 0.24105040837002795, 0.2400382527617409, 0.23868660215795917]), opt_state = (NamedTuple(),), hyperopt_state = (NamedTuple(),), kernel_matrices = ((K = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.9793417135471325 … 1.4538695041232462e-21 1.9286534177037293e-22; 0.9793906794086937 0.20245939866196944 … 4.489420225833076e-20 6.248468235352852e-21; … ; 1.453942195781206e-21 1.0513088244072457e-20 … 0.034540938136573676 0.07728907039399337; 1.9287498479639178e-22 1.453942195781206e-21 … 0.9793906794087034 0.03454092804975722], 'U', 0),),)))

We can also create a sampling based model

@info "Sampling from model"
mmcmc = MCGP(x, y, kernel, LogisticLikelihood(), GibbsSampling(); optimiser=false)
m = mmcmc
@time samples = sample(mmcmc, 1000)
1000-element Vector{Vector{Vector{Float64}}}:
 [[1.1987772738903215, 1.0626602880803855, 0.8254462650963545, 0.4804301055609584, 0.059044285426006315, -0.3325928102033546, -0.6493525210664619, -0.8768270787921308, -0.991443780365423, -1.0474669446015117  …  -0.2867232384928631, -0.43607506785749484, -0.5352244830490909, -0.6579203277397203, -0.7310969207307079, -0.7729361923955621, -0.7705022873905667, -0.8189020060734118, -0.8641034809144862, -0.8833218941470851]]
 [[1.0782063968438542, 1.2142655215688105, 1.1038294457014184, 0.7107993098112406, 0.11322299964961269, -0.6190982337866073, -1.3440008055893164, -1.9528217753569281, -2.369473707476083, -2.5840392917227355  …  -0.35646471348308056, -0.1935425390615092, -0.08735299938927785, -0.03828616368456619, -0.1221183900012801, -0.28113364573643507, -0.5483929377553849, -0.8427486817112102, -1.126765702561727, -1.3841177727243397]]
 [[-0.45709197422065895, -0.8336186826547725, -1.1179112296590439, -1.3579831901601513, -1.525151277008475, -1.6005962282133415, -1.6072910144763113, -1.5485468428064206, -1.4579605382327545, -1.4131009272693782  …  0.6494098792561656, 0.6884187472914299, 0.6356132261921111, 0.4854591900314239, 0.18914777632094715, -0.206313982850915, -0.6651134422167932, -1.1134916842028042, -1.4796116850589411, -1.767770014143595]]
 [[-1.3975719679294487, -1.3719368963676084, -1.371651682867419, -1.473772275139685, -1.725020368304538, -2.067821907058688, -2.517860356288251, -2.9691907476994848, -3.372357060174525, -3.663173409279381  …  -0.33511240301934236, -0.5123895010020223, -0.661845185994863, -0.7785185485396082, -0.9108442417780652, -1.0248925325236855, -1.1416054842141585, -1.1952985705256727, -1.2863688294449163, -1.3790116231129022]]
 [[1.4216500835360437, 1.0806935450904247, 0.5448726387228692, -0.011418779177443383, -0.6011057314438053, -1.1142282446513108, -1.5102404930356923, -1.788712661319774, -2.0522468626737904, -2.2338880855982564  …  -0.11531308589331743, -0.10035716820361404, -0.0912239643994478, -0.01295197792491326, 0.0525421190625007, 0.12664189529443784, 0.1300768868937587, 0.023894615866770108, -0.23999244254677884, -0.6129542344381254]]
 [[0.7777480939812128, 0.6934466126282481, 0.4250833401568493, 0.007702810302029295, -0.5021084919749748, -1.0566881068611216, -1.6182386640694653, -2.1191236515814946, -2.5188725819861624, -2.8104575752460885  …  -0.5296760522454922, -0.41425119619888634, -0.3666104906311517, -0.28594413292231824, -0.2412320772044007, -0.16132586679015273, -0.06284283933404156, 0.033787744888611276, 0.1462457598373954, 0.23877803981300494]]
 [[1.5908535398771517, 1.5717453521590312, 1.33408645412333, 0.9002774256648921, 0.29900718383299996, -0.33339118180153493, -0.9402025450737429, -1.4025788955689058, -1.755063818458122, -1.9405874408894177  …  -0.04512709673575738, 0.20053872134527517, 0.4550152961627843, 0.6615294696730001, 0.7602627716995543, 0.6968464949082231, 0.4650698876428886, 0.13401377262158945, -0.2665232775517152, -0.6983551457280376]]
 [[0.5463834751633856, 0.5070710747137659, 0.4271659732728558, 0.3491788263342075, 0.23892205813561573, 0.05651542797259479, -0.18685513283497357, -0.5437711710050515, -0.9498635642095419, -1.4021731525274124  …  -1.3119870435157288, -1.2950923693950582, -1.2730545697499265, -1.2018514065616845, -1.1124807648559287, -1.0670682336290884, -1.0733505904281775, -1.0523254150639447, -1.050580712374384, -1.0280759032039883]]
 [[1.5009010883402818, 1.2268445801560794, 0.9268881680157782, 0.6138575434096509, 0.27160965721242747, -0.10874323141585951, -0.534678553551557, -0.9784309977239485, -1.446324304551272, -1.8674570646745474  …  -1.1792355267675578, -1.2363393601559678, -1.223108214722293, -1.1330346176265707, -0.9295206733116476, -0.6211867184948823, -0.21889819339686262, 0.2364785514529038, 0.6003666018715887, 0.8932220323764323]]
 [[-0.07529846762842884, -0.23600137605786597, -0.4941665351923584, -0.8400376286578658, -1.2313903907463462, -1.6138039645442879, -1.9283492120504266, -2.1625043849623102, -2.278447552909056, -2.2894259183197705  …  -0.9340656566652998, -1.1853292835712312, -1.3287785338828288, -1.4073619951818204, -1.4246217746639958, -1.3368397368168972, -1.1450656777223063, -0.8639392779760842, -0.47901577325876543, -0.061500185234524024]]
 ⋮
 [[-0.030911190716562828, 0.14804874832262996, 0.2382832943528112, 0.2427073595221414, 0.14144186760374344, -0.03800957171464836, -0.22962802950100347, -0.46995323410453715, -0.7199047183015084, -0.9572149721601062  …  0.2064010213862597, 0.2372403208238234, 0.32030422982223494, 0.40678367287655914, 0.46142804348837696, 0.48615839575720565, 0.4726602955406138, 0.3649391772531057, 0.20294002597884814, 0.02036052824426604]]
 [[-0.006521258082229742, -0.24858179406177533, -0.5489901452393564, -0.8524215659179084, -1.1671418908296154, -1.4402171473330823, -1.667741832161242, -1.842793335663639, -1.9810399295855443, -2.09036996849746  …  0.5324300816515344, 0.464983451686276, 0.3231272663120737, 0.14366072879800748, -0.05169471827733049, -0.2591854783982982, -0.46571735033311346, -0.6416940839403443, -0.7617904593903275, -0.9101088145249614]]
 [[1.0501799664831837, 0.7372770929301343, 0.3683525188321548, -0.04817831952660678, -0.4915263275947954, -0.9571836022851095, -1.4092222231528335, -1.8237495644210688, -2.1401291649022727, -2.406592222659046  …  -0.42199836761915194, -0.48708750747946905, -0.4790914599599898, -0.4300597398494042, -0.36066090349569435, -0.32713995234921756, -0.3313667690013663, -0.35643343352865, -0.39919694251482063, -0.403186700654797]]
 [[-0.8095182752263433, -0.7163698575968257, -0.6425123106957665, -0.6678078515512932, -0.7886812241227763, -0.9609896529740624, -1.2011362841736901, -1.4780586538556886, -1.7156567858735519, -1.9114957201546967  …  -0.4467384118321283, -0.47560404382271565, -0.43394924932406326, -0.35332931837800124, -0.2885669839534609, -0.2676534243133115, -0.27923690139578944, -0.37468721774817554, -0.5083121974522276, -0.5912824346855667]]
 [[0.47618463261242217, 0.3132369546926018, 0.06893674163491861, -0.2480168895811698, -0.5922472956012514, -0.9507524238699329, -1.1944456613521757, -1.2859333103079706, -1.2758542042572527, -1.1652088961625666  …  -0.20174395812120677, -0.036290684411797586, 0.17736697673667406, 0.32744204152575207, 0.453111837540595, 0.5254602120434281, 0.5267489051084443, 0.43607133667572495, 0.3282854243319616, 0.16834086955724858]]
 [[0.6025571246344351, 0.3004056261890749, -0.21559998055837204, -0.831365569408755, -1.4128049645958174, -1.8991377871324382, -2.2051113604741714, -2.340560103298859, -2.341107599340484, -2.263578566276431  …  -0.604130797156291, -0.5843438295758092, -0.5284685541112775, -0.46337315468977025, -0.39529443183829915, -0.32388531440770146, -0.3036132649598269, -0.3656334020232844, -0.4330226039483557, -0.5471336519860786]]
 [[-0.029224753830451443, -0.47276540472228096, -0.8712785246169051, -1.2060544081365894, -1.496073243891201, -1.7200464422603403, -1.8350662231180903, -1.8769018661338834, -1.7963388509555762, -1.6537401299091175  …  -0.7455840137123335, -0.7974276568745027, -0.7903378695491685, -0.7033483529149348, -0.5466776626846825, -0.33678941668081874, -0.07845177400316525, 0.18188273000813532, 0.43374101070030013, 0.6487228618276203]]
 [[0.5282328298215987, 0.38159959100042373, 0.16548109334895172, -0.06256552812030408, -0.2706343677036811, -0.4796265918244418, -0.7091173507459086, -0.9140093216838332, -1.1574657977677918, -1.455783842175181  …  -0.7876095169446913, -1.050349861537839, -1.215380680668162, -1.2928350339328265, -1.3572056648816404, -1.2998673176858797, -1.2760956790059215, -1.2149124695884181, -1.1313543968604054, -0.9962739231175588]]
 [[-0.6582068402287555, -0.7935098526352072, -0.9139365161497157, -1.0668097127338676, -1.1559920874445195, -1.2149816083412897, -1.1783990007714684, -1.0954555091790272, -0.9938015536528918, -0.9092103904942155  …  0.07943490285614702, 0.12509263668882709, 0.20940685465088182, 0.22446917068654654, 0.21762231179152233, 0.25814578880202166, 0.26674117629134375, 0.27270213801338417, 0.25666830176481226, 0.13388675394605587]]

We can now visualize the results of both models

We first plot the latent function f (truth, the VI estimate, the samples)

p1 = plot(x, f; label="true f")
plot!(x, samples; label="", color=:black, alpha=0.02, lab="")
plot!(x, mean(mfull[1]); ribbon=sqrt.(var(mfull[1])), label="VI")