Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
js05212 47d8ba5d2b | 2 years ago | |
---|---|---|
figures | 4 years ago | |
.gitignore | 5 years ago | |
README.md | 2 years ago | |
boston_housing_nor_train.pkl | 5 years ago | |
boston_housing_nor_val.pkl | 5 years ago | |
cnn_mlp.sh | 2 years ago | |
cnn_npn.sh | 2 years ago | |
datasets_boston_housing.py | 5 years ago | |
main_mlp.py | 2 years ago | |
mlp-att.sh | 5 years ago | |
mlp.sh | 3 years ago | |
npn.py | 2 years ago | |
npn.sh | 2 years ago | |
npnlite.sh | 2 years ago | |
regress_mlp.sh | 5 years ago | |
regress_npn.sh | 5 years ago | |
utils.py | 5 years ago |
This is the PyTorch code for the NIPS paper 'Natural-Parameter Networks: A Class of Probabilistic Neural Networks'.
It is a class of probabilistic neural networks that treat both weights and neurons as distributions rather than just points in high-dimensional space. Distributions are first-citizens in the networks. The design allows distributions to feedforward and backprop across the network. Given an input data point, NPN will output a predicted distribution with information on both the prediction and uncertainty.
NPN can be used either independently or as a building block for Bayesian Deep Learning (BDL).
Note that this is the code for Gaussian NPN to run on the MNIST and Boston
Housing datasets. For Gamma NPN or Poisson NPN please go to the other repo.
Above is the predictive distribution for NPN. The shaded regions correspond to 3 standard deviations. The black curve is the data-generating function and blue curves show the mean of the predictive distributions. Red stars are the training data.
Above is the classification accuracy for different variance (uncertainty). Note that ‘1’ in the x-axis means the variance is in the range [0, 0.04), ‘2’ means the variance is in the range [0.04, 0.08), etc.
Using only 100 training samples in the training set of MNIST:
Method | Accuracy |
---|---|
NPN (ours) | 74.58% |
MLP | 69.02% |
CNN+NPN (ours) | 86.87% |
CNN+MLP | 82.90% |
Regression task on Boston Housing:
Method | RMSE |
---|---|
NPN (ours) | 3.2197 |
MLP | 3.5748 |
This is everything to implement a three-layer NPN on PyTorch (essentially only need to replace nn.Linear with NPNLinear):
from npn import NPNLinear
from npn import NPNSigmoid
class NPNNet(nn.Module):
def __init__(self):
super(NPNNet, self).__init__()
# Last parameter of NPNLinear
# True: input contains both the mean and variance
# False: input contains only the mean
self.fc1 = NPNLinear(784, 800, False)
self.sigmoid1 = NPNSigmoid()
self.fc2 = NPNLinear(800, 800)
self.sigmoid2 = NPNSigmoid()
self.fc3 = NPNLinear(800, 10)
self.sigmoid3 = NPNSigmoid()
def forward(self, x):
x = self.sigmoid1(self.fc1(x))
x = self.sigmoid2(self.fc2(x))
# output mean (x) and variance (s) of Gaussian NPN
x, s = self.sigmoid3(self.fc3(x))
return x, s
The code is tested under PyTorch 0.2.03 and Python 3.5.2.
The official Matlab version (with GPU support) can be found here
Another version of Pytorch/Python code (with extension to GRU) by sohamghosh121.
Natural-Parameter Networks: A Class of Probabilistic Neural Networks
@inproceedings{DBLP:conf/nips/WangSY16,
author = {Hao Wang and
Xingjian Shi and
Dit{-}Yan Yeung},
title = {Natural-Parameter Networks: {A} Class of Probabilistic Neural Networks},
booktitle = {Advances in Neural Information Processing Systems 29: Annual Conference
on Neural Information Processing Systems 2016, December 5-10, 2016,
Barcelona, Spain},
pages = {118--126},
year = {2016}
}
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》