BNNs
require in the end,Pyro
is built to tacklePyro
are applied to the special case of BNNs
.The notebook therefore starts off by introducing somewhat heavy mathematical machinery,
which we are then going to progressively hide/abstract away with
Pyro
for arbitrary Bayesian Machine LearningTyXe
for Bayesian Neural NetworksLet's make sure we have everything we need.
If you installed the dependencies (including ipykernel) in a virtual/conda environment,
Select Kernel > Change Kernel > environmentname
# run this cell to reset the kernel or select kernel > restart kernel
%reset -s -f
!jupyter nbextension enable --py widgetsnbextension
Enabling notebook extension jupyter-js-widgets/extension...
- Validating: OK
import logging
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import pyro
import pyro.distributions as dist
assert pyro.__version__.startswith('1.4.0'), f"For TyXe, pyro version must be exactly 1.4.0, not {pyro.__version__}"
# pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(42)
logging.basicConfig(format='%(message)s', level=logging.INFO)
# Set matplotlib settings
%matplotlib inline
plt.style.use('default')
# Is GPU available?
USE_CUDA = torch.cuda.is_available()
USE_CUDA
True
Most data analysis problems can be understood as elaborations on three basic high-level questions:
In the probabilistic or Bayesian approach to data science and machine learning,
we formalize these in terms of mathematical operations on probability distributions.
We express everything we know about the variables in a problem and the relationships between them in the form of a *probabilistic model*:
It usually has a joint density function of the form
$$p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})$$The distribution over latent variables $p_{\theta}({\bf z})$ in this formula is called the *prior, and the distribution over observed variables given latent variables $p_{\theta}({\bf x}|{\bf z})$ is called the likelihood*.
Probabilistic models are often depicted in a standard graphical notation (Bayesian Networks):
Pyro >= 1.8.0
supports rendering arbitrary functions containing Pyro
expressions as graphical models using pyro.render_model(function, model_args=(x,y))
Once we have specified a model, Bayes' rule tells us how to use it to perform *inference, or draw conclusions about latent variables from data, by computing the posterior distribution* over $\bf z$
$$ p_{\theta}({\bf z} | {\bf x}) = \frac{p_{\theta}({\bf x} , {\bf z})}{ \int \! p_{\theta}({\bf x} , {\bf z})\; d{\bf z} } $$To check the results of modeling and inference, we would like to know how well a model fits observed data $x$, which we can quantify with the *evidence* (or *marginal likelihood*)
$$p_{\theta}({\bf x}) = \int \! p_{\theta}({\bf x} , {\bf z})\; d{\bf z} $$and also to make predictions for new data (and sample new data points!), which we can do with the *posterior predictive distribution*
$$p_{\theta}(x' | {\bf x}) = \int \! p_{\theta}(x' | {\bf z}) p_{\theta}({\bf z} | {\bf x})\; d{\bf z} $$It is often desirable to *learn* the parameters $\theta$ of our models from observed data $x$, which we can do by maximizing the evidence:
$$\theta_{\rm{max}} = \rm{argmax}_\theta p_{\theta}({\bf x}) = \rm{argmax}_\theta \int \! p_{\theta}({\bf x} , {\bf z}) \; d{\bf z} $$One naive way to set the parameters $\theta$ is Maximum Likelihood Estimation (MLE):
$$\theta_{\rm{max}} = \rm{argmax}_\theta p_{\theta}({\bf x} | {\bf z})$$The converse of MLE is to maximize the posterior, this is called Maximum-a-posteriori Estimation (MAP) and requires picking a Prior $p_\theta({\bf z})$:
$$ \theta_{\rm{max}} = \rm{argmax}_\theta p_\theta({\bf z} | {\bf x}) = \rm{argmax}_\theta \frac{p_\theta({\bf x} | {\bf z}) p_\theta({\bf z})}{ \int_Z \! p_\theta({\bf x} | {\bf z})p_\theta({\bf z})\; d{\bf z} } = \rm{argmax}_\theta p_\theta({\bf x} | {\bf z} ) p_\theta({\bf z}) $$(Wee see that Bayes Theorem implies MLE is actually the most naive case of MAP with the Prior $p_\theta({\bf z})$ being uniform!)
Pyro
supports many different approximate inference algorithms; the flagship is stochastic variational inference (SVI)Pyro
calls this guidethat are by construction easy to sample from (e.g. family of gaussians)
to find one that is most similar to the true posterior according to some measure of distance:
the log evidence $\log p_\theta(x)$, which does not depend on $q_{\phi}$,
and a tractable term called the *evidence lower bound (ELBO)*:
$$KL = log\;evidence - ELBO$$
$\Rightarrow$ The guide and model play chase.
Despite the moving target, this optimization problem can be solved well enough for many different problems.
So at high level variational inference is easy:
Probabilistic models in Pyro
are specified as Python functions, e.g. model(*args, **kwargs)
,
that generate observed data from latent variables using special primitive functions
whose behavior can be changed by Pyro
's internals depending on the high-level computation being performed.
Specifically, the different mathematical pieces of model()
and guide()
are encoded via the mapping:
obs
keyword argument (model only)pyro.nn.module
)model(...)
, a guide is encoded as a Python program guide(...)
that contains pyro.sample
and pyro.param
statements.Pyro
with the primitive statement pyro.sample()
the first argument is the name of the random variablez_1
without the obs=
keyword,def model():
pyro.sample("z_1", ...)
then the guide MUST have a matching sample
statement
def guide():
pyro.sample("z_1", ...)
The distributions used in the two cases can be different, but the names must line-up 1-to-1.
obs=
keyword), since the guide needs to be a properly normalized distribution so that it is easy to sample from!model()
and guide()
should take the same argumentsPyro
contains powerful high level abstractions that change the behavior of the model and guideguide
.Pyro
.Please note:
Pyro
s abovementioned advantages.# ---- VAE example on MNIST's handwritten digit data ----
from vae import Encoder, Decoder # appropriate torch.nn.Modules
class VAE(nn.Module):
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# this context makes the samples in the batch independent:
with pyro.plate("data", x.shape[0]):
# setup parameters for gaussian prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will also be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# dist.to_event(n) declares the n rightmost dimensions as being RVs, the rest are batch dimensions
# decode the latent code z
img = self.decoder(z)
# score against actual images
# (results in reconstruction loss = log likelihood for decoder)
pyro.sample("obs", dist.Bernoulli(img).to_event(1), obs=x.reshape(-1, 784)) # bernoulli to sample black/white
# return the img only so that we can visualize it later
return img
VAE.model = model
VAE.model(x)
¶model()
is to register the (previously instantiated) decoder module with Pyro.pyro.module
lets Pyro know about all the parameters inside of the decoder network.Note that:
pyro.plate
..to_event(1)
when sampling from the latent z
tells pyro the rightmost dimension is multivariate. See Pyro
s Tensor Shapes tutorial for more details.VAE.model(x)
as a graphical model:
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
VAE.guide = guide
VAE.guide(x)
¶encoder
) with Pyro.x
and pass it through the encoder.'latent'
.VAE.guide(x)
as a graphical model:
Now that we've defined the full model and guide inside our torch.nn.Module
"VAE
" we can setup inference:
# preliminaries
from utils import make_loaders_mnist
from vae import train
from vae import evaluate
# Run options
smoke_test = True # short run
LEARNING_RATE = 1.0e-3
# Run only for a single iteration for testing
NUM_EPOCHS = 1 if smoke_test else 100
TEST_FREQUENCY = 5
# load MNIST
train_loader, test_loader = make_loaders_mnist(
batch_size=32,
use_cuda=USE_CUDA
)
# clear pyro's global parameter store
pyro.clear_param_store()
To set up the SVI
, all we need to do is:
# setup the VAE
vae = VAE(
z_dim=10,
hidden_dim=40,
use_cuda=USE_CUDA
)
# setup the optimizer
optimizer = pyro.optim.Adam({"lr": LEARNING_RATE}) # pyro.optim wraps torch.optim
# setup the stochastic variational inference algorithm
svi = pyro.infer.SVI(vae.model, vae.guide, optimizer, loss=pyro.infer.Trace_ELBO())
# training
train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
# simply calls svi.step(batch) as train loop
total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
train_elbo.append(-total_epoch_loss_train)
print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
if epoch % TEST_FREQUENCY == 0:
# simply calls svi.evaluate_loss(batch) as test loop
total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
test_elbo.append(-total_epoch_loss_test)
print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
(Bayesian model evaluation with posterior predictive checks)
(Sometimes called sanity check)
# plot example images
row = 12
column = 3
plt.figure(figsize=(20,5))
mock_input = torch.zeros([1, 784]) # does not matter, check VAE.model
if USE_CUDA:
mock_input = mock_input.cuda()
for i in range(1, row * column +1):
plt.subplot(column, row, i)
# get img from the model.
# model samples internally and practically ignores mock_input
sample_img = vae.model(mock_input) # p(img|z)p(z)
img = sample_img[0].view(28, 28).cpu().data.numpy()
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap="gray")
# time for a coffee pause? :)
The training workflow for BNN
s is that
Usually we fix the means to the pretrained weights and only learn the variances.
For a minibatch of data, the ELBO now takes on the form:
$$ELBO = -logq_\phi(z)+logp_\theta(z)+\frac{1}{n}\sum_{i=1}^nlogp_\theta(x_i|z)$$where $z\sim q_\phi(z)$ is drawn ahead of time.
according to Weight Uncertainty in Neural Networks (Blundell et al., 2015). (This is a Monte Carlo Estimate of the ELBO with 1 Sample)
Pyro
, practically.# preliminaries
# dataset helper functions
from utils import make_loaders_bnns, make_net
dataset = "cifar10"
model_name = "resnet18"
pretrained = False # usually, start with a net pretrained using MLE!
# load cifar data
train_loader, test_loader, _ = make_loaders_bnns(
dataset, "./data", 32, 32, False, False
)
# take an existing net:
resnet: torch.nn.Module = make_net(dataset, model_name, pretrained=pretrained)
# make_net only reshapes final layer of torchvision.models.model_name()
# initialize gaussian means using these:
pretrained_weights = resnet.state_dict()
# convert model to pyro module inplace
pyro.nn.module.to_pyro_module_(resnet)
Now we have to make choices for the following three Bayesian components:
For each of these three components, we
We pick standard normals as our prior belief $p_\theta(weights)$ about the weights.
for m in resnet.modules():
# replace weights and biases of
# fully connected and convolutional modules
# (do not be bayesian about batchnorms)
if isinstance(m, (nn.Linear, nn.Conv2d)):
# PyroSample Modules sample on their forward pass,
m.weight = pyro.nn.PyroSample(
dist.Normal(
torch.zeros_like(m.weight),
torch.ones_like(m.weight)
).to_event()
)
if m.bias is not None:
m.bias = pyro.nn.PyroSample(
dist.Normal(
torch.zeros_like(m.bias),
torch.ones_like(m.bias)
).to_event()
)
We define the likelihood $p_\theta(data|weights)$ by defining our model $p_\theta(data|weights) * p(weights)$, i.e.
p(data | weights) * p(weights)
= likelihood * prior
= joint distribution
def model(x, y=None):
logits = resnet(x) # forward samples from prior
# define the likelihood
# p(data | weights)
with pyro.plate("data_plate", x.shape[0]):
# context for IID data points in batch
# log likelihood reconstruction loss
pyro.sample("data", dist.Categorical(logits=logits), obs=y)
# return for prediction/testing only:
return logits
We decide that the guide $q_\phi(weights)$ should be gaussian.
In this case, the use of AutoNormal
leads to diagonal covariance, i.e. independent weights.
This is called Mean-Field assumption and drastically reduces training time.
guide = pyro.infer.autoguide.AutoNormal(
model,
init_scale=1e-4, # initialize the variances uniformly (vars=phi above)
init_loc_fn=pyro.infer.autoguide.init_to_value(
values=pretrained_weights
) # init gauss means to pretrained weights
)
# an optimizer can be taken from pyro.optim (wrapper around torch.optim API)
optim = pyro.optim.Adam({"lr": 1e-3})
# set up stochastic variational inference
svi = pyro.infer.SVI(
model,
guide,
optim,
pyro.infer.Trace_ELBO() # objective function
)
# training
# fit the BNN
num_epochs = 1
first_n_batches_train = 20
for _ in range(num_epochs):
for i, (x, y) in enumerate(iter(train_loader)):
if i > first_n_batches_train:
break
# 1. forward guide(x,y), memorize sampled weight values
# 2. forward model(x,y), using memorized values
# 3. compute the elbo
# 4. elbo.backward() # updates guide
svi.step(x, y)
# prediction
def make_prediction(model, guide, x, y):
trace = pyro.poutine.trace(guide).get_trace(x) # memorize sampled weight values of guide
logits = pyro.poutine.replay(model, trace=trace)(x) # use sampled weight values of guide
predictions = logits.argmax(-1)
acc = ((predictions == y).sum()/y.shape[0]).item()
print(f"Batch acc = {acc}")
first_n_batches_test = 10
test_predictions = [make_prediction(model, guide, x, y) for i, (x, y) in enumerate(test_loader) if i < first_n_batches_test]
There are a few things to note here:
„Pyro was primarily designed" for a „smaller-scale Bayesian Workflow"
where one is interested in „modelling and making inferences on a given dataset
rather than predicting on held-out test data“.
In the next section, we will see that the Pyro-based BNN
library TyXe solves these problems:
TyXe users only have to have familiarity with pytorch, and minimal knowledge of the BNN
workflow.
# first some more imports:
import os
import contextlib
import functools
from typing import List, Optional
import tyxe # bnn library ontop of pyro
from utils import make_loaders_bnns, make_net
In Tyxe, the following short code is equivalent to our previous Bayesian ResNet:
# take an existing torch.nn.Module:
resnet: torch.nn.Module = make_net("cifar10", "resnet18", pretrained=False)
# setup bnn:
prior = tyxe.priors.IIDPrior(dist.Normal(0,1), expose_all=False, hide_module_types=(nn.BatchNorm2d,))
likelihood = tyxe.likelihoods.Categorical(len(train_loader))
guide = functools.partial(
tyxe.guides.AutoNormal,
train_loc=False,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(resnet)
)
bnn = tyxe.VariationalBNN(resnet, prior, likelihood, guide)
# train bnn
# (execute with care; very compute intensive, may be too much for cpu!)
# bnn.fit(train_loader, optim=pyro.optim.Adam({"lr":1e-3}), num_epochs=1, device="gpu" if USE_CUDA else "cpu")
As you can see, users technically don't need to be intricately familiar with Pyro
to use TyXe
.
With TyXe
we've reached a level of abstraction appropriate for BNNs,
so let's look at some more typical training setups in TyXe
.
First, let's get the standard Hyperparameters out of the way:
# Standard Hyperparameters first ...
# MODEL
architecture: str = "resnet18"
dataset: str = "cifar10"
pretrained: bool = False
mock_dataset: bool = False
# DATA
train_batch_size: int = 10
test_batch_size: int = 10
num_epochs: int = 1
test_samples: int = 20
# MISC
root: str = os.environ.get("DATASETS_PATH", "./data")
seed: int = 42
output_dir: Optional[str] = None
# OPTIMIZER
lr: float = 0.001 # initial learning rate
milestones: Optional[List[int]] = None # epochs at which to do scheduler step
gamma: float = 0.1 # scheduler step factor
resnets = [n for n in dir(torchvision.models) if (n.startswith("resnet") or n.startswith("wide_resnet")) and n[-1].isdigit()]
assert architecture in resnets, architecture
datasets = ["cifar10", "cifar100", "mnist"]
assert dataset in datasets, dataset
# Some BNN Hyperparameters:
inference: str = "mean-field"
local_reparameterization: bool = False # important: variance reduction for gradients!
flipout: bool = False # important: variance reduction for gradients!
max_guide_scale: float = 0.1 # to prevent underfitting: clamp learned variance
rank: int = 10 # low rank setting for inference == "last-layer-low-rank"
scale_only: bool = False # train variance only, leaving means at pretrained values
# More comments on these inference options in the section 'guide' below
inference_options = [
"mle", # maximum likelihood estimation: weights = argmax p(data|weights)
"map", # maximum a posteriori inference: weights = argmax p(data|weights)*p(weights)
"mean-field", # svi with autonormal guide (diagonal covariance)
"last-layer-mean-field", # svi for last layer only, autonormal guide and diagonal covariance
"last-layer-full", # svi for last layer only, autonormal guide and FULL covariance
"last-layer-low-rank" # svi for last layer only, low rank
]
assert inference in inference_options, inference
TyXe
into any existing Pytorch workflowtorch.nn.Module
and never touch its internals!# ----- set up pyro & torch -----
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# ----- set up dataset and model -----
train_loader, test_loader, ood_loader = make_loaders_bnns(
dataset, root, train_batch_size, test_batch_size, use_cuda, mock_dataset
)
net: torch.nn.Module = make_net(dataset, architecture, pretrained=pretrained).to(device)
Once again, to set up a BNN we need to make choices for the Prior, Likelihood and Guide.
The Likelihood of our Training Data $p_\theta({\bf x}|{\bf z})$.
The support of the distribution must be equal in size to the number of training samples.
# tyxe.likelihoods includes:
# Bernoulli
# Categorical
# HeteroskedasticGaussian
# HomoskedasticGaussian
likelihood = tyxe.likelihoods.Categorical(len(train_loader.sampler))
# uncomment for documentation:
# tyxe.likelihoods.HeteroskedasticGaussian?
guide
is where most of our flexibility lies.BNN
s expect a partially initialized Autoguide
function, either from Pyro
or TyXe
.None
.($\Rightarrow Training\in\mathcal{O}(\#params)$ instead of $\mathcal{O}(\#params^6)$ - Oxford Blog Post. According to the post & paper, the Diagonal Covariances are good enough especially for deeper models.) 4. Last Layer: Often only the last layer of a Big Conv Net is tuned with SVI
See the Paper for a comparison of these methods both on CIFAR's test set and Out Of Domain data.
if inference == "mle":
# do maximum likelihood estimation
test_samples = 1
guide = None
if inference == "map":
# maximum a posteriori inference
test_samples = 1
guide = functools.partial(
pyro.infer.autoguide.AutoDelta, # deterministic weights
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net)
)
if inference == "mean-field":
# SVI with diagonal covariances
guide = functools.partial(
tyxe.guides.AutoNormal,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
init_scale=1e-4,
max_guide_scale=max_guide_scale, # prevent underfitting
train_loc=not scale_only # train gaussian means?
)
if inference.startswith("last-layer"):
# usually only done for pretrained network:
# if not pretrained:
# raise ValueError("Asked to do last-layer inference, but no pre-trained weights were provided.")
# turning parameters except for last layer in buffers to avoid training them
# this might be avoidable via poutine.block
for module in net.modules():
if module is not net.fc:
for param_name, param in list(module.named_parameters(recurse=False)):
delattr(module, param_name)
module.register_buffer(param_name, param.detach().data)
if inference == "last-layer-mean-field":
guide = functools.partial(
tyxe.guides.AutoNormal,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
init_scale=1e-4
)
elif inference == "last-layer-full":
guide = functools.partial(
pyro.infer.autoguide.AutoMultivariateNormal,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
init_scale=1e-4
)
elif inference == "last-layer-low-rank":
guide = functools.partial(
pyro.infer.autoguide.AutoLowRankMultivariateNormal,
rank=rank,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net),
init_scale=1e-4
)
Pyro's Automatic Guide Families: ['AutoCallable', 'AutoContinuous', 'AutoDelta', 'AutoDiagonalNormal', 'AutoDiscreteParallel', 'AutoGuide', 'AutoGuideList', 'AutoIAFNormal', 'AutoLaplaceApproximation', 'AutoLowRankMultivariateNormal', 'AutoMultivariateNormal', 'AutoNormal', 'AutoNormalizingFlow'] TyXe's Automatic Guide Families: ['AutoNormal']
print("Pyro's Automatic Guide Families:")
print([g for g in dir(pyro.infer.autoguide) if g.startswith("Auto")])
print("TyXe's Automatic Guide Families:")
print([g for g in dir(tyxe.guides) if g.startswith("Auto")])
# uncomment for documentation:
# pyro.infer.autoguide.AutoDelta?
# it is standard practice to not be bayesian about batchnorm modules:
prior_kwargs = {
"expose_all": False, # do not treat all nn.Modules with pyro
"hide_module_types": (nn.BatchNorm2d,) # specifically, ignore batchnorms
}
# our choice of guide impacts how we need to initialize the Prior:
if inference == "mle":
# we dont want a prior for maximum likelihood estimation
prior_kwargs["hide_all"] = True
elif inference.startswith("last-layer"):
# only be bayesian about the final, fully connected layer
del prior_kwargs['hide_module_types']
prior_kwargs["expose_modules"] = [net.fc]
prior = tyxe.priors.IIDPrior(
dist.Normal(
torch.zeros(1, device=device),
torch.ones(1, device=device)
),
**prior_kwargs
)
print("TyXe's Available Prior Distributions:")
print([p for p in dir(tyxe.priors) if p.endswith("Prior")])
# uncomment for documentation:
# tyxe.priors.IIDPrior?
TyXe's Available Prior Distributions: ['DictPrior', 'IIDPrior', 'LambdaPrior', 'LayerwiseNormalPrior', 'Prior']
Thats it! We're ready to set up our BNN:
# Finally set up our VariationalBNN!
bnn = tyxe.VariationalBNN(
net, prior, likelihood, guide
)
# uncomment for documentation:
# bnn?
Pyro
provides many variance reduction techniques and tutorials,TyXe
provides some context managers specific to its BNN
s:# gradient variance reduction techniques:
if local_reparameterization:
if flipout:
raise RuntimeError("Can't use both local reparameterization and flipout, pick one.")
train_context = tyxe.poutine.local_reparameterization
# turns each
# torch.distributions.Normal(loc, scale).sample() (gradient w.r.t. loc, scale is stochastic)
# into
# loc + scale * torch.distributions.Normal(0, 1).sample() (gradient w.r.t. loc, scale is deterministic)
elif flipout:
# usually: use one sampled weight for entire minibatch
# flipout: efficiently sample pseudo-independent weights along the minibatch dimension
train_context = tyxe.poutine.flipout
else:
train_context = contextlib.nullcontext
# pyro-specific: optimizer must come from pyro.optim
if milestones is None:
optim = pyro.optim.Adam({"lr": lr})
else:
optimizer = torch.optim.Adam
optim = pyro.optim.MultiStepLR({"optimizer": optimizer, "optim_args": {"lr": lr}, "milestones": milestones, "gamma": gamma})
print("All typical pytorch optimizers & schedulers are supported by pyro.optim:")
print([opt for opt in dir(pyro.optim) if "_" not in opt and opt[0] == opt[0].upper()])
All typical pytorch optimizers & schedulers are supported by pyro.optim: ['ASGD', 'Adadelta', 'Adagrad', 'AdagradRMSProp', 'Adam', 'AdamW', 'Adamax', 'ChainedScheduler', 'ClippedAdam', 'ConstantLR', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts', 'CyclicLR', 'DCTAdam', 'ExponentialLR', 'LambdaLR', 'LinearLR', 'MultiStepLR', 'MultiplicativeLR', 'NAdam', 'OneCycleLR', 'PyroLRScheduler', 'PyroOptim', 'RAdam', 'RMSprop', 'ReduceLROnPlateau', 'Rprop', 'SGD', 'SequentialLR', 'SparseAdam', 'StepLR']
# tyXe-specific: evaluation and logging done after every epoch:
def callback(
b: tyxe.VariationalBNN, # bnn
i: int, # epoch number
avg_elbo: float # mean elbo this epoch
):
avg_err, avg_ll = 0., 0.
for x, y in iter(test_loader):t
err, ll = b.evaluate(x.to(device), y.to(device), num_predictions=test_samples)
avg_err += err / len(test_loader.sampler)
avg_ll += ll / len(test_loader.sampler)
print(f"ELBO={avg_elbo}; test error={100 * avg_err:.2f}%; LL={avg_ll:.4f}")
# ------ TRAIN THE MODEL ------
with train_context():
bnn.fit(train_loader, optim, num_epochs, callback=callback, device=device)
This notebook and its dependencies were assembled from, in rough order: