|
| 1 | +import chex |
| 2 | +import jax |
| 3 | +import jax.numpy as jnp # JAX NumPy |
| 4 | +import numpy as np |
| 5 | +import tensorflow_datasets as tfds # TFDS for MNIST |
| 6 | +import wandb |
| 7 | +from evosax import NetworkMapper |
| 8 | +from backprop import sl |
| 9 | +from args import get_args |
| 10 | +from utils import helpers |
| 11 | +from flax.core import FrozenDict |
| 12 | +from evosax import ParameterReshaper |
| 13 | +import os |
| 14 | + |
| 15 | +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' |
| 16 | + |
| 17 | + |
| 18 | +# compress the array with quantization based on array distribution |
| 19 | +def quantize(array, min_val, max_val, n_bits): |
| 20 | + # max_val = array.max() |
| 21 | + # min_val = array.min() |
| 22 | + step = (max_val - min_val) / (2 ** n_bits - 1) |
| 23 | + array = ((array - min_val) / step).round() |
| 24 | + return array |
| 25 | + |
| 26 | + |
| 27 | +# dequantization array |
| 28 | +def dequantize(array, min_val, max_val, n_bits): |
| 29 | + step = (max_val - min_val) / (2 ** n_bits - 1) |
| 30 | + array = array * step + min_val |
| 31 | + return array |
| 32 | + |
| 33 | +def sparsify(array, percentage): |
| 34 | + original = array |
| 35 | + array = jnp.abs(array.flatten()) |
| 36 | + array = jnp.sort(array) |
| 37 | + threshold = array[int(len(array) * percentage)] |
| 38 | + array = jnp.where(jnp.abs(original) < threshold, 0, original) |
| 39 | + return array |
| 40 | + |
| 41 | +# L2 distance |
| 42 | +def l2(x, y): |
| 43 | + return -1 * jnp.sqrt(jnp.sum((x - y) ** 2)) # / jnp.sqrt(jnp.sum(x ** 2)) |
| 44 | + |
| 45 | + |
| 46 | +class TaskManager: |
| 47 | + def __init__(self, rng: chex.PRNGKey, args): |
| 48 | + wandb.run.name = '{}-{}-{} b{} c{} s{} q{} -- {}' \ |
| 49 | + .format(args.dataset, args.algo, |
| 50 | + args.dist, |
| 51 | + args.batch_size, args.n_clients, |
| 52 | + args.seed, args.quantize_bits, wandb.run.id) |
| 53 | + wandb.run.save() |
| 54 | + self.train_ds, self.test_ds = sl.get_fed_datasets(args.dataset, args.n_clients, 2, args.dist == 'IID') |
| 55 | + |
| 56 | + rng = jax.random.PRNGKey(0) |
| 57 | + rng, init_rng = jax.random.split(rng) |
| 58 | + self.learning_rate = wandb.config.lr |
| 59 | + self.momentum = wandb.config.momentum |
| 60 | + network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config) |
| 61 | + |
| 62 | + self.state = sl.create_train_state(init_rng, network, self.learning_rate, self.momentum) |
| 63 | + del init_rng # Must not be used anymore. |
| 64 | + |
| 65 | + self.param_count = sum(x.size for x in jax.tree_leaves(self.state.params)) |
| 66 | + self.num_epochs = wandb.config.n_rounds |
| 67 | + self.batch_size = wandb.config.batch_size |
| 68 | + self.client_epoch = wandb.config.client_epoch |
| 69 | + self.n_clients = args.n_clients |
| 70 | + min_cut = 10000 |
| 71 | + # if args.dataset == 'mnist': |
| 72 | + # min_cut = 5421 |
| 73 | + |
| 74 | + self.X = jnp.array([train['image'][:min_cut] for train in self.train_ds]) |
| 75 | + self.y = jnp.array([train['label'][:min_cut] for train in self.train_ds]) |
| 76 | + self.args = args |
| 77 | + self.n_bits = args.quantize_bits |
| 78 | + self.param_reshaper = ParameterReshaper(self.state.params, n_devices=1) |
| 79 | + |
| 80 | + def run(self, rng: chex.PRNGKey): |
| 81 | + for epoch in range(0, self.num_epochs + 1): |
| 82 | + # Use a separate PRNG key to permute image data during shuffling |
| 83 | + rng, input_rng = jax.random.split(rng) |
| 84 | + # Run an optimization step over a training batch |
| 85 | + # clients = [self.state for i in range(5)] |
| 86 | + clients, loss, acc = jax.vmap(sl.train_epoch, in_axes=(None, 0, 0, None, None))(self.state, |
| 87 | + self.X, |
| 88 | + self.y, |
| 89 | + self.batch_size, input_rng) |
| 90 | + for c_epoch in range(self.client_epoch): |
| 91 | + input_rng, c_rng = jax.random.split(input_rng) |
| 92 | + clients, loss, acc = jax.vmap(sl.train_epoch, in_axes=(0, 0, 0, None, None))(clients, |
| 93 | + self.X, |
| 94 | + self.y, |
| 95 | + self.batch_size, c_rng) |
| 96 | + wandb.log({ |
| 97 | + 'Epoch': epoch * self.client_epoch + c_epoch, |
| 98 | + 'Train Loss': loss.mean(), |
| 99 | + 'Train Accuracy': acc.mean(), |
| 100 | + }) |
| 101 | + |
| 102 | + server = self.param_reshaper.network_to_flat(self.state.params) |
| 103 | + target_server = jax.vmap(self.param_reshaper.network_to_flat)(clients.params) |
| 104 | + target_server = (target_server - server) |
| 105 | + min_val, max_val = jax.vmap(jnp.min)(target_server), jax.vmap(jnp.max)(target_server) |
| 106 | + target_server = jax.vmap(sparsify, in_axes=(0, None))(target_server, self.args.percentage) |
| 107 | + |
| 108 | + # target_server = jax.vmap(quantize, in_axes=(0, 0, 0, None))(target_server, min_val, max_val, self.n_bits) |
| 109 | + # target_server = jax.vmap(dequantize, in_axes=(0, 0, 0, None))(target_server, min_val, max_val, self.n_bits) |
| 110 | + target_server = jax.vmap(jnp.mean)(target_server.T) |
| 111 | + # target_server = jax.vmap(quantize, in_axes=(0, None))(target_server, self.n_bits) |
| 112 | + # target_server = dequantize(target_server, min_val.mean(), max_val.mean(), self.n_bits) |
| 113 | + target_server = sparsify(target_server, self.args.percentage) |
| 114 | + |
| 115 | + target_server = target_server + server |
| 116 | + params = self.param_reshaper.reshape_single_net(target_server) |
| 117 | + self.state = self.state.replace(params=FrozenDict(params)) |
| 118 | + rng, eval_rng = jax.random.split(rng) |
| 119 | + test_loss, test_accuracy = sl.eval_model(params, self.test_ds, eval_rng) |
| 120 | + remining_params = self.param_count * (1 - self.args.percentage) |
| 121 | + |
| 122 | + wandb.log({ |
| 123 | + 'Round': epoch, |
| 124 | + 'Test Loss': test_loss, |
| 125 | + 'Global Accuracy': test_accuracy, |
| 126 | + # 'Communication': epoch * 2 * self.param_count / (32 / self.n_bits), |
| 127 | + 'Communication': epoch * 2 * remining_params * ((self.n_bits + np.log2(self.param_count))/ 32), |
| 128 | + }) |
| 129 | + |
| 130 | +def run(): |
| 131 | + print(jax.devices()) |
| 132 | + args = get_args() |
| 133 | + config = helpers.load_config(args.config) |
| 134 | + wandb.init(project='evofed-publish', config=args) |
| 135 | + wandb.config.update(config) |
| 136 | + args = wandb.config |
| 137 | + rng = jax.random.PRNGKey(args.seed) |
| 138 | + rng, rng_init, rng_run = jax.random.split(rng, 3) |
| 139 | + manager = TaskManager(rng_init, args) |
| 140 | + manager.run(rng_run) |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == '__main__': |
| 144 | + # wandb.agent('tdt4lz81', function=run, project='evofed', count=10) |
| 145 | + run() |
0 commit comments