VAEs on MNIST typically perform reasonably well with two dimensions after enough epochs, but the best way to know this for certain is to test that assumption and try a few other sizes.
For the implementation described in this book, this is a fairly quick change:
w0 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(784, 256), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w1 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(256, 128), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w5 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(8, 128), gorgonia.WithName("w5"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w6 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 256), gorgonia.WithName("w6"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w7 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(256, 784), gorgonia.WithName("w7"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
estMean := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 8), gorgonia.WithName("estMean"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
estSd := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 8), gorgonia.WithName("estSd"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
floatHalf := gorgonia.NewScalar(g, dt, gorgonia.WithName("floatHalf"))
gorgonia.Let(floatHalf, 0.5)
epsilon := gorgonia.GaussianRandomNode(g, dt, 0, 1, 100, 8)
The basic implementation here is with eight dimensions; all we have to do to get it to work on two dimensions is to change all instances of 8 to 2, resulting in the following:
w0 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(784, 256), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w1 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(256, 128), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w5 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(2, 128), gorgonia.WithName("w5"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w6 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 256), gorgonia.WithName("w6"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
w7 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(256, 784), gorgonia.WithName("w7"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
estMean := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 2), gorgonia.WithName("estMean"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
estSd := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(128, 2), gorgonia.WithName("estSd"), gorgonia.WithInit(gorgonia.GlorotU(1.0)))
floatHalf := gorgonia.NewScalar(g, dt, gorgonia.WithName("floatHalf"))
gorgonia.Let(floatHalf, 0.5)
epsilon := gorgonia.GaussianRandomNode(g, dt, 0, 1, 100, 2)
Now all we have to do is recompile the code and then run it, which allows us to see what happens when we try a latent space with more dimensions.
As we can see, it's quite clear that 2 Dimensions is at a disadvantage, but it isn't quite so clear as we move up the ladder. You can see that 20 Dimensions produces appreciably sharper results on average, but really it looks like the 5 Dimension version of the model may already be more than sufficient for most purposes: