Skip to content

Commit fbda101

Browse files
authored
Adding learnable output scaling alpha (#353)
1 parent c737a52 commit fbda101

20 files changed

+635
-211
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The format is based on [Keep a Changelog], and this project adheres to
1212
* `Fixed` for any bug fixes.
1313
* `Security` in case of vulnerabilities.
1414

15-
## [0.5.0] - 2022/01/19
15+
## [0.5.0] - 2022/01/27
1616

1717
### Added
1818

@@ -36,6 +36,8 @@ The format is based on [Keep a Changelog], and this project adheres to
3636
output sizes can be configured for the ``*Mapped`` layers. (\#331)
3737
* Notebooks directory with several notebook examples (#333, \#334)
3838
* Analog information summary function. (\#316)
39+
* The `alpha` weight scaling factor can now be defined as learnable parameter by switching
40+
`learn_out_scaling_alpha` in the `rpu_config.mapping` parameters. (\#353)
3941

4042
### Fixed
4143

@@ -62,6 +64,7 @@ The format is based on [Keep a Changelog], and this project adheres to
6264
* Digital bias is now accessable through ``MappingParameter``. (\#331)
6365
* The aihwkit documentation. New content around analog ai concepts, training presets, analog ai
6466
optimizers, new references, and examples. (\#348)
67+
* The `weight_scaling_omega` can now be defined in the `rpu_config.mapping`. (\#353)
6568

6669
### Deprecated
6770

examples/11_vgg8_training.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from aihwkit.optim import AnalogSGD
3636
from aihwkit.simulator.presets.configs import GokmenVlasovPreset
3737
from aihwkit.simulator.rpu_base import cuda
38+
from aihwkit.simulator.configs.utils import MappingParameter
3839

3940
# Check device
4041
USE_CUDA = 0
@@ -59,7 +60,8 @@
5960

6061
# Select the device model to use in the training. In this case we are using one of the preset,
6162
# but it can be changed to a number of preset to explore possible different analog devices
62-
RPU_CONFIG = GokmenVlasovPreset()
63+
mapping = MappingParameter(weight_scaling_omega=WEIGHT_SCALING_OMEGA)
64+
RPU_CONFIG = GokmenVlasovPreset(mapping=mapping)
6365

6466

6567
def load_images():
@@ -93,33 +95,33 @@ def create_analog_network():
9395
nn.ReLU(),
9496
AnalogConv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, stride=1,
9597
padding=1,
96-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
98+
rpu_config=RPU_CONFIG),
9799
nn.BatchNorm2d(channel[0]),
98100
nn.ReLU(),
99101
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1),
100102
AnalogConv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, stride=1,
101103
padding=1,
102-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
104+
rpu_config=RPU_CONFIG),
103105
nn.ReLU(),
104106
AnalogConv2d(in_channels=channel[1], out_channels=channel[1], kernel_size=3, stride=1,
105107
padding=1,
106-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
108+
rpu_config=RPU_CONFIG),
107109
nn.BatchNorm2d(channel[1]),
108110
nn.ReLU(),
109111
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1),
110112
AnalogConv2d(in_channels=channel[1], out_channels=channel[2], kernel_size=3, stride=1,
111113
padding=1,
112-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
114+
rpu_config=RPU_CONFIG),
113115
nn.ReLU(),
114116
AnalogConv2d(in_channels=channel[2], out_channels=channel[2], kernel_size=3, stride=1,
115117
padding=1,
116-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
118+
rpu_config=RPU_CONFIG),
117119
nn.BatchNorm2d(channel[2]),
118120
nn.ReLU(),
119121
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1),
120122
nn.Flatten(),
121123
AnalogLinear(in_features=16 * channel[2], out_features=fc_size,
122-
rpu_config=RPU_CONFIG, weight_scaling_omega=WEIGHT_SCALING_OMEGA),
124+
rpu_config=RPU_CONFIG),
123125
nn.ReLU(),
124126
nn.Linear(in_features=fc_size, out_features=N_CLASSES),
125127
nn.LogSoftmax(dim=1)

examples/16_mnist_gan.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@
5252
# optimizer used (e.g. mixed precision or full analog update)
5353

5454
# As an example we use a mixed precision preset using an ECRAM device model
55-
RPU_CONFIG = MixedPrecisionEcRamMOPreset()
56-
WS_OMEGA = 0.8
55+
from aihwkit.simulator.configs.utils import MappingParameter
56+
mapping = MappingParameter(weight_scaling_omega=0.8)
57+
RPU_CONFIG = MixedPrecisionEcRamMOPreset(mapping=mapping)
5758

5859
# Set your parameters
5960
SEED = 1
@@ -129,8 +130,7 @@ def get_generator_block(input_dim, output_dim):
129130
input_dim,
130131
output_dim,
131132
bias=True,
132-
rpu_config=RPU_CONFIG,
133-
weight_scaling_omega=WS_OMEGA,
133+
rpu_config=RPU_CONFIG
134134
),
135135
nn.BatchNorm1d(output_dim),
136136
nn.ReLU(inplace=True),
@@ -159,8 +159,7 @@ def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
159159
hidden_dim * 8,
160160
im_dim,
161161
bias=True,
162-
rpu_config=RPU_CONFIG,
163-
weight_scaling_omega=WS_OMEGA,
162+
rpu_config=RPU_CONFIG
164163
),
165164
nn.Sigmoid(),
166165
)

examples/17_resnet34_imagenet_conversion_to_analog.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@
3535
# Define device and chip configuration used in the RPU tile
3636
mapping = MappingParameter(max_input_size=512, # analog tile size
3737
max_output_size=512,
38-
digital_bias=True) # whether to use analog or digital bias
38+
digital_bias=True,
39+
weight_scaling_omega=0.6) # whether to use analog or digital bias
3940
# Choose any preset or RPU configuration here
4041
rpu_config = TikiTakaReRamSBPreset(mapping=mapping)
4142

4243
# Convert the model to its analog version.
4344
# this will replace ``Linear`` layers with ``AnalogLinearMapped``
44-
model = convert_to_analog_mapped(model, rpu_config, weight_scaling_omega=0.6)
45+
model = convert_to_analog_mapped(model, rpu_config)
4546

4647
# Note: One can also use ``convert_to_analog`` instead to convert
4748
# ``Linear`` to ``AnalogLinear`` (without mapping to multiple tiles)

examples/18_cifar10_on_resnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from aihwkit.nn.conversion import convert_to_analog
3535
from aihwkit.simulator.presets import TikiTakaEcRamPreset
3636
from aihwkit.simulator.rpu_base import cuda
37+
from aihwkit.simulator.configs.utils import MappingParameter
3738

3839
# Device to use
3940
USE_CUDA = 0
@@ -56,7 +57,8 @@
5657
N_CLASSES = 10
5758

5859
# Device used in the RPU tile
59-
RPU_CONFIG = TikiTakaEcRamPreset()
60+
mapping = MappingParameter(weight_scaling_omega=0.6)
61+
RPU_CONFIG = TikiTakaEcRamPreset(mapping=mapping)
6062

6163

6264
class ResidualBlock(nn.Module):
@@ -295,7 +297,7 @@ def main():
295297
model = create_model()
296298

297299
# Convert the model to its analog version
298-
model = convert_to_analog(model, RPU_CONFIG, weight_scaling_omega=0.6)
300+
model = convert_to_analog(model, RPU_CONFIG)
299301
# Load saved weights if previously saved
300302
# model.load_state_dict(load(WEIGHT_PATH))
301303

src/aihwkit/nn/conversion.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def convert_to_analog(
4545
module: Module,
4646
rpu_config: RPUConfigGeneric,
4747
realistic_read_write: bool = False,
48-
weight_scaling_omega: float = 0.0,
4948
conversion_map: Optional[Dict] = None
5049
) -> Module:
5150
"""Convert a given digital model to analog counter parts.
@@ -63,9 +62,6 @@ def convert_to_analog(
6362
Applied to all converted tiles.
6463
realistic_read_write: Whether to use closed-loop programming
6564
when setting the weights. Applied to all converted tiles.
66-
weight_scaling_omega: If non-zero, the analog weights will be
67-
scaled by ``weight_scaling_omega`` divided by the absolute
68-
maximum value of the original weight matrix.
6965
7066
Note:
7167
Make sure that the weight max and min settings of the
@@ -92,7 +88,7 @@ def convert_to_analog(
9288
# Convert parent.
9389
if module.__class__ in conversion_map:
9490
module = conversion_map[module.__class__].from_digital( # type: ignore
95-
module, rpu_config, realistic_read_write, weight_scaling_omega)
91+
module, rpu_config, realistic_read_write)
9692

9793
# Convert children.
9894
convert_dic = {}
@@ -101,11 +97,11 @@ def convert_to_analog(
10197
n_grand_children = len(list(mod.named_children()))
10298
if n_grand_children > 0:
10399
new_mod = convert_to_analog(mod, rpu_config, realistic_read_write,
104-
weight_scaling_omega, conversion_map)
100+
conversion_map)
105101

106102
elif mod.__class__ in conversion_map:
107103
new_mod = conversion_map[mod.__class__].from_digital( # type: ignore
108-
mod, rpu_config, realistic_read_write, weight_scaling_omega)
104+
mod, rpu_config, realistic_read_write)
109105
else:
110106
continue
111107

@@ -125,7 +121,6 @@ def convert_to_analog_mapped(
125121
module: Module,
126122
rpu_config: RPUConfigGeneric,
127123
realistic_read_write: bool = False,
128-
weight_scaling_omega: float = 0.0,
129124
) -> Module:
130125
"""Convert a given digital model to its analog counterpart with tile
131126
mapping support.
@@ -142,9 +137,6 @@ def convert_to_analog_mapped(
142137
rpu_config: RPU config to apply to all converted tiles.
143138
realistic_read_write: Whether to use closed-loop programming
144139
when setting the weights. Applied to all converted tiles.
145-
weight_scaling_omega: If non-zero, the analog weights will be
146-
scaled by ``weight_scaling_omega`` divided by the absolute
147-
maximum value of the original weight matrix.
148140
149141
Note:
150142
Make sure that the weight max and min settings of the
@@ -159,6 +151,5 @@ def convert_to_analog_mapped(
159151
module,
160152
rpu_config,
161153
realistic_read_write,
162-
weight_scaling_omega,
163154
_DEFAULT_MAPPED_CONVERSION_MAP
164155
)

src/aihwkit/nn/modules/base.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# that they have been altered from the originals.
1212

1313
"""Base class for analog Modules."""
14+
import warnings
1415

1516
from typing import (
1617
Any, Dict, List, Optional, Tuple, NamedTuple, Union,
@@ -72,14 +73,15 @@ class AnalogModuleBase(Module):
7273
ANALOG_CTX_PREFIX: str = 'analog_ctx_'
7374
ANALOG_SHARED_WEIGHT_PREFIX: str = 'analog_shared_weights_'
7475
ANALOG_STATE_PREFIX: str = 'analog_tile_state_'
76+
ANALOG_OUT_SCALING_ALPHA_PREFIX: str = 'analog_out_scaling_alpha_'
7577

7678
def __init__(
7779
self,
7880
in_features: int,
7981
out_features: int,
8082
bias: bool,
8183
realistic_read_write: bool = False,
82-
weight_scaling_omega: float = 0.0,
84+
weight_scaling_omega: Optional[float] = None,
8385
mapping: Optional[MappingParameter] = None,
8486
) -> None:
8587
# pylint: disable=super-init-not-called
@@ -93,9 +95,21 @@ def __init__(
9395
self.use_bias = bias
9496
self.digital_bias = bias and mapping.digital_bias
9597
self.analog_bias = bias and not mapping.digital_bias
98+
self.weight_scaling_omega = mapping.weight_scaling_omega if weight_scaling_omega is None \
99+
else weight_scaling_omega
100+
if weight_scaling_omega is not None:
101+
warnings.warn(DeprecationWarning('\nSetting the weight_scaling_omega through the '
102+
'layers input parameters will be deprecated in the '
103+
'future. Please set it through the MappingParameter '
104+
'of the rpu_config.\n'))
105+
106+
self.weight_scaling_omega_columnwise = mapping.weight_scaling_omega_columnwise
107+
self.learn_out_scaling_alpha = mapping.learn_out_scaling_alpha
108+
109+
if self.learn_out_scaling_alpha and self.weight_scaling_omega == 0:
110+
raise ValueError('out_scaling_alpha can only be learned if weight_scaling_omega > 0')
96111

97112
self.realistic_read_write = realistic_read_write
98-
self.weight_scaling_omega = weight_scaling_omega
99113
self.in_features = in_features
100114
self.out_features = out_features
101115

@@ -129,6 +143,15 @@ def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) ->
129143
if par_name not in self._registered_helper_parameter:
130144
self._registered_helper_parameter.append(par_name)
131145

146+
if self.learn_out_scaling_alpha:
147+
if not isinstance(tile.out_scaling_alpha, Parameter):
148+
tile.out_scaling_alpha = Parameter(tile.out_scaling_alpha)
149+
par_name = self.ANALOG_OUT_SCALING_ALPHA_PREFIX + str(self._analog_tile_counter)
150+
self.register_parameter(par_name, tile.out_scaling_alpha)
151+
152+
if par_name not in self._registered_helper_parameter:
153+
self._registered_helper_parameter.append(par_name)
154+
132155
self._analog_tile_counter += 1
133156

134157
def unregister_parameter(self, param_name: str) -> None:
@@ -235,9 +258,12 @@ def set_weights(
235258
analog_tile = analog_tiles[0]
236259

237260
if self.weight_scaling_omega > 0.0:
238-
analog_tile.set_weights_scaled(weight, bias if self.analog_bias else None,
239-
realistic=realistic,
240-
omega=self.weight_scaling_omega)
261+
analog_tile.set_weights_scaled(
262+
weight, bias if self.analog_bias else None,
263+
realistic=realistic,
264+
omega=self.weight_scaling_omega,
265+
weight_scaling_omega_columnwise=self.weight_scaling_omega_columnwise,
266+
learn_out_scaling_alpha=self.learn_out_scaling_alpha)
241267
else:
242268
analog_tile.set_weights(weight, bias if self.analog_bias else None,
243269
realistic=realistic)
@@ -283,7 +309,9 @@ def get_weights(
283309

284310
realistic = self.realistic_read_write and not force_exact
285311
if self.weight_scaling_omega > 0.0:
286-
weight, bias = analog_tile.get_weights_scaled(realistic=realistic)
312+
weight, bias = analog_tile.get_weights_scaled(
313+
realistic=realistic,
314+
weight_scaling_omega_columnwise=self.weight_scaling_omega_columnwise)
287315
else:
288316
weight, bias = analog_tile.get_weights(realistic=realistic)
289317

0 commit comments

Comments
 (0)