Skip to content

Commit b1e159e

Browse files
committed
updated toy example to use minibatched function
1 parent a8a1969 commit b1e159e

File tree

1 file changed

+35
-12
lines changed

1 file changed

+35
-12
lines changed

examples/2D.ipynb

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
{
1313
"cell_type": "code",
1414
"execution_count": 47,
15-
"metadata": {},
15+
"metadata": {
16+
"collapsed": true
17+
},
1618
"outputs": [],
1719
"source": [
1820
"import torch\n",
@@ -36,7 +38,9 @@
3638
{
3739
"cell_type": "code",
3840
"execution_count": 63,
39-
"metadata": {},
41+
"metadata": {
42+
"collapsed": true
43+
},
4044
"outputs": [],
4145
"source": [
4246
"# random points at least 2r apart\n",
@@ -66,7 +70,9 @@
6670
{
6771
"cell_type": "code",
6872
"execution_count": 64,
69-
"metadata": {},
73+
"metadata": {
74+
"collapsed": false
75+
},
7076
"outputs": [
7177
{
7278
"name": "stdout",
@@ -122,7 +128,9 @@
122128
{
123129
"cell_type": "code",
124130
"execution_count": 65,
125-
"metadata": {},
131+
"metadata": {
132+
"collapsed": false
133+
},
126134
"outputs": [
127135
{
128136
"data": {
@@ -162,6 +170,7 @@
162170
"cell_type": "code",
163171
"execution_count": 66,
164172
"metadata": {
173+
"collapsed": false,
165174
"scrolled": true
166175
},
167176
"outputs": [
@@ -203,7 +212,7 @@
203212
"data = []\n",
204213
"opt = optim.Adam(robust_net.parameters(), lr=1e-3)\n",
205214
"for i in range(1000):\n",
206-
" robust_ce, robust_err = robust_loss_batch(robust_net, epsilon, Variable(X), Variable(y), False, False)\n",
215+
" robust_ce, robust_err = robust_loss(robust_net, epsilon, Variable(X), Variable(y))\n",
207216
" out = robust_net(Variable(X))\n",
208217
" l2 = nn.CrossEntropyLoss()(out, Variable(y))\n",
209218
" err = (out.max(1)[1].data != y).float().mean()\n",
@@ -227,7 +236,9 @@
227236
{
228237
"cell_type": "code",
229238
"execution_count": 67,
230-
"metadata": {},
239+
"metadata": {
240+
"collapsed": false
241+
},
231242
"outputs": [
232243
{
233244
"data": {
@@ -266,7 +277,9 @@
266277
{
267278
"cell_type": "code",
268279
"execution_count": 68,
269-
"metadata": {},
280+
"metadata": {
281+
"collapsed": true
282+
},
270283
"outputs": [],
271284
"source": [
272285
"def plot_grid(net, ax): \n",
@@ -289,7 +302,9 @@
289302
{
290303
"cell_type": "code",
291304
"execution_count": 69,
292-
"metadata": {},
305+
"metadata": {
306+
"collapsed": false
307+
},
293308
"outputs": [
294309
{
295310
"data": {
@@ -320,7 +335,9 @@
320335
{
321336
"cell_type": "code",
322337
"execution_count": 26,
323-
"metadata": {},
338+
"metadata": {
339+
"collapsed": false
340+
},
324341
"outputs": [
325342
{
326343
"data": {
@@ -357,7 +374,9 @@
357374
{
358375
"cell_type": "code",
359376
"execution_count": 43,
360-
"metadata": {},
377+
"metadata": {
378+
"collapsed": false
379+
},
361380
"outputs": [
362381
{
363382
"data": {
@@ -385,7 +404,9 @@
385404
{
386405
"cell_type": "code",
387406
"execution_count": 14,
388-
"metadata": {},
407+
"metadata": {
408+
"collapsed": false
409+
},
389410
"outputs": [
390411
{
391412
"ename": "AttributeError",
@@ -428,7 +449,9 @@
428449
{
429450
"cell_type": "code",
430451
"execution_count": null,
431-
"metadata": {},
452+
"metadata": {
453+
"collapsed": true
454+
},
432455
"outputs": [],
433456
"source": []
434457
}

0 commit comments

Comments
 (0)