Skip to content

Commit c051882

Browse files
committed
add BP with quent
1 parent 6ee59da commit c051882

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

quantization.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import jax
2+
import jax.numpy as jnp # JAX NumPy
3+
import numpy as np # Ordinary NumPy
4+
import wandb
5+
from backprop import sl
6+
from utils import helpers, models, evo
7+
import chex
8+
from args import get_args
9+
from evosax import NetworkMapper, ParameterReshaper, FitnessShaper
10+
from flax.core import FrozenDict
11+
12+
import os
13+
14+
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
15+
16+
17+
# cosine distance
18+
def cosine(x, y):
19+
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(x ** 2)))
20+
21+
22+
def cosine2(x, y):
23+
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(y ** 2)))
24+
25+
# l2 distance
26+
def l2(x, y):
27+
return -1 * jnp.sqrt(jnp.sum((x - y) ** 2)) # / jnp.sqrt(jnp.sum(x ** 2))
28+
29+
30+
def l1(x, y):
31+
return -1 * jnp.sum(jnp.abs(x - y))
32+
33+
34+
#compress the array with quantization based on array distribution
35+
def quantize(array, n_bits):
36+
max_val = array.max()
37+
min_val = array.min()
38+
step = (max_val - min_val) / (2 ** n_bits - 1)
39+
array = ((array - min_val) / step).round()
40+
return array
41+
42+
#dequantization array
43+
def dequantize(array, min_val, max_val, n_bits):
44+
step = (max_val - min_val) / (2 ** n_bits - 1)
45+
array = array * step + min_val
46+
return array
47+
48+
49+
#plot histogram of array
50+
from matplotlib import pyplot as plt
51+
def plot_hist(array):
52+
plt.hist(array, bins=100)
53+
plt.show()
54+
55+
class TaskManager:
56+
def __init__(self, rng: chex.PRNGKey, args):
57+
wandb.run.name = '{}-{}-{} b{} s{} q{} -- {}' \
58+
.format(args.dataset, args.algo,
59+
args.dist,
60+
args.batch_size,
61+
args.seed, args.quantize_bits, wandb.run.id)
62+
63+
wandb.run.save()
64+
self.args = args
65+
66+
def run(self, rng: chex.PRNGKey):
67+
train_ds, test_ds = sl.get_datasets(wandb.config.dataset.lower())
68+
rng, init_rng = jax.random.split(rng)
69+
70+
learning_rate = wandb.config.lr
71+
momentum = wandb.config.momentum
72+
network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)
73+
74+
state = sl.create_train_state(init_rng, network, learning_rate, momentum)
75+
param_reshaper = ParameterReshaper(state.params, n_devices=self.args.n_devices)
76+
test_param_reshaper = ParameterReshaper(state.params, n_devices=1)
77+
# strategy, es_params = evo.get_strategy_and_params(self.args.pop_size, param_reshaper.total_params, self.args)
78+
fit_shaper = FitnessShaper(centered_rank=True, z_score=True, w_decay=self.args.w_decay, maximize=True)
79+
# server = strategy.initialize(init_rng, es_params)
80+
# server = server.replace(mean=test_param_reshaper.network_to_flat(state.params))
81+
# del init_rng # Must not be used anymore.
82+
83+
num_epochs = wandb.config.n_rounds
84+
batch_size = wandb.config.batch_size
85+
nbits = wandb.config.quantize_bits
86+
X, y = jnp.array(train_ds['image']), jnp.array(train_ds['label'])
87+
88+
for epoch in range(1, num_epochs + 1):
89+
# Use a separate PRNG key to permute image data during shuffling
90+
rng, input_rng, rng_ask = jax.random.split(rng, 3)
91+
# Run an optimization step over a training batch
92+
target_state, loss, acc = sl.train_epoch(state, X, y, batch_size, input_rng)
93+
# Evaluate on the test set after each training epoch
94+
target_server = param_reshaper.network_to_flat(target_state.params)
95+
max_val = target_server.max()
96+
min_val = target_server.min()
97+
target_server = quantize(target_server, nbits)
98+
server = dequantize(target_server, min_val, max_val, nbits)
99+
100+
ad = jnp.sum(server - target_server)
101+
102+
state = sl.update_train_state(learning_rate, momentum, test_param_reshaper.reshape_single_net(server))
103+
test_loss, test_accuracy = sl.eval_model(state.params, test_ds, input_rng)
104+
wandb.log({
105+
'Round': epoch,
106+
'Test Loss': test_loss,
107+
'Train Loss': loss,
108+
'Test Accuracy': test_accuracy,
109+
'Train Accuracy': acc,
110+
'Global Accuracy': test_accuracy,
111+
'Information Loss': ad,
112+
})
113+
114+
115+
def run():
116+
print(jax.devices())
117+
args = get_args()
118+
config = helpers.load_config(args.config)
119+
wandb.init(project='evofed', config=args)
120+
wandb.config.update(config)
121+
args = wandb.config
122+
rng = jax.random.PRNGKey(args.seed)
123+
rng, rng_init, rng_run = jax.random.split(rng, 3)
124+
manager = TaskManager(rng_init, args)
125+
manager.run(rng_run)
126+
127+
128+
SWEEPS = {
129+
'cifar-bp': 'bc4zva3u',
130+
'cifar-bp2': '82la1zw0',
131+
'fmnits-mah': '1yksrmvs',
132+
'cifar-mah': 'mtheusi1',
133+
}
134+
135+
if __name__ == '__main__':
136+
run()
137+
# wandb.agent(SWEEPS['cifar-mah'], function=run, project='evofed', count=10)

0 commit comments

Comments
 (0)