Sampling from a GP

Preliminary steps

Loading necessary packages

using Plots
using AugmentedGaussianProcesses
using Distributions
using LinearAlgebra

Loading the banana dataset from OpenML

kernel = SqExponentialKernel()
x = range(0, 10, length=50)
K = kernelmatrix(kernel, x)
f = rand(MvNormal(K + 1e-8I)) # Sample a random GP
y = rand.(Bernoulli.(AGP.logistic.(f)))
y_sign = sign.(y .- 0.5)
50-element Vector{Float64}:
  1.0
 -1.0
 -1.0
 -1.0
 -1.0
 -1.0
 -1.0
  1.0
 -1.0
 -1.0
  ⋮
  1.0
  1.0
  1.0
  1.0
  1.0
  1.0
  1.0
  1.0
  1.0

We create a function to visualize the data

function plot_data(x, y; size=(300,500))
    Plots.scatter(x,
                y,
                alpha=0.2,
                markerstrokewidth=0.0,
                lab="",
                size=size
            )
end
plot_data(x, y; size = (500, 500))

Run the variational gaussian process approximation

@info "Running full model"
mfull = VGP(x, y_sign,
            kernel,
            LogisticLikelihood(),
            AnalyticVI(),
            optimiser = false
            )
@time train!(mfull, 5)
[ Info: Running full model
  0.001201 seconds (250 allocations: 897.969 KiB)

We can also create a sampling based model

@info "Sampling from model"
mmcmc = MCGP(x, y,
            kernel,
            LogisticLikelihood(),
            GibbsSampling(),
            optimiser = false
            )
m = mmcmc
@time samples = sample(mmcmc, 1000)
1000-element Vector{Vector{Vector{Float64}}}:
 [[-0.014698455470761096, -0.2606842291029119, -0.5747296460687921, -0.8476086714886385, -1.0643196039076182, -1.2542348031293136, -1.3650390573319462, -1.4483829901313894, -1.4647996118018591, -1.5098525393440219  …  -0.6645916949707495, -0.23299073303666196, 0.14348371869399668, 0.4842577914728982, 0.8007928605238441, 1.0820387207066244, 1.304694067096788, 1.4636205491637313, 1.5678377831804235, 1.5572041683841147]]
 [[0.13814311594407674, 0.3579116747321309, 0.5118141353433958, 0.5446455816140432, 0.44658872890956536, 0.20997788810686235, -0.1213426970506768, -0.5283061271022496, -0.9058061798376846, -1.183535753655523  …  0.34693416006864325, 0.5843181504453661, 0.752893874300611, 0.8704194047022101, 0.9728052177664355, 1.0626670686993984, 1.1013464515843787, 1.091679235803359, 1.029616376363528, 0.9432989785016939]]
 [[-0.11241074367192383, -0.21767399291390088, -0.43611299873035014, -0.7478363295415273, -1.1126177384429508, -1.474532009937946, -1.7316628338369657, -1.9711103951598872, -2.1470333054489132, -2.216042360609863  …  0.15793012321309963, 0.22526963419806267, 0.28285444358217293, 0.33661277462189754, 0.3789616069054611, 0.42294554519929317, 0.4180458663434534, 0.3596525078356214, 0.1575992484061879, -0.09565794543488781]]
 [[-0.57193495838684, -0.63405016807356, -0.6283508286712771, -0.5862837420791737, -0.5194412055188121, -0.39285088180573424, -0.31105900592912317, -0.2429955507464585, -0.15533432183803086, -0.15438251527733826  …  0.1916816130028308, 0.5426731920852058, 0.9419696567466566, 1.2928643849346373, 1.5835483912896404, 1.787771135779797, 1.827330485950759, 1.7013384463349697, 1.4165897443068118, 1.0366835185680852]]
 [[-1.211704267424024, -1.1234101499228981, -1.0494556551417273, -1.0603019236555116, -1.1454520231056273, -1.2432691520736947, -1.4152817998914016, -1.583906697050999, -1.7613505431057899, -1.9031660156920156  …  0.6586210572449736, 0.9359890449014294, 1.0942210616453745, 1.2608238192434258, 1.4821204088431539, 1.8047816895800697, 2.2475907689209333, 2.7203683352296313, 3.196404872564133, 3.613822719937569]]
 [[-0.30376388694365686, -0.39062241357907745, -0.44081378016116196, -0.42022603697609384, -0.40065113003517083, -0.3379101241786574, -0.20309764168316224, -0.08553148619566675, -0.02430321983876793, -0.06284050252628992  …  -0.2121633875144166, 0.29639949018837036, 0.8131246798707139, 1.2507475245092994, 1.5920908898636668, 1.8274350851512104, 1.9303572889397576, 1.9271559687608575, 1.8712576286625864, 1.7533581634646411]]
 [[0.44229827297065794, 0.11863816935652427, -0.3116753089422817, -0.7507623401090359, -1.201294572049276, -1.5631845368070159, -1.8316368550904318, -1.9317212631364482, -1.9246876410596983, -1.8105708045747013  …  0.3722856544030999, 0.7405684132272676, 1.0873236132736546, 1.3564642120843282, 1.5298026760375787, 1.5709391306342897, 1.4750676420860398, 1.3205987533111654, 1.1309539840465945, 0.9068524650946743]]
 [[-0.12939751754789608, -0.3175294708378315, -0.5076562909494117, -0.7161207395197218, -0.9021206840058608, -1.1009249207512721, -1.2616783488659231, -1.3400544130175471, -1.332842256197198, -1.2353287527481305  …  0.19063354761082185, 0.5676708380883853, 0.9478062270795963, 1.282046947813016, 1.5737360052802942, 1.7673322378768552, 1.9047604610654842, 1.9076969838115674, 1.8659846731052845, 1.7435691222068035]]
 [[-1.7736327095401498, -2.176179530923793, -2.443727858857337, -2.527053682128021, -2.427027029429727, -2.158399400328447, -1.817189067554951, -1.4583080041118879, -1.1171029409687416, -0.8575591457352518  …  -0.020920914167655916, 0.20967579590079596, 0.4673341577279797, 0.8122259589943074, 1.0693965548610136, 1.2829096899834813, 1.4149659532427212, 1.4622511783327958, 1.4153500923615117, 1.2814053076979277]]
 [[-0.9601395691184522, -1.0741333735105503, -1.2013051643694819, -1.3475582344028694, -1.5228683205201188, -1.6442464261688716, -1.7436803638074538, -1.737201450385095, -1.5701757378140166, -1.2537695541937541  …  0.8190665344229676, 1.0346724247432322, 1.204939651884069, 1.3326767849946113, 1.4479340257683193, 1.5600417869409586, 1.7321851910539001, 1.9112843787254543, 2.0553685957676517, 2.1636836716726027]]
 ⋮
 [[0.26047124479878303, 0.07273895364200256, -0.09766077360326064, -0.27603521963583, -0.36741472863156943, -0.40623541638962846, -0.4014096099272457, -0.326826125474187, -0.2039462593329171, -0.08482882963470262  …  1.2962929362953706, 1.7155678331051216, 2.0848343616768, 2.354477601690575, 2.5416582031106105, 2.673636274146658, 2.6657857856384126, 2.635279580950429, 2.4961134479041642, 2.2606609302956686]]
 [[0.690365532041324, 0.4217067644905026, 0.13664485115621317, -0.18580667217728186, -0.46719913055402923, -0.7026433590069381, -0.7820886558825257, -0.7420568433075396, -0.6099266468026338, -0.44496973305869925  …  -0.24715598184031934, 0.19383373101448242, 0.5875997484354917, 0.9596766926344658, 1.240329489574642, 1.4381591359787667, 1.5800032260285686, 1.6455349745918646, 1.6245215197274265, 1.4858042886454679]]
 [[-0.8627918310997778, -1.219418086561622, -1.5180873825840193, -1.721764944918637, -1.8484247343437346, -1.8345283770745113, -1.7470419218548192, -1.601763533422972, -1.3854750531634163, -1.1783372528150307  …  -0.027938946283974087, 0.16701068783267115, 0.4608823970009598, 0.8020640486592915, 1.1894948926304543, 1.5612179226000689, 1.8770779651359284, 2.136945682552444, 2.2644156631231844, 2.299447034987501]]
 [[1.1288726562263267, 0.593619607942639, 0.06793578257153343, -0.4541988107748717, -0.8541620808862078, -1.1001799884137318, -1.226242667678614, -1.2569991348546687, -1.1846478300422385, -1.0817678130682609  …  -0.2822308228347663, 0.16564009371697463, 0.5907663762395692, 0.9281692364833475, 1.18205187408091, 1.3713096169143553, 1.4491536027673908, 1.381308810747439, 1.2409757057564592, 1.0001015996246023]]
 [[-0.9839858974590066, -0.9901114696644299, -1.0639288392151973, -1.1762734171756715, -1.2749579373758093, -1.3281423834201307, -1.3184433869713335, -1.2312379930887682, -1.0972271802339704, -1.0443395092282526  …  -0.3904734979744384, -0.023206136521283782, 0.26110526447926635, 0.4721521187400826, 0.5879717005656202, 0.5992748897905723, 0.64012559440061, 0.6110270177850754, 0.594620267611407, 0.5520975321661459]]
 [[-2.052837195035986, -2.27509724082655, -2.4495006184496986, -2.5376175464107993, -2.5514177729447702, -2.51204082761456, -2.4228632547365723, -2.3096493895031855, -2.1599927252385225, -1.960765112176504  …  0.2065778102562575, 0.6342344392109767, 1.0138227923711933, 1.3034292658439504, 1.4445528252579873, 1.3972292216142619, 1.2076743039193785, 0.9668743734391229, 0.6238589557290756, 0.3087525639378672]]
 [[-0.6592061657569372, -0.9422114489083857, -1.196969614922026, -1.3869411228771056, -1.5528229812572487, -1.6835226188352994, -1.7799464657534805, -1.897641596441726, -1.9201859046338083, -1.936671902168356  …  0.17519792372103365, 0.30332269770189313, 0.3518384507853386, 0.427535374704895, 0.5372233640413239, 0.6514249860956796, 0.7711999017717223, 0.8329662202005983, 0.8812184943669265, 0.8122566033565221]]
 [[-0.4468244604691769, -0.2525431758818718, -0.10522887810255044, 0.005224173240489183, 0.02551803456805457, -0.05844314850906751, -0.27014526031894126, -0.6242309690610751, -1.0188480447673611, -1.4378515879040001  …  0.6216029669563777, 1.0397621037558706, 1.4198049216142519, 1.7271045697369662, 2.039426432008374, 2.29584071121226, 2.495299851586293, 2.6254331130948, 2.638516745359973, 2.508585735239729]]
 [[-0.761546753106194, -1.0613990428819273, -1.3862945236814113, -1.6133358238584083, -1.7665002333880517, -1.7880068811815284, -1.7013719973991446, -1.5381875511427665, -1.3569195678734016, -1.2169355836723466  …  0.02099852999776114, 0.08357484372302848, 0.18258731848564158, 0.32339354392083486, 0.44245856155323726, 0.6136528043614053, 0.7126228768429053, 0.8007580221517201, 0.7788298906033436, 0.787466375783348]]

We can now visualize the results of both models

We first plot the latent function f (truth, the VI estimate, the samples)

p1 = plot(x, f, label="true f")
plot!(x, samples, label="", color=:black, alpha=0.02, lab="")
plot!(x, mean(mfull[1]), ribbon=sqrt.(var(mfull[1])), label="VI")

And we can also plot the predictions vs the data

p2 = plot_data(x, y; size=(600,400))
μ_vi, σ_vi = proba_y(mfull, x)
plot!(x, μ_vi; ribbon=σ_vi, label="VI")
μ_mcmc, σ_mcmc = proba_y(mmcmc, x)
plot!(x, μ_mcmc; ribbon=σ_mcmc, label="MCMC")

This page was generated using Literate.jl.