Working with arviz
¶
arviz
is a wonderful library for analysis of Bayesian models. It includes many functions to help researchers determine how well a model performs both visually and numerically. However, those coming from primarily pandas
and numpy
experience may find it a bit challenging to orient themselves to using arviz
and its underlying data structure derived from the xarray
package. Here, we give a short breakdown on how to start analyzing an arviz.InferenceData
object.
For more information on these two packages, see their documentation:
xarray
¶
An arviz.InferenceData
is essentially a collection of xarray
objects so we will cover xarray
first. xarray
is a Python package designed for efficient and elegant interaction of multi-dimensional data. In Bayesian analysis we deal primarily with data that includes both chain and draw as dimensions in addition to the original parameter dimensions. For example, in a birdman.NegativeBinomial
model, the \(\beta\) parameter will be output with 4 dimensions: chain, draw, covariate, and feature. Dealing with all of these dimensions in simple NumPy arrays can get confusing as you try to keep track of which dimension is which. xarray
uses named dimensions and coordinates to make this process much cleaner and intuitive.
xarray.Dataset
¶
The Dataset
is the primary data structure that you will be interfacing with in xarray
. A Dataset
is a collection of data variables (for example parameters from posterior sampling) that can have different dimensionality. As an example, here is a Dataset
holding posterior draws from a birdman.NegativeBinomial
model:
<xarray.Dataset>
Dimensions: (chain: 4, covariate: 2, draw: 100, feature: 143)
Coordinates:
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
* feature (feature) object '193358' '4465746' ... '212228' '192368'
* covariate (covariate) object 'Intercept' 'diet[T.DIO]'
Data variables:
beta (chain, draw, covariate, feature) float64 7.142 3.549 ... 0.4587
phi (chain, draw, feature) float64 0.1216 0.3286 ... 0.8205 0.6151
Attributes:
created_at: 2021-04-01T22:37:03.663497
arviz_version: 0.11.2
inference_library: cmdstanpy
inference_library_version: 0.9.68
The Dimensions
descriptor shows the names and number of entries in each dimensions. The Coordinates
entry holds the labels for each of the dimensions. In this example we see that the chains are labeled 0-3 and the draws are labeled 0-99. However, the features are labeled with OTU IDs and the covariates are labeled with the entries in the design matrix. The Data variables
entry contains the actual data (in this case parameters) and lists the dimensionality of each. Note that beta
is of dimension chain, draw, covariate, feature while phi
is only of dimension chain, draw, feature. The Dataset
is a powerful data structure because it can hold data variables of varying dimensionality.
xarray.DataArray
¶
A Dataset
is simply a collection of DataArray
objects. Whereas a Dataset
can contain multiple data variables, a DataArray
contains only one. If you want to access the beta
variable from the above Dataset
, you simply index it like you would a dictionary. If you have a Dataset
, ds
, with a data variable beta
, you would access it with ds["beta_var"]
which returns:
<xarray.DataArray 'beta' (chain: 4, draw: 100, covariate: 2, feature: 143)>
array([[[[ 7.14216 , ..., -0.03673 ],
[-0.639199, ..., 0.958811]],
...,
[[ 7.071049, ..., -0.374691],
[-0.61982 , ..., 0.5009 ]]],
...,
[[[ 7.096262, ..., -0.281968],
[-0.607823, ..., 1.190807]],
...,
[[ 7.185024, ..., 0.038614],
[-0.671318, ..., 0.458722]]]])
Coordinates:
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
* feature (feature) object '193358' '4465746' ... '212228' '192368'
* covariate (covariate) object 'Intercept' 'diet[T.DIO]'
Selecting and indexing data¶
Manipulating data in xarray
is a bit more involved than in NumPy. The most important thing to keep in mind is the notion of dims
(dimensions) and coords
(coordinates). Dimensions are the names while coordinates are the labels.
You can use the .sel
function to select specific slices of data. To extract all values from just chain 0, you would run:
ds["beta_var"].sel(chain=0)
You can also index across multiple dimensions - if you wanted only values from chain 2 from the diet covariate you would run:
ds["beta_var"].sel(chain=2, covariate="diet[T.DIO]")
This also works with multiple values for a given dimension. As an example if you wanted to get all diet posterior samples from just features 193358 and 4465746 you would run:
ds["beta_var"].sel(feature=["193358", "4465746"], covariate="diet[T.DIO]")
See the documentation for more on indexing and selecting data.
arviz
¶
An arviz.InferenceData
object is a collection of xarray.Datasets
organized for use in Bayesian model analysis. Each inference comprises several groups such as posterior draws, sample stats, log likelihood values, etc. arviz
organizes these different groups such that they can be used seamlessly for downstream analysis.
If you run a birdman.NegativeBinomial
model and convert it to an inference object, you can print this object and see the following:
Inference data with groups:
> posterior
> posterior_predictive
> log_likelihood
> sample_stats
> observed_data
Each group is an xarray.Dataset
that you can interact with as described above. You can access each of these groups with either attribute notation (inference.posterior
) or index notation (inference["posterior"]
).
Saving and loading data¶
It is useful to be able to save the results of BIRDMAn so that they can be analyzed later or distributed to collaborators. The best way to do this is to save the InferenceData
object in the NetCDF format. This is a compressed format that works very well with multi-dimensional arrays.
You can save and load fitted models with to_netcdf
and from_netcdf
.
import arviz as az
inference.to_netcdf("inference.nc")
inference_loaded = az.from_netcdf("inference.nc")