|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp # JAX NumPy |
| 3 | +import numpy as np # Ordinary NumPy |
| 4 | +import wandb |
| 5 | +from backprop import sl |
| 6 | +from utils import helpers, models, evo |
| 7 | +import chex |
| 8 | +from args import get_args |
| 9 | +from evosax import NetworkMapper, ParameterReshaper, FitnessShaper |
| 10 | +from flax.core import FrozenDict |
| 11 | + |
| 12 | +import os |
| 13 | + |
| 14 | +os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' |
| 15 | + |
| 16 | + |
| 17 | +# cosine distance |
| 18 | +def cosine(x, y): |
| 19 | + return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(x ** 2))) |
| 20 | + |
| 21 | + |
| 22 | +def cosine2(x, y): |
| 23 | + return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(y ** 2))) |
| 24 | + |
| 25 | +# l2 distance |
| 26 | +def l2(x, y): |
| 27 | + return -1 * jnp.sqrt(jnp.sum((x - y) ** 2)) # / jnp.sqrt(jnp.sum(x ** 2)) |
| 28 | + |
| 29 | + |
| 30 | +def l1(x, y): |
| 31 | + return -1 * jnp.sum(jnp.abs(x - y)) |
| 32 | + |
| 33 | + |
| 34 | +#compress the array with quantization based on array distribution |
| 35 | +def quantize(array, n_bits): |
| 36 | + max_val = array.max() |
| 37 | + min_val = array.min() |
| 38 | + step = (max_val - min_val) / (2 ** n_bits - 1) |
| 39 | + array = ((array - min_val) / step).round() |
| 40 | + return array |
| 41 | + |
| 42 | +#dequantization array |
| 43 | +def dequantize(array, min_val, max_val, n_bits): |
| 44 | + step = (max_val - min_val) / (2 ** n_bits - 1) |
| 45 | + array = array * step + min_val |
| 46 | + return array |
| 47 | + |
| 48 | + |
| 49 | +#plot histogram of array |
| 50 | +from matplotlib import pyplot as plt |
| 51 | +def plot_hist(array): |
| 52 | + plt.hist(array, bins=100) |
| 53 | + plt.show() |
| 54 | + |
| 55 | +class TaskManager: |
| 56 | + def __init__(self, rng: chex.PRNGKey, args): |
| 57 | + wandb.run.name = '{}-{}-{} b{} s{} q{} -- {}' \ |
| 58 | + .format(args.dataset, args.algo, |
| 59 | + args.dist, |
| 60 | + args.batch_size, |
| 61 | + args.seed, args.quantize_bits, wandb.run.id) |
| 62 | + |
| 63 | + wandb.run.save() |
| 64 | + self.args = args |
| 65 | + |
| 66 | + def run(self, rng: chex.PRNGKey): |
| 67 | + train_ds, test_ds = sl.get_datasets(wandb.config.dataset.lower()) |
| 68 | + rng, init_rng = jax.random.split(rng) |
| 69 | + |
| 70 | + learning_rate = wandb.config.lr |
| 71 | + momentum = wandb.config.momentum |
| 72 | + network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config) |
| 73 | + |
| 74 | + state = sl.create_train_state(init_rng, network, learning_rate, momentum) |
| 75 | + param_reshaper = ParameterReshaper(state.params, n_devices=self.args.n_devices) |
| 76 | + test_param_reshaper = ParameterReshaper(state.params, n_devices=1) |
| 77 | + # strategy, es_params = evo.get_strategy_and_params(self.args.pop_size, param_reshaper.total_params, self.args) |
| 78 | + fit_shaper = FitnessShaper(centered_rank=True, z_score=True, w_decay=self.args.w_decay, maximize=True) |
| 79 | + # server = strategy.initialize(init_rng, es_params) |
| 80 | + # server = server.replace(mean=test_param_reshaper.network_to_flat(state.params)) |
| 81 | + # del init_rng # Must not be used anymore. |
| 82 | + |
| 83 | + num_epochs = wandb.config.n_rounds |
| 84 | + batch_size = wandb.config.batch_size |
| 85 | + nbits = wandb.config.quantize_bits |
| 86 | + X, y = jnp.array(train_ds['image']), jnp.array(train_ds['label']) |
| 87 | + |
| 88 | + for epoch in range(1, num_epochs + 1): |
| 89 | + # Use a separate PRNG key to permute image data during shuffling |
| 90 | + rng, input_rng, rng_ask = jax.random.split(rng, 3) |
| 91 | + # Run an optimization step over a training batch |
| 92 | + target_state, loss, acc = sl.train_epoch(state, X, y, batch_size, input_rng) |
| 93 | + # Evaluate on the test set after each training epoch |
| 94 | + target_server = param_reshaper.network_to_flat(target_state.params) |
| 95 | + max_val = target_server.max() |
| 96 | + min_val = target_server.min() |
| 97 | + target_server = quantize(target_server, nbits) |
| 98 | + server = dequantize(target_server, min_val, max_val, nbits) |
| 99 | + |
| 100 | + ad = jnp.sum(server - target_server) |
| 101 | + |
| 102 | + state = sl.update_train_state(learning_rate, momentum, test_param_reshaper.reshape_single_net(server)) |
| 103 | + test_loss, test_accuracy = sl.eval_model(state.params, test_ds, input_rng) |
| 104 | + wandb.log({ |
| 105 | + 'Round': epoch, |
| 106 | + 'Test Loss': test_loss, |
| 107 | + 'Train Loss': loss, |
| 108 | + 'Test Accuracy': test_accuracy, |
| 109 | + 'Train Accuracy': acc, |
| 110 | + 'Global Accuracy': test_accuracy, |
| 111 | + 'Information Loss': ad, |
| 112 | + }) |
| 113 | + |
| 114 | + |
| 115 | +def run(): |
| 116 | + print(jax.devices()) |
| 117 | + args = get_args() |
| 118 | + config = helpers.load_config(args.config) |
| 119 | + wandb.init(project='evofed', config=args) |
| 120 | + wandb.config.update(config) |
| 121 | + args = wandb.config |
| 122 | + rng = jax.random.PRNGKey(args.seed) |
| 123 | + rng, rng_init, rng_run = jax.random.split(rng, 3) |
| 124 | + manager = TaskManager(rng_init, args) |
| 125 | + manager.run(rng_run) |
| 126 | + |
| 127 | + |
| 128 | +SWEEPS = { |
| 129 | + 'cifar-bp': 'bc4zva3u', |
| 130 | + 'cifar-bp2': '82la1zw0', |
| 131 | + 'fmnits-mah': '1yksrmvs', |
| 132 | + 'cifar-mah': 'mtheusi1', |
| 133 | +} |
| 134 | + |
| 135 | +if __name__ == '__main__': |
| 136 | + run() |
| 137 | + # wandb.agent(SWEEPS['cifar-mah'], function=run, project='evofed', count=10) |
0 commit comments