Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_agc_plus_receptor_potential #21

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions python/jax/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2302,7 +2302,7 @@ def run_segment(
state: CarfacState,
open_loop: bool = False,
) -> Tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray,
]:
"""This function runs the entire CARFAC model.

Expand Down Expand Up @@ -2341,8 +2341,9 @@ def run_segment(
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
ohc nonlinearity and agc are doing; both in terms of extra damping.
receptor_pot: receptor potential of ihc (optional extra)
seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra)
seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra)
"""
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))
Expand All @@ -2356,9 +2357,11 @@ def run_segment(
# (n_ears, cfp.n_ears))

n_ch = hypers.ears[0].car.n_ch
n_agc_stages = jnp.shape(state.ears[0].agc)[0]
naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result
naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears))
bm = jnp.zeros((n_samp, n_ch, n_ears))
receptor_pot = jnp.zeros((n_samp, n_ch, n_ears))
seg_ohc = jnp.zeros((n_samp, n_ch, n_ears))
seg_agc = jnp.zeros((n_samp, n_ch, n_ears))

Expand All @@ -2376,7 +2379,7 @@ def run_segment(
# Note that we can use naive for loops here because it will make gradient
# computation very slow.
def run_segment_scan_helper(carry, k):
naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry
naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves = carry
agc_updated = False
for ear in range(n_ears):
# This would be cleaner if we could just get and use a reference to
Expand Down Expand Up @@ -2406,10 +2409,11 @@ def run_segment_scan_helper(carry, k):
)
# save some output data:
naps = naps.at[k, :, ear].set(ihc_out)
receptor_pot = receptor_pot.at[k, :, ear].set(v_recep)
bm = bm.at[k, :, ear].set(car_out)
car_state = state.ears[ear].car
seg_ohc = seg_ohc.at[k, :, ear].set(car_state.za_memory)
seg_agc = seg_agc.at[k, :, ear].set(car_state.zb_memory)
seg_agc = seg_agc.at[k, :, ear].set(state.ears[ear].agc[0].agc_memory)

def close_agc_loop_helper(
hypers: CarfacHypers, weights: CarfacWeights, state: CarfacState
Expand All @@ -2431,11 +2435,11 @@ def close_agc_loop_helper(
state,
)

return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None
return (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves), None

return jax.lax.scan(
run_segment_scan_helper,
(naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves),
(naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves),
jnp.arange(n_samp),
)[0][:-1]

Expand All @@ -2454,7 +2458,7 @@ def run_segment_jit(
state: CarfacState,
open_loop: bool = False,
) -> Tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray,
]:
"""A JITted version of run_segment for convenience.

Expand All @@ -2463,11 +2467,11 @@ def run_segment_jit(
way to account for this is to always make a deep copy of the hypers and modify
those. Example usage if modifying the hypers (which most users should not):

naps, _, _, _, _ = run_segment_jit(input, hypers, weights, state)
naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers, weights, state)

hypers_copy = copy.deepcopy(hypers)
hypers_jax2.ears[0].car.r1_coeffs /= 2.0
naps, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state)
naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state)

If no modifications to the CarfacHypers are made, the same hypers object
should be reused.
Expand All @@ -2485,9 +2489,10 @@ def run_segment_jit(
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
ohc nonlinearity and agc are doing; both in terms of extra damping.
"""
receptor_pot: receptor potential of ihc (optional extra)
seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra)
seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra)
\ """
return run_segment(input_waves, hypers, weights, state, open_loop)


Expand All @@ -2499,7 +2504,7 @@ def run_segment_jit_in_chunks_notraceable(
open_loop: bool = False,
segment_chunk_length: int = 32 * 48000,
) -> tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray,
]:
"""Runs the jitted segment runner in segment groups.

Expand All @@ -2524,11 +2529,14 @@ def run_segment_jit_in_chunks_notraceable(
largest chunk.

Returns:
naps: Neural activity pattern as a numpy array.
naps_out: Neural activity pattern as a numpy array.
naps_fibers_out: neural activity of the different fiber types
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: The updated state of the CARFAC model.
BM: The basilar membrane motion as a numpy array.
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
ohc nonlinearity and agc are doing; both in terms of extra damping.
bm_out: The basilar membrane motion as a numpy array.
v_recep_out: receptor potential of ihc (optional extra)
ohc_out za_memory parameter to observe ohc nonlinearity (optional extra)
agc_out: agc output (stage 1) to observe agc activity (optional extra)

Raises:
RuntimeError: If this function is being JITTed, which it should not be.
Expand All @@ -2545,37 +2553,42 @@ def run_segment_jit_in_chunks_notraceable(
naps_out = []
naps_fibers_out = []
bm_out = []
v_recep_out = []
ohc_out = []
agc_out = []
agc_memory_out = []
# NOMUTANTS -- This is a performance optimization.
while segment_length > 16:
[n_samp, _] = input_waves.shape
if n_samp >= segment_length:
[current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0)
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = (
naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = (
run_segment_jit(current_waves, hypers, weights, state, open_loop)
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
bm_out.append(bm_jax)
v_recep_out.append(receptor_pot_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
else:
segment_length //= 2
[n_samp, _] = input_waves.shape
# Take the last few items and just run them.
if n_samp > 0:
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = (
naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = (
run_segment_jit(input_waves, hypers, weights, state, open_loop)
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
v_recep_out.append(receptor_pot_jax)
bm_out.append(bm_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
naps_out = np.concatenate(naps_out, 0)
naps_fibers_out = np.concatenate(naps_fibers_out, 0)
bm_out = np.concatenate(bm_out, 0)
v_recep_out = np.concatenate(v_recep_out, 0)
ohc_out = np.concatenate(ohc_out, 0)
agc_out = np.concatenate(agc_out, 0)
return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out
return naps_out, naps_fibers_out, state, bm_out, v_recep_out, ohc_out, agc_out,
12 changes: 6 additions & 6 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def loss_func(
weights: carfac_jax.CarfacWeights,
state: carfac_jax.CarfacState,
):
nap_output, _, _, _, _, _ = carfac_jax.run_segment(
nap_output, _, _, _, _, _, _ = carfac_jax.run_segment(
audio, hypers, weights, state
)
return jnp.sum(nap_output), nap_output
Expand Down Expand Up @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State):
# that this benchmark is appropriate.
n_samp += 1
state.resume_timing()
naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand Down Expand Up @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
for _, segment in enumerate(silence_slices):
if segment.shape not in compiled_shapes:
compiled_shapes.add(segment.shape)
naps_jax, _, _, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, _, _, _, _, _, _ = carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand All @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
jax_loop_state = state_jax
state.resume_timing()
for _, segment in enumerate(run_seg_slices):
seg_naps, _, jax_loop_state, seg_bm, seg_ohc, seg_agc = (
seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_receptor_pot, seg_ohc, seg_agc = (
carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False
)
Expand Down Expand Up @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State):
params_jax
)
short_silence = jnp.zeros(shape=(n_samp, n_ears))
naps_jax, _, state_jax, _, _, _ = run_segment_function(
naps_jax, _, state_jax, _, _, _, _ = run_segment_function(
short_silence, hypers_jax, weights_jax, state_jax, open_loop=False
)
# This block ensures calculation.
Expand All @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State):
jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR
).block_until_ready()
state.resume_timing()
naps_jax, _, state_jax, _, _, _ = run_segment_function(
naps_jax, _, state_jax, _, _, _, _ = run_segment_function(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
if state.range(0) != 1:
Expand Down
2 changes: 1 addition & 1 deletion python/jax/carfac_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state):
# A loss function for tests. Note that we shouldn't use `run_segment_jit`
# here because it will donate the `state` which causes unnecessary
# inconvenience for tests.
naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment(
naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment(
input_waves, hypers, weights, state, open_loop=False
)
# For testing, just fit `naps` to 1.
Expand Down
29 changes: 11 additions & 18 deletions python/jax/carfac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,24 +324,17 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style):
state_jax_copied = copy.deepcopy(state_jax)

# Only tests the JITted version because this is what we will use.
naps_jax, _, _, bm_jax, ohc_jax, agc_jax = (
carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax, naps_fibers_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
(
naps_jax_chunked,
_,
_,
bm_chunked,
ohc_chunked,
agc_chunked,
) = carfac_jax.run_segment_jit_in_chunks_notraceable(
run_seg_input,
hypers_jax,
weights_jax,
state_jax_copied,
open_loop=False,
naps_jax_chunked, naps_fibes_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = (
carfac_jax.run_segment_jit_in_chunks_notraceable(
run_seg_input,
hypers_jax,
weights_jax,
state_jax_copied,
open_loop=False,
)
)
self.assertLess(jnp.max(abs(naps_jax_chunked - naps_jax)), 1e-7)
self.assertLess(jnp.max(abs(bm_chunked - bm_jax)), 1e-7)
Expand Down Expand Up @@ -387,7 +380,7 @@ def test_equal_forward_pass(
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears))

# Only tests the JITted version because this is what we will use.
naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
naps_jax, naps_fibers_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
Expand Down
10 changes: 9 additions & 1 deletion python/jax/carfac_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def run_multiple_segment_states_shmap(
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
]
]:
"""Run multiple equal-length, segments in carfac, Jitted, in parallel.
Expand Down Expand Up @@ -89,7 +91,7 @@ def parallel_helper(input_waves, state):
"""
input_waves = input_waves[0]
state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state)
naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = (
naps, naps_fibers, ret_state, bm, receptor_pot, seg_ohc, seg_agc = (
carfac_jax.run_segment_jit(
input_waves, hypers, weights, state, open_loop
)
Expand All @@ -102,6 +104,7 @@ def parallel_helper(input_waves, state):
naps_fibers[None],
ret_state,
bm[None],
receptor_pot[None],
seg_ohc[None],
seg_agc[None],
)
Expand All @@ -111,6 +114,7 @@ def parallel_helper(input_waves, state):
stacked_naps_fibers,
stacked_states,
stacked_bm,
stacked_receptor_pot,
stacked_ohc,
stacked_agc,
) = parallel_helper(input_waves_array, batch_state)
Expand All @@ -124,6 +128,7 @@ def parallel_helper(input_waves, state):
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_receptor_pot[i],
stacked_ohc[i],
stacked_agc[i],
)
Expand All @@ -146,6 +151,7 @@ def run_multiple_segment_pmap(
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
]
]:
"""Run multiple equal-length, segments in carfac, Jitted, in parallel.
Expand All @@ -171,6 +177,7 @@ def run_multiple_segment_pmap(
stacked_naps_fibers,
stacked_states,
stacked_bm,
stacked_receptor_pot,
stacked_ohc,
stacked_agc,
) = pmapped(input_waves_array, hypers, weights, state, open_loop)
Expand All @@ -183,6 +190,7 @@ def run_multiple_segment_pmap(
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_receptor_pot[i],
stacked_ohc[i],
stacked_agc[i],
)
Expand Down
Loading