Skip to content

Commit f182ee7

Browse files
authored
Merge pull request #25 from MJ10/fixinstall
SMO Ablation + Correct version of DUE
2 parents f916ace + ef69bf9 commit f182ee7

File tree

6 files changed

+308
-29
lines changed

6 files changed

+308
-29
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ runs/
99
notebooks/figs
1010
notebooks/old notebooks
1111
notebooks/pickles
12+
notebooks/exploring-bias.ipynb
1213

1314
results/
14-
figures/
15+
figures/
16+
new_results/
17+
18+
*.DS_Store

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The notebook `notebooks/fixed_training_set.ipynb` illustrates how DEUP is used t
1919
## Rejecting Difficult Examples
2020
Install DUE:
2121
```bash
22-
pip install git+https://github.com/y0ast/DUE.git
22+
pip install git+https://github.com/y0ast/DUE.git@482f9f05788ca62d539e8e5e8684cfc63d39d6f0
2323
```
2424

2525
We first train the main predictor, variance source and density source on the entire dataset and the spilts for training. The procedure is described in Appendix D.1. This script should take about a day to run on a V100 GPU.

examples/SMO/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
parser.add_argument("--function", default='multi_optima', help='one of the keys of SMO.test_functions.functions')
1818
parser.add_argument("--noise", type=float, default=0, help='additive aleatoric noise')
1919
parser.add_argument("--method", default='deup', help='one of deup, gp, mcdropout, ensemble')
20+
parser.add_argument("--features", default='xv', help='one of xv, x, v, xd, d, xvd')
2021
parser.add_argument("--save_base_path", default='.', help='path to save results')
2122

2223
args = parser.parse_args()
@@ -38,7 +39,7 @@
3839

3940
results = np.zeros((n_seeds, 1 + n_steps))
4041
use_log_unc = True
41-
features = 'xv'
42+
features = args.features
4243

4344
for seed in range(n_seeds):
4445
torch.manual_seed(10 + seed)
@@ -80,10 +81,13 @@
8081
plot_stuff=False,
8182
n_steps=n_steps, epochs=200, domain=X, domain_image=Y, print_each=100, use_log_unc=True,
8283
estimator='gp')
84+
else:
85+
raise NotImplementedError(f"Method {args.method} not implemented !")
8386
results[seed] = outs[0]
87+
print(results[seed])
8488

85-
string = f"{args.method}_{args.function}_{args.n_init}"
89+
string = f"rebuttal_{args.method}_{features}_{args.function}_{args.n_init}"
8690
filename = os.path.join(args.save_base_path, string)
8791
pickle.dump({'results': results, 'string': string}, open(filename, 'wb'))
8892

89-
print('Results saved !')
93+
print(f'Results saved in {filename}!')

notebooks/SMO.ipynb

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"import warnings\n",
3939
"warnings.filterwarnings('ignore')\n",
4040
"\n",
41-
"from test_functions import functions\n",
4241
"from examples.SMO.test_functions import functions, bounds as boundsx\n",
4342
"from uncertaintylearning.features.feature_generator import FeatureGenerator\n",
4443
"from examples.SMO.buffer import Buffer\n",
@@ -144,6 +143,15 @@
144143
" plt.show()\n"
145144
]
146145
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"res_gpdeup"
153+
]
154+
},
147155
{
148156
"cell_type": "markdown",
149157
"metadata": {},
@@ -169,7 +177,7 @@
169177
],
170178
"metadata": {
171179
"kernelspec": {
172-
"display_name": "Python 3",
180+
"display_name": "Python 3.10.6 ('deup')",
173181
"language": "python",
174182
"name": "python3"
175183
},
@@ -183,7 +191,12 @@
183191
"name": "python",
184192
"nbconvert_exporter": "python",
185193
"pygments_lexer": "ipython3",
186-
"version": "3.8.5"
194+
"version": "3.10.6"
195+
},
196+
"vscode": {
197+
"interpreter": {
198+
"hash": "57c0ade180fa71cb589f7bf6b9e050bb8252cc64cdf7f2fc57c1f989acce9d4b"
199+
}
187200
}
188201
},
189202
"nbformat": 4,

notebooks/paperfigures.ipynb

Lines changed: 24 additions & 21 deletions
Large diffs are not rendered by default.

notebooks/smo-ablation-features.ipynb

Lines changed: 255 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)