Skip to content

Commit 1db5e33

Browse files
author
kkkkahlua
committed
initial commit
0 parents  commit 1db5e33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+17771
-0
lines changed

README.md

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# CROP: Certifying Robust Policies for Reinforcement Learning through Functional Smoothing
2+
3+
We propose CROP, the first unified framework for certifying robust policies for RL against test-time evasion attacks on agent observations. In particular, we propose two robustness certification criteria: *robustness of per-state actions* and *lower bound of cumulative rewards*. We then develop three novel methods (LoAct, GRe, LoRe) to achieve certification corresponding to the two certification criteria. More details can be found in our paper:
4+
5+
*Fan Wu, Linyi Li, Zijian Huang, Yevgeniy Vorobeychik, Ding Zhao*, and *Bo Li*, "CROP: Certifying Robust Policies for Reinforcement Learning through Functional Smoothing", [ICLR 2022](https://openreview.net/forum?id=HOjLHrlZhmx)
6+
7+
All experimental results are available at the website https://crop-leaderboard.github.io.
8+
9+
## Content of the repository
10+
11+
In our paper, we apply our **three** certification algorithms (*CROP-LoAct, CROP-GRe*, and *CROP-LoRe*) to certify **nine** RL methods (*StdTrain, GaussAug, AdvTrain, SA-MDP (PGD,CVX), RadialRL, CARRL, NoisyNet*, and *GradDQN*) on two high-dimensional Atari games (*Pong* and *Freeway*), one low dimensional control environment (*CartPole*), and an autonomous driving environment (*Highway*). For all algorithms in all environments, we obtain certification on either per-state action stability or cumulative reward lower bound.
12+
13+
In this repository, we provide the code for our CROP framework, built on top of the deep RL codebase [CleanRL](https://github.com/vwxyzjn/cleanrl). Basically, our repository contains both the code for
14+
15+
1. training policies (via functionalities provided by CleanRL codebase), and
16+
2. certifying the trained policies (via testing code and APIs in our CROP framework).
17+
18+
Below, we first present example commands for running the certifications (including LoAct, GRe, and LoAct), and then provide the usage of the easy-to-use plug-and-play APIs such that interested readers can directly integrate these certification APIs into their own testing code for their trained models.
19+
20+
## Example commands for certification
21+
22+
In this part, we present the example commands for obtaining certification corresponding to two certification criteria via three certification algorithms.
23+
24+
### CROP-LoAct
25+
26+
We first run the pre-processing step to obtain the output range of the Q-network, e.g.,
27+
28+
```bash
29+
python cleanrl_estimate_q_range.py \
30+
--load-checkpoint <model_path> --dqn-type <model_type> \
31+
--m 10000 --sigma 0.01
32+
```
33+
34+
Then, we update the configuration file `config_v_table.py` and run LoAct to obtain the certification for per-state action stability via local smoothing, e.g.,
35+
36+
```bash
37+
python cleanrl_certify_r.py \
38+
--load-checkpoint <model_path> --dqn-type <model_type> \
39+
--m 10000 --sigma 0.01
40+
```
41+
42+
The results are stored in files with the suffix `_certify-r-{i}.pt`.
43+
44+
### CROP-GRe
45+
46+
Example command to run GRe to obtain the certification for cumulative reward via global smoothing:
47+
48+
```bash
49+
python cleanrl_run_global.py \
50+
--gym-id PongNoFrameskip-v4 --restrict-actions 4 \
51+
--load-checkpoint <model_path> --dqn-type <model_type> \
52+
--max-episodes 10000 --sigma 0.01
53+
```
54+
55+
The results are stored in the file with the suffix `_global-reward.pt`.
56+
57+
### CROP-LoRe
58+
59+
Example command to run LoRe to obtain the certification for cumulative reward via adaptive search algorithm along with local smoothing:
60+
61+
```bash
62+
python cleanrl_certify_r.py \
63+
--gym-id PongNoFrameskip-v4 --restrict-actions 4 \
64+
--load-checkpoint <model_path> --dqn-type <model_type> \
65+
--m 10000 --sigma 0.01
66+
```
67+
68+
The results are stored in the file with the suffix `_certify-map.pt`.
69+
70+
## Usage of APIs
71+
72+
### class LoAct
73+
74+
- **Filepath**: ``lo_act.py``
75+
- **Class name**: ``LoAct``
76+
- **Input variables**:
77+
78+
``log_func``: the function for logging information
79+
80+
``input_shape``: shape of the state observation
81+
82+
``model``: the model (Q-network)
83+
84+
``forward_func``: the model forward function of the given model (e.g., ``model.forward``). This function returns the Q-value.
85+
86+
``m``: number of samples for randomized smoothing
87+
88+
``sigma``: standard deviation of the smoothing Gaussian noise
89+
90+
``v_lo`` and ``v_hi``: the estimated output range of the Q-network. Details see Section 4.2 of the paper.
91+
92+
``conf_bound``: parameter for computing the confidence interval in the Hoeffding's inequality.
93+
94+
- **Functions**:
95+
96+
``__init__``: initialization
97+
98+
``init_record_list``: initialize the statistics to be saved
99+
100+
``forward``: 1) perform randomized smoothing; 2) compute the certification via Theorem 1 in Section 4.1
101+
102+
``save``: save the statistics and reset
103+
104+
* **How to incorporate the API**
105+
106+
1. *Model loading*: after loading the model as in the original testing, wrap the loaded model into ``LoAct``;
107+
2. *Forwarding*: replace the original forwarding step via the model with the forwarding step via ``LoAct``;
108+
3. *Statistics saving*: after finishing one episode, save the stored statistics and reset.
109+
110+
* **Example file for proper usage of the API**: ``cleanrl_certify_r.py``
111+
112+
### class GRe
113+
114+
- **Filepath**: ``g_re.py``
115+
- **Class name**: ``GRe``
116+
117+
- **Input variables**:
118+
119+
``log_func``, ``input_shape``, ``model``, ``forward_func``, ``sigma``: same as described in the previous part for class LoAct
120+
121+
- **Functions**:
122+
123+
``__init__``: initialization
124+
125+
``forward``: perform global smoothing by adding one noise at each given time step
126+
127+
* **How to incorporate the API**
128+
129+
1. *Model loading*: after loading the model as in the original testing, wrap the loaded model into ``GRe``;
130+
2. *Forwarding*: replace the original forwarding step via the model with the forwarding step via ``GRe``;
131+
3. *Statistics saving*: after completing ``args.max_episodes`` number of trajectories via ``GRe`` forwarding and obtaining the cumulative rewards for these ``args.max_episodes`` randomized trajectories, save these reward values.
132+
133+
* **Example file for proper usage of the API**: ``cleanrl_run_global.py``
134+
135+
### class LoRe
136+
137+
- **Filepath**: ``lo_re.py``
138+
- **Class name**: ``LoRe``
139+
- **Input variables**:
140+
141+
``log_func``, ``input_shape``, ``model``, ``forward_func``, ``m``, ``sigma``, ``v_lo``, ``v_hi``, ``conf_bound``: same as described in the previous part for class LoAct
142+
143+
``max_frames_per_episode``: the trajectory/horizon length to evaluate for the reward certification, i.e., $H$
144+
145+
- **Main functions**:
146+
147+
``__init__``: initialization, including the preparation for priority queue and the memorization in search
148+
149+
``run``: the entire adaptive search algorithm, alternating between the *trajectory exploration and expansion* step and the *perturbation magnitude growth* step in the loop
150+
151+
``expand``: the *trajectory exploration and expansion* step
152+
153+
``take_action``: deciding the possible action set for each current step, via Theorem 4 in Section 5.2
154+
155+
``save``: saving ``certified_map`` which contains the list of mappings from perturbation magnitudes to the corresponding certified lower bounds to the corresponding file
156+
157+
* **How to incorporate the API**
158+
159+
1. *Model loading*: after loading the model as in the original testing, wrap the loaded model into ``LoAct``;
160+
2. *Adaptive search*: directly call ``lo_re.run(env, obs)``, where ``obs`` is the fixed initial observation;
161+
162+
* **Example file for proper usage of the API**: ``cleanrl_tree.py``
163+
164+
* **Sidenote**: During the growth of the tree, we keep track of the nodes and edges corresponding to the states and transitions. The tree structure can be saved via ``save_tree`` at the end of the adaptive search, which facilitates the visualization of the search tree as well as the understanding of the certification procedure.
165+
166+
## Reference
167+
168+
```tex
169+
@inproceedings{wu2022crop,
170+
title={CROP: Certifying Robust Policies for Reinforcement Learning through Functional Smoothing},
171+
author={Wu, Fan and Li, Linyi and Huang, Zijian and Vorobeychik, Yevgeniy and Zhao, Ding and Li, Bo},
172+
booktitle={International Conference on Learning Representations},
173+
year={2022}
174+
}
175+
```

attacks.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import torch
2+
from torch import autograd
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
import torchvision.transforms as transforms
6+
import torchvision.datasets as datasets
7+
import torch.nn.functional as F
8+
import numpy as np
9+
TARGET_MULT = 10000.0
10+
11+
USE_CUDA = torch.cuda.is_available()
12+
Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)
13+
14+
CARTPOLE_STD=[0.7322321, 1.0629482, 0.12236707, 0.43851405]
15+
ACROBOT_STD=[0.36641926, 0.65119815, 0.6835106, 0.67652863, 2.0165246, 3.0202584]
16+
17+
18+
def pgd(model, X, y, verbose=False, params={}, env_id="", norm_type='l_2'):
19+
X /= 255
20+
epsilon = params.get('epsilon', 0.00392)
21+
niters = params.get('niters', 10)
22+
img_min = params.get('img_min', 0.0)
23+
img_max = params.get('img_max', 1.0)
24+
network_type = params.get('network_type', 'nature')
25+
loss_func = params.get('loss_func', nn.CrossEntropyLoss())
26+
step_size = epsilon * 1.0 / niters
27+
y = Variable(torch.tensor(y))
28+
if verbose:
29+
print('epislon: {}, step size: {}, target label: {}'.format(epsilon, step_size, y))
30+
31+
X_adv = Variable(X.data, requires_grad=True)
32+
33+
for i in range(niters):
34+
35+
if network_type == 'noisynet':
36+
model.model.sample()
37+
38+
_, logits = model.forward_requires_grad(X_adv, return_q=True)
39+
40+
loss = loss_func(logits, y)
41+
if verbose:
42+
print('current loss: ', loss.data.cpu().numpy())
43+
model.zero_grad()
44+
loss.backward()
45+
46+
if norm_type == 'l_inf':
47+
eta = step_size * X_adv.grad.data.sign()
48+
elif norm_type == 'l_2':
49+
if not torch.norm(X_adv.grad).item():
50+
eta = step_size * X_adv.grad.data
51+
else:
52+
eta = step_size * X_adv.grad.data / torch.norm(X_adv.grad).data
53+
54+
X_adv = Variable(X_adv.data + eta, requires_grad=True)
55+
# adjust to be within [-epsilon, epsilon]
56+
57+
if norm_type == 'l_inf':
58+
eta = torch.clamp(X_adv.data - X.data, -epsilon, epsilon)
59+
60+
elif norm_type == 'l_2':
61+
eta = X_adv.data - X.data
62+
# print('iter', i, 'second eta', torch.norm(eta))
63+
if torch.norm(eta) > epsilon:
64+
eta *= epsilon / torch.norm(eta)
65+
66+
X_adv.data = X.data + eta
67+
if verbose:
68+
print('max eta: ', np.max(abs(eta.data.cpu().numpy())))
69+
print('linf diff before clamp: ', np.max(abs(X_adv.data.cpu().numpy()-X.data.cpu().numpy())))
70+
71+
X_adv.data = torch.clamp(X_adv.data, img_min, img_max)
72+
if verbose:
73+
print('linf diff after clamp: ',np.max(abs(X_adv.data.cpu().numpy()-X.data.cpu().numpy())))
74+
75+
if verbose:
76+
print('{} iterations'.format(i+1))
77+
78+
return torch.clamp((X_adv.data * 255).long(), 0, 255)
79+
80+
81+
def fgsm(model, X, y, verbose=False, params={}):
82+
epsilon=params.get('epsilon', 1)
83+
img_min=params.get('img_min', 0.0)
84+
img_max=params.get('img_max', 1.0)
85+
X_adv = Variable(X.data, requires_grad=True)
86+
logits = model.forward(X_adv)
87+
loss = F.nll_loss(logits, y)
88+
model.features.zero_grad()
89+
loss.backward()
90+
eta = epsilon*X_adv.grad.data.sign()
91+
X_adv = Variable(X_adv.data + eta, requires_grad=True)
92+
X_adv.data = torch.clamp(X_adv.data, img_min, img_max)
93+
return X_adv.data
94+
95+
96+
97+
def rand_attack(model, X, y, verbose=False, params={}, env_id=""):
98+
epsilon = params.get('epsilon', 0.00392)
99+
if env_id == "CartPole-v0":
100+
epsilon = torch.from_numpy(CARTPOLE_STD) * epsilon
101+
if env_id == "Acrobot-v1":
102+
epsilon = torch.from_numpy(ACROBOT_STD) * epsilon
103+
img_min = params.get('img_min', 0.0)
104+
img_max = params.get('img_max', 1.0)
105+
noise = 2 * epsilon * torch.rand(X.data.size()) - epsilon
106+
if USE_CUDA:
107+
noise = noise.cuda()
108+
X_adv = torch.clamp(X.data + noise, img_min, img_max)
109+
X_adv = Variable(X_adv.data, requires_grad=True)
110+
return X_adv.data
111+
112+
113+
def attack(model, X, attack_config, loss_func=nn.CrossEntropyLoss(), epsilon=0.00392, smooth_type='', network_type='nature'):
114+
# method = attack_config.get('method', 'pgd')
115+
# verbose = attack_config.get('verbose', False)
116+
# params = attack_config.get('params', {})
117+
method = 'pgd'
118+
verbose = False
119+
params = {
120+
'epsilon': epsilon,
121+
'network_type': network_type,
122+
}
123+
params['loss_func'] = loss_func
124+
125+
if network_type == 'noisynet':
126+
model.model.sample()
127+
128+
if smooth_type == 'local':
129+
_, output = model.forward(X, cert=False, return_q=True)
130+
elif smooth_type == 'global':
131+
_, output = model.forward(X, return_q=True)
132+
else:
133+
raise NotImplementedError(f'smooth_type = {smooth_type} not implemented!')
134+
135+
y = torch.argmax(output, dim=1)
136+
# y = model.act(X, cert=False)
137+
if method == 'cw':
138+
atk = cw
139+
elif method == 'rand':
140+
atk = rand_attack
141+
elif method == 'fgsm':
142+
atk = fgsm
143+
else:
144+
atk = pgd
145+
adv_X = atk(model, X, y, verbose=verbose, params=params)
146+
abs_diff = abs(adv_X.cpu().numpy()-X.cpu().numpy())
147+
if verbose:
148+
print('adv image range: {}-{}, ori action: {}, adv action: {}, l1 norm: {}, l2 norm: {}, linf norm: {}'.format(torch.min(adv_X).cpu().numpy(), torch.max(adv_X).cpu().numpy(), model.act(X)[0], model.act(adv_X)[0], np.sum(abs_diff), np.linalg.norm(abs_diff), np.max(abs_diff)))
149+
return adv_X
150+

auto_LiRPA/LICENSE

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Copyright 2020 Kaidi Xu, Zhouxing Shi, Huan Zhang
2+
3+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4+
5+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6+
7+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8+
9+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10+
11+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

0 commit comments

Comments
 (0)