Skip to content

Commit 99c26fe

Browse files
committed
add fedavg quantization
1 parent dbcfc1f commit 99c26fe

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

fedavg_quantization.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import chex
2+
import jax
3+
import jax.numpy as jnp # JAX NumPy
4+
import numpy as np
5+
import tensorflow_datasets as tfds # TFDS for MNIST
6+
import wandb
7+
from evosax import NetworkMapper
8+
from backprop import sl
9+
from args import get_args
10+
from utils import helpers
11+
from flax.core import FrozenDict
12+
from evosax import ParameterReshaper
13+
import os
14+
15+
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
16+
17+
18+
# compress the array with quantization based on array distribution
19+
def quantize(array, min_val, max_val, n_bits):
20+
# max_val = array.max()
21+
# min_val = array.min()
22+
step = (max_val - min_val) / (2 ** n_bits - 1)
23+
array = ((array - min_val) / step).round()
24+
return array
25+
26+
27+
# dequantization array
28+
def dequantize(array, min_val, max_val, n_bits):
29+
step = (max_val - min_val) / (2 ** n_bits - 1)
30+
array = array * step + min_val
31+
return array
32+
33+
def sparsify(array, percentage):
34+
original = array
35+
array = jnp.abs(array.flatten())
36+
array = jnp.sort(array)
37+
threshold = array[int(len(array) * percentage)]
38+
array = jnp.where(jnp.abs(original) < threshold, 0, original)
39+
return array
40+
41+
# L2 distance
42+
def l2(x, y):
43+
return -1 * jnp.sqrt(jnp.sum((x - y) ** 2)) # / jnp.sqrt(jnp.sum(x ** 2))
44+
45+
46+
class TaskManager:
47+
def __init__(self, rng: chex.PRNGKey, args):
48+
wandb.run.name = '{}-{}-{} b{} c{} s{} q{} -- {}' \
49+
.format(args.dataset, args.algo,
50+
args.dist,
51+
args.batch_size, args.n_clients,
52+
args.seed, args.quantize_bits, wandb.run.id)
53+
wandb.run.save()
54+
self.train_ds, self.test_ds = sl.get_fed_datasets(args.dataset, args.n_clients, 2, args.dist == 'IID')
55+
56+
rng = jax.random.PRNGKey(0)
57+
rng, init_rng = jax.random.split(rng)
58+
self.learning_rate = wandb.config.lr
59+
self.momentum = wandb.config.momentum
60+
network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)
61+
62+
self.state = sl.create_train_state(init_rng, network, self.learning_rate, self.momentum)
63+
del init_rng # Must not be used anymore.
64+
65+
self.param_count = sum(x.size for x in jax.tree_leaves(self.state.params))
66+
self.num_epochs = wandb.config.n_rounds
67+
self.batch_size = wandb.config.batch_size
68+
self.client_epoch = wandb.config.client_epoch
69+
self.n_clients = args.n_clients
70+
min_cut = 10000
71+
# if args.dataset == 'mnist':
72+
# min_cut = 5421
73+
74+
self.X = jnp.array([train['image'][:min_cut] for train in self.train_ds])
75+
self.y = jnp.array([train['label'][:min_cut] for train in self.train_ds])
76+
self.args = args
77+
self.n_bits = args.quantize_bits
78+
self.param_reshaper = ParameterReshaper(self.state.params, n_devices=1)
79+
80+
def run(self, rng: chex.PRNGKey):
81+
for epoch in range(0, self.num_epochs + 1):
82+
# Use a separate PRNG key to permute image data during shuffling
83+
rng, input_rng = jax.random.split(rng)
84+
# Run an optimization step over a training batch
85+
# clients = [self.state for i in range(5)]
86+
clients, loss, acc = jax.vmap(sl.train_epoch, in_axes=(None, 0, 0, None, None))(self.state,
87+
self.X,
88+
self.y,
89+
self.batch_size, input_rng)
90+
for c_epoch in range(self.client_epoch):
91+
input_rng, c_rng = jax.random.split(input_rng)
92+
clients, loss, acc = jax.vmap(sl.train_epoch, in_axes=(0, 0, 0, None, None))(clients,
93+
self.X,
94+
self.y,
95+
self.batch_size, c_rng)
96+
wandb.log({
97+
'Epoch': epoch * self.client_epoch + c_epoch,
98+
'Train Loss': loss.mean(),
99+
'Train Accuracy': acc.mean(),
100+
})
101+
102+
server = self.param_reshaper.network_to_flat(self.state.params)
103+
target_server = jax.vmap(self.param_reshaper.network_to_flat)(clients.params)
104+
target_server = (target_server - server)
105+
min_val, max_val = jax.vmap(jnp.min)(target_server), jax.vmap(jnp.max)(target_server)
106+
target_server = jax.vmap(sparsify, in_axes=(0, None))(target_server, self.args.percentage)
107+
108+
# target_server = jax.vmap(quantize, in_axes=(0, 0, 0, None))(target_server, min_val, max_val, self.n_bits)
109+
# target_server = jax.vmap(dequantize, in_axes=(0, 0, 0, None))(target_server, min_val, max_val, self.n_bits)
110+
target_server = jax.vmap(jnp.mean)(target_server.T)
111+
# target_server = jax.vmap(quantize, in_axes=(0, None))(target_server, self.n_bits)
112+
# target_server = dequantize(target_server, min_val.mean(), max_val.mean(), self.n_bits)
113+
target_server = sparsify(target_server, self.args.percentage)
114+
115+
target_server = target_server + server
116+
params = self.param_reshaper.reshape_single_net(target_server)
117+
self.state = self.state.replace(params=FrozenDict(params))
118+
rng, eval_rng = jax.random.split(rng)
119+
test_loss, test_accuracy = sl.eval_model(params, self.test_ds, eval_rng)
120+
remining_params = self.param_count * (1 - self.args.percentage)
121+
122+
wandb.log({
123+
'Round': epoch,
124+
'Test Loss': test_loss,
125+
'Global Accuracy': test_accuracy,
126+
# 'Communication': epoch * 2 * self.param_count / (32 / self.n_bits),
127+
'Communication': epoch * 2 * remining_params * ((self.n_bits + np.log2(self.param_count))/ 32),
128+
})
129+
130+
def run():
131+
print(jax.devices())
132+
args = get_args()
133+
config = helpers.load_config(args.config)
134+
wandb.init(project='evofed-publish', config=args)
135+
wandb.config.update(config)
136+
args = wandb.config
137+
rng = jax.random.PRNGKey(args.seed)
138+
rng, rng_init, rng_run = jax.random.split(rng, 3)
139+
manager = TaskManager(rng_init, args)
140+
manager.run(rng_run)
141+
142+
143+
if __name__ == '__main__':
144+
# wandb.agent('tdt4lz81', function=run, project='evofed', count=10)
145+
run()

0 commit comments

Comments
 (0)