From 0ee1e3f9ba6c2239a8619f8617cf8eb0cf53d5c7 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Thu, 26 Dec 2024 14:35:25 +1100 Subject: [PATCH 1/7] add_agc_plus_potential_receptor --- python/jax/carfac.py | 49 ++++++++++++++++++++------- python/jax/carfac_bench.py | 10 +++--- python/jax/carfac_float64_test.py | 2 +- python/jax/carfac_test.py | 29 +++++++--------- python/jax/carfac_util.py | 16 ++++++++- python/jax/carfac_util_test.py | 55 ++++++++++++++++++------------- 6 files changed, 101 insertions(+), 60 deletions(-) diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 4a6d9d4..5bb329d 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2302,7 +2302,7 @@ def run_segment( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """This function runs the entire CARFAC model. @@ -2339,10 +2339,12 @@ def run_segment( naps: neural activity pattern naps_fibers: neural activity of different fibers (only populated with non-zeros when ihc_style equals "two_cap_with_syn") + receptor_pot: receptor potential of ihc state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the ohc nonlinearity and agc are doing; both in terms of extra damping. + seg_agc_memory is an optional extra that gives the actual agc activity. """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) @@ -2356,11 +2358,14 @@ def run_segment( # (n_ears, cfp.n_ears)) n_ch = hypers.ears[0].car.n_ch + n_agc_stages = jnp.shape(state.ears[0].agc)[0] naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears)) + receptor_pot = jnp.zeros((n_samp, n_ch, n_ears)) bm = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) + seg_agc_memory = jnp.zeros((n_samp, n_agc_stages, n_ch, n_ears)) # A 2022 addition to make open-loop running behave: if open_loop: @@ -2376,7 +2381,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry + naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2406,10 +2411,15 @@ def run_segment_scan_helper(carry, k): ) # save some output data: naps = naps.at[k, :, ear].set(ihc_out) + receptor_pot = receptor_pot.at[k, :, ear].set(v_recep) bm = bm.at[k, :, ear].set(car_out) car_state = state.ears[ear].car seg_ohc = seg_ohc.at[k, :, ear].set(car_state.za_memory) seg_agc = seg_agc.at[k, :, ear].set(car_state.zb_memory) + for i in range(n_agc_stages): + seg_agc_memory = seg_agc_memory.at[k, i, :, ear].set( + state.ears[ear].agc[i].agc_memory + ) def close_agc_loop_helper( hypers: CarfacHypers, weights: CarfacWeights, state: CarfacState @@ -2431,11 +2441,11 @@ def close_agc_loop_helper( state, ) - return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None + return (naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), + (naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2454,7 +2464,7 @@ def run_segment_jit( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """A JITted version of run_segment for convenience. @@ -2483,10 +2493,13 @@ def run_segment_jit( naps: neural activity pattern naps_fibers: neural activity of the different fiber types (only populated with non-zeros when ihc_style equals "two_cap_with_syn") + receptor_pot: receptor potential of ihc state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the ohc nonlinearity and agc are doing; both in terms of extra damping. + seg_agc_memory is an optional extra that gives the actual agc activity. + """ return run_segment(input_waves, hypers, weights, state, open_loop) @@ -2499,7 +2512,7 @@ def run_segment_jit_in_chunks_notraceable( open_loop: bool = False, segment_chunk_length: int = 32 * 48000, ) -> tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """Runs the jitted segment runner in segment groups. @@ -2524,11 +2537,15 @@ def run_segment_jit_in_chunks_notraceable( largest chunk. Returns: - naps: Neural activity pattern as a numpy array. + naps_out: Neural activity pattern as a numpy array. + naps_fibers_out: neural activity of the different fiber types + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") + v_recep_out: receptor potential of ihc state: The updated state of the CARFAC model. - BM: The basilar membrane motion as a numpy array. - seg_ohc & seg_agc are optional extra outputs useful for seeing what the + bm_out: The basilar membrane motion as a numpy array. + ohc_out & agc_out are optional extra outputs useful for seeing what the ohc nonlinearity and agc are doing; both in terms of extra damping. + agc_memory_out is optional and gives access to the actual 4-stage agc output Raises: RuntimeError: If this function is being JITTed, which it should not be. @@ -2544,38 +2561,46 @@ def run_segment_jit_in_chunks_notraceable( input_waves = jnp.reshape(input_waves, (-1, 1)) naps_out = [] naps_fibers_out = [] + v_recep_out = [] bm_out = [] ohc_out = [] agc_out = [] + agc_memory_out = [] # NOMUTANTS -- This is a performance optimization. while segment_length > 16: [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, receptor_pot, state, bm_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax = ( run_segment_jit(current_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) + v_recep_out.append(receptor_pot) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) + agc_memory_out.append(seg_agc_memory_jax) else: segment_length //= 2 [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, receptor_pot, state, bm_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax, = ( run_segment_jit(input_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) + v_recep_out.append(receptor_pot) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) + agc_memory_out.append(seg_agc_memory_jax) naps_out = np.concatenate(naps_out, 0) naps_fibers_out = np.concatenate(naps_fibers_out, 0) + v_recep_out = np.concatenate(v_recep_out, 0) bm_out = np.concatenate(bm_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) - return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out + agc_memory_out = np.concatenate(agc_memory_out, 0) + return naps_out, naps_fibers_out, v_recep_out, state, bm_out, ohc_out, agc_out, agc_memory_out, diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index ca58785..4490e7f 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, _, _, _, _, _ = carfac_jax.run_segment( + nap_output, naps_fibers_output, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, _, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, naps_fibers_jax, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, _, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( + seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, _, state_jax, _, _, _ = run_segment_function( + naps_jax, naps_fibers_jax, state_jax, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 3237024..7caaf71 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment( + naps_jax, _, _, state_jax, _, _, _, _= carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index 246b5f0..acb985b 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -324,24 +324,17 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): state_jax_copied = copy.deepcopy(state_jax) # Only tests the JITted version because this is what we will use. - naps_jax, _, _, bm_jax, ohc_jax, agc_jax = ( - carfac_jax.run_segment_jit( - run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False - ) + naps_jax, naps_fibers_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit( + run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) - ( - naps_jax_chunked, - _, - _, - bm_chunked, - ohc_chunked, - agc_chunked, - ) = carfac_jax.run_segment_jit_in_chunks_notraceable( - run_seg_input, - hypers_jax, - weights_jax, - state_jax_copied, - open_loop=False, + naps_jax_chunked, naps_fibes_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = ( + carfac_jax.run_segment_jit_in_chunks_notraceable( + run_seg_input, + hypers_jax, + weights_jax, + state_jax_copied, + open_loop=False, + ) ) self.assertLess(jnp.max(abs(naps_jax_chunked - naps_jax)), 1e-7) self.assertLess(jnp.max(abs(bm_chunked - bm_jax)), 1e-7) @@ -387,7 +380,7 @@ def test_equal_forward_pass( run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Only tests the JITted version because this is what we will use. - naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index d789415..de19e63 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -38,12 +38,14 @@ def run_multiple_segment_states_shmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -89,7 +91,7 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = ( + naps, naps_fibers, receptor_pot, ret_state, bm, seg_ohc, seg_agc, seg_agc_memory = ( carfac_jax.run_segment_jit( input_waves, hypers, weights, state, open_loop ) @@ -100,19 +102,23 @@ def parallel_helper(input_waves, state): return ( naps[None], naps_fibers[None], + receptor_pot[None], ret_state, bm[None], seg_ohc[None], seg_agc[None], + seg_agc_memory[None], ) ( stacked_naps, stacked_naps_fibers, + stacked_receptor_pot, stacked_states, stacked_bm, stacked_ohc, stacked_agc, + stacked_agc_memory, ) = parallel_helper(input_waves_array, batch_state) output_states = _tree_unstack(stacked_states) output = [] @@ -122,10 +128,12 @@ def parallel_helper(input_waves, state): tup = ( stacked_naps[i], stacked_naps_fibers[i], + stacked_receptor_pot[i], output_state, stacked_bm[i], stacked_ohc[i], stacked_agc[i], + stacked_agc_memory[i], ) output.append(tup) return output @@ -140,12 +148,14 @@ def run_multiple_segment_pmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -169,10 +179,12 @@ def run_multiple_segment_pmap( ( stacked_naps, stacked_naps_fibers, + stacked_receptor_pot, stacked_states, stacked_bm, stacked_ohc, stacked_agc, + stacked_agc_memory, ) = pmapped(input_waves_array, hypers, weights, state, open_loop) output_states = _tree_unstack(stacked_states) @@ -181,10 +193,12 @@ def run_multiple_segment_pmap( tup = ( stacked_naps[i], stacked_naps_fibers[i], + stacked_receptor_pot[i], output_state, stacked_bm[i], stacked_ohc[i], stacked_agc[i], + stacked_agc_memory[i], ) output.append(tup) return output diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index a88c449..b239a2b 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, receptor_pot_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a, agc_memory_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, receptor_pot_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b, agc_memory_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -88,20 +88,24 @@ def test_same_outputs_parallel_for_pmap(self): self.assertTrue((combined_output[1][0] == nap_out_b).all()) self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) - self.assertTrue((combined_output[0][3] == bm_out_a).all()) - self.assertTrue((combined_output[1][3] == bm_out_b).all()) - self.assertTrue((combined_output[0][4] == ohc_out_a).all()) - self.assertTrue((combined_output[1][4] == ohc_out_b).all()) - self.assertTrue((combined_output[0][5] == agc_out_a).all()) - self.assertTrue((combined_output[1][5] == agc_out_b).all()) + self.assertTrue((combined_output[0][2] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][2] == receptor_pot_b).all()) + self.assertTrue((combined_output[0][4] == bm_out_a).all()) + self.assertTrue((combined_output[1][4] == bm_out_b).all()) + self.assertTrue((combined_output[0][5] == ohc_out_a).all()) + self.assertTrue((combined_output[1][5] == ohc_out_b).all()) + self.assertTrue((combined_output[0][6] == agc_out_a).all()) + self.assertTrue((combined_output[1][6] == agc_out_b).all()) + self.assertTrue((combined_output[0][7] == agc_memory_out_a).all()) + self.assertTrue((combined_output[1][7] == agc_memory_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][3]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][3]) ) ) @@ -114,7 +118,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, receptor_pot_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a, agc_memory_out_a, = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -126,7 +130,7 @@ def test_same_outputs_parallel_for_shmap(self): # Run sample B twice, so we have a separate "starting" state for the # test for shmap. - _, _, state_out_b_first, _, _, _ = carfac.run_segment_jit( + _, _, _, state_out_b_first, _, _, _, _ = carfac.run_segment_jit( self.sample_b, self.hypers, self.weights, @@ -134,7 +138,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, receptor_pot_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b, agc_memory_out_b, = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -154,20 +158,25 @@ def test_same_outputs_parallel_for_shmap(self): self.assertTrue((combined_output[1][0] == nap_out_b).all()) self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) - self.assertTrue((combined_output[0][3] == bm_out_a).all()) - self.assertTrue((combined_output[1][3] == bm_out_b).all()) - self.assertTrue((combined_output[0][4] == ohc_out_a).all()) - self.assertTrue((combined_output[1][4] == ohc_out_b).all()) - self.assertTrue((combined_output[0][5] == agc_out_a).all()) - self.assertTrue((combined_output[1][5] == agc_out_b).all()) + self.assertTrue((combined_output[0][2] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][2] == receptor_pot_b).all()) + self.assertTrue((combined_output[0][4] == bm_out_a).all()) + self.assertTrue((combined_output[1][4] == bm_out_b).all()) + self.assertTrue((combined_output[0][5] == ohc_out_a).all()) + self.assertTrue((combined_output[1][5] == ohc_out_b).all()) + self.assertTrue((combined_output[0][6] == agc_out_a).all()) + self.assertTrue((combined_output[1][6] == agc_out_b).all()) + self.assertTrue((combined_output[0][6] == agc_memory_out_a).all()) + self.assertTrue((combined_output[1][6] == agc_memory_out_b).all()) + self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][3]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][3]) ) ) @@ -175,12 +184,12 @@ def test_same_outputs_parallel_for_shmap(self): # equality is complete and double sided. self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[0][2], state_out_a) + jax.tree.map(jnp.allclose, combined_output[0][3], state_out_a) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[1][2], state_out_b) + jax.tree.map(jnp.allclose, combined_output[1][3], state_out_b) ) ) From 52b7159b494e828354cec0f2e2e0e6e7b6161120 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Thu, 26 Dec 2024 14:51:54 +1100 Subject: [PATCH 2/7] Removed space in comment --- python/jax/carfac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 5bb329d..b3f939c 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2499,7 +2499,6 @@ def run_segment_jit( seg_ohc & seg_agc are optional extra outputs useful for seeing what the ohc nonlinearity and agc are doing; both in terms of extra damping. seg_agc_memory is an optional extra that gives the actual agc activity. - """ return run_segment(input_waves, hypers, weights, state, open_loop) From 40dd379c9413625002ed7c0f2abf19a1769259e7 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Tue, 31 Dec 2024 12:47:34 +1100 Subject: [PATCH 3/7] Made changes to carfac_bench.py according to review Made changes as requested + altered other calls to run_segment/run_segment_jit that accompany changes to carfac.py --- python/jax/carfac_bench.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 4490e7f..b5319d5 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, naps_fibers_output, _, _, _, _ = carfac_jax.run_segment( + nap_output, _, _, _, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, naps_fibers_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, state_jax, _, _, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, naps_fibers_jax, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, _, _, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( + seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_receptor_pot, seg_ohc, seg_agc, seg_agc_actual = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State): params_jax ) short_silence = jnp.zeros(shape=(n_samp, n_ears)) - naps_jax, _, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _, _ = run_segment_function( short_silence, hypers_jax, weights_jax, state_jax, open_loop=False ) # This block ensures calculation. @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, naps_fibers_jax, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: From e7a09c83749342f9f5b3480e5be957e30cb02631 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Tue, 31 Dec 2024 13:46:47 +1100 Subject: [PATCH 4/7] Altered order of outputs in carfac.py receptor_pot (ihc receptor potential in two_cap model) is now an optional extra and positioned accordingly. Also separated + clarified origins of seg_ohc + seg_agc --- python/jax/carfac.py | 56 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/python/jax/carfac.py b/python/jax/carfac.py index b3f939c..d7394de 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2302,7 +2302,7 @@ def run_segment( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """This function runs the entire CARFAC model. @@ -2339,12 +2339,12 @@ def run_segment( naps: neural activity pattern naps_fibers: neural activity of different fibers (only populated with non-zeros when ihc_style equals "two_cap_with_syn") - receptor_pot: receptor potential of ihc state: the updated state of the CARFAC model. BM: The basilar membrane motion - seg_ohc & seg_agc are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. - seg_agc_memory is an optional extra that gives the actual agc activity. + receptor_pot: receptor potential of ihc (optional extra) + seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) + seg_agc: zb_memory parameter to observe agc activity (optional extra) + seg_agc_memory: actual 4-stage agc output (optional extra) """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) @@ -2361,8 +2361,8 @@ def run_segment( n_agc_stages = jnp.shape(state.ears[0].agc)[0] naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears)) - receptor_pot = jnp.zeros((n_samp, n_ch, n_ears)) bm = jnp.zeros((n_samp, n_ch, n_ears)) + receptor_pot = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc_memory = jnp.zeros((n_samp, n_agc_stages, n_ch, n_ears)) @@ -2381,7 +2381,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves = carry + naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2441,11 +2441,11 @@ def close_agc_loop_helper( state, ) - return (naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves), None + return (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, naps_fibers, receptor_pot, state, bm, seg_ohc, seg_agc, seg_agc_memory, input_waves), + (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2464,7 +2464,7 @@ def run_segment_jit( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """A JITted version of run_segment for convenience. @@ -2473,11 +2473,11 @@ def run_segment_jit( way to account for this is to always make a deep copy of the hypers and modify those. Example usage if modifying the hypers (which most users should not): - naps, _, _, _, _ = run_segment_jit(input, hypers, weights, state) + naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers, weights, state) hypers_copy = copy.deepcopy(hypers) hypers_jax2.ears[0].car.r1_coeffs /= 2.0 - naps, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state) + naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state) If no modifications to the CarfacHypers are made, the same hypers object should be reused. @@ -2493,12 +2493,12 @@ def run_segment_jit( naps: neural activity pattern naps_fibers: neural activity of the different fiber types (only populated with non-zeros when ihc_style equals "two_cap_with_syn") - receptor_pot: receptor potential of ihc state: the updated state of the CARFAC model. BM: The basilar membrane motion - seg_ohc & seg_agc are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. - seg_agc_memory is an optional extra that gives the actual agc activity. + receptor_pot: receptor potential of ihc (optional extra) + seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) + seg_agc: zb_memory parameter to observe agc activity (optional extra) + seg_agc_memory: actual 4-stage agc output (optional extra) """ return run_segment(input_waves, hypers, weights, state, open_loop) @@ -2511,7 +2511,7 @@ def run_segment_jit_in_chunks_notraceable( open_loop: bool = False, segment_chunk_length: int = 32 * 48000, ) -> tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """Runs the jitted segment runner in segment groups. @@ -2539,12 +2539,12 @@ def run_segment_jit_in_chunks_notraceable( naps_out: Neural activity pattern as a numpy array. naps_fibers_out: neural activity of the different fiber types (only populated with non-zeros when ihc_style equals "two_cap_with_syn") - v_recep_out: receptor potential of ihc state: The updated state of the CARFAC model. bm_out: The basilar membrane motion as a numpy array. - ohc_out & agc_out are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. - agc_memory_out is optional and gives access to the actual 4-stage agc output + v_recep_out: receptor potential of ihc (optional extra) + ohc_out za_memory parameter to observe ohc nonlinearity (optional extra) + agc_out: zb_memory parameter to observe agc activity (optional extra) + agc_memory_out: actual 4-stage agc output (optional extra) Raises: RuntimeError: If this function is being JITTed, which it should not be. @@ -2560,8 +2560,8 @@ def run_segment_jit_in_chunks_notraceable( input_waves = jnp.reshape(input_waves, (-1, 1)) naps_out = [] naps_fibers_out = [] - v_recep_out = [] bm_out = [] + v_recep_out = [] ohc_out = [] agc_out = [] agc_memory_out = [] @@ -2570,13 +2570,13 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, naps_fibers_jax, receptor_pot, state, bm_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax = ( run_segment_jit(current_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) - v_recep_out.append(receptor_pot) bm_out.append(bm_jax) + v_recep_out.append(receptor_pot_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) agc_memory_out.append(seg_agc_memory_jax) @@ -2585,21 +2585,21 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, naps_fibers_jax, receptor_pot, state, bm_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax, = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax, = ( run_segment_jit(input_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) - v_recep_out.append(receptor_pot) + v_recep_out.append(receptor_pot_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) agc_memory_out.append(seg_agc_memory_jax) naps_out = np.concatenate(naps_out, 0) naps_fibers_out = np.concatenate(naps_fibers_out, 0) - v_recep_out = np.concatenate(v_recep_out, 0) bm_out = np.concatenate(bm_out, 0) + v_recep_out = np.concatenate(v_recep_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) agc_memory_out = np.concatenate(agc_memory_out, 0) - return naps_out, naps_fibers_out, v_recep_out, state, bm_out, ohc_out, agc_out, agc_memory_out, + return naps_out, naps_fibers_out, state, bm_out, v_recep_out, ohc_out, agc_out, agc_memory_out, From 88805b1988e3f7b81a2e7ebcfda638c32959ec9a Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Tue, 31 Dec 2024 13:52:45 +1100 Subject: [PATCH 5/7] Altered carfac_float64_test + carfac_util carfac_float64_test + carfac_util to match changes to run_segment in carfac.py --- python/jax/carfac_float64_test.py | 2 +- python/jax/carfac_util.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 7caaf71..6202a40 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, _, _, state_jax, _, _, _, _= carfac_jax.run_segment( + naps_jax, _, state_jax, _, _, _, _,_ = carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index de19e63..59dd06d 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -38,7 +38,6 @@ def run_multiple_segment_states_shmap( open_loop: bool = False, ) -> Sequence[ Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, @@ -46,6 +45,7 @@ def run_multiple_segment_states_shmap( jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -91,7 +91,7 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, naps_fibers, receptor_pot, ret_state, bm, seg_ohc, seg_agc, seg_agc_memory = ( + naps, naps_fibers, ret_state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory = ( carfac_jax.run_segment_jit( input_waves, hypers, weights, state, open_loop ) @@ -102,9 +102,9 @@ def parallel_helper(input_waves, state): return ( naps[None], naps_fibers[None], - receptor_pot[None], ret_state, bm[None], + receptor_pot[None], seg_ohc[None], seg_agc[None], seg_agc_memory[None], @@ -113,9 +113,9 @@ def parallel_helper(input_waves, state): ( stacked_naps, stacked_naps_fibers, - stacked_receptor_pot, stacked_states, stacked_bm, + stacked_receptor_pot, stacked_ohc, stacked_agc, stacked_agc_memory, @@ -128,9 +128,9 @@ def parallel_helper(input_waves, state): tup = ( stacked_naps[i], stacked_naps_fibers[i], - stacked_receptor_pot[i], output_state, stacked_bm[i], + stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], stacked_agc_memory[i], @@ -148,7 +148,6 @@ def run_multiple_segment_pmap( open_loop: bool = False, ) -> Sequence[ Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, @@ -156,6 +155,7 @@ def run_multiple_segment_pmap( jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -179,9 +179,9 @@ def run_multiple_segment_pmap( ( stacked_naps, stacked_naps_fibers, - stacked_receptor_pot, stacked_states, stacked_bm, + stacked_receptor_pot, stacked_ohc, stacked_agc, stacked_agc_memory, @@ -193,9 +193,9 @@ def run_multiple_segment_pmap( tup = ( stacked_naps[i], stacked_naps_fibers[i], - stacked_receptor_pot[i], output_state, stacked_bm[i], + stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], stacked_agc_memory[i], From 421e50cd36b0e9b5b7130ff1d25be251bcf4b5f3 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:10:57 +1100 Subject: [PATCH 6/7] Updated carfac_util_test.py No matches carfac.py changes to receptor_pot --- python/jax/carfac_util_test.py | 36 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index b239a2b..af2e8ff 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, nap_fibers_out_a, receptor_pot_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a, agc_memory_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a, agc_memory_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, nap_fibers_out_b, receptor_pot_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b, agc_memory_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, agc_memory_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -88,10 +88,10 @@ def test_same_outputs_parallel_for_pmap(self): self.assertTrue((combined_output[1][0] == nap_out_b).all()) self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) - self.assertTrue((combined_output[0][2] == receptor_pot_a).all()) - self.assertTrue((combined_output[1][2] == receptor_pot_b).all()) - self.assertTrue((combined_output[0][4] == bm_out_a).all()) - self.assertTrue((combined_output[1][4] == bm_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][4] == receptor_pot_b).all()) self.assertTrue((combined_output[0][5] == ohc_out_a).all()) self.assertTrue((combined_output[1][5] == ohc_out_b).all()) self.assertTrue((combined_output[0][6] == agc_out_a).all()) @@ -100,12 +100,12 @@ def test_same_outputs_parallel_for_pmap(self): self.assertTrue((combined_output[1][7] == agc_memory_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][3]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][3]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -118,7 +118,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, nap_fibers_out_a, receptor_pot_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a, agc_memory_out_a, = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a,, ohc_out_a, agc_out_a, agc_memory_out_a, = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -138,7 +138,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, nap_fibers_out_b, receptor_pot_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b, agc_memory_out_b, = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, agc_memory_out_b, = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -158,10 +158,10 @@ def test_same_outputs_parallel_for_shmap(self): self.assertTrue((combined_output[1][0] == nap_out_b).all()) self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) - self.assertTrue((combined_output[0][2] == receptor_pot_a).all()) - self.assertTrue((combined_output[1][2] == receptor_pot_b).all()) - self.assertTrue((combined_output[0][4] == bm_out_a).all()) - self.assertTrue((combined_output[1][4] == bm_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][4] == receptor_pot_b).all()) self.assertTrue((combined_output[0][5] == ohc_out_a).all()) self.assertTrue((combined_output[1][5] == ohc_out_b).all()) self.assertTrue((combined_output[0][6] == agc_out_a).all()) @@ -171,12 +171,12 @@ def test_same_outputs_parallel_for_shmap(self): self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][3]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][3]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -184,12 +184,12 @@ def test_same_outputs_parallel_for_shmap(self): # equality is complete and double sided. self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[0][3], state_out_a) + jax.tree.map(jnp.allclose, combined_output[0][2], state_out_a) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[1][3], state_out_b) + jax.tree.map(jnp.allclose, combined_output[1][2], state_out_b) ) ) From 3921e339717ffe140a566bbb7e3a6f70b4939042 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:52:16 +1100 Subject: [PATCH 7/7] Removed agc_memory_out as parameter and updated agc_out agc_memory_out has been removed and replaced with agc_out. agc_out now equals the stage 1 and final agc output --- python/jax/carfac.py | 39 +++++++++++-------------------- python/jax/carfac_bench.py | 12 +++++----- python/jax/carfac_float64_test.py | 2 +- python/jax/carfac_util.py | 8 +------ python/jax/carfac_util_test.py | 12 ++++------ 5 files changed, 26 insertions(+), 47 deletions(-) diff --git a/python/jax/carfac.py b/python/jax/carfac.py index d7394de..25e1580 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2302,7 +2302,7 @@ def run_segment( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """This function runs the entire CARFAC model. @@ -2343,8 +2343,7 @@ def run_segment( BM: The basilar membrane motion receptor_pot: receptor potential of ihc (optional extra) seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) - seg_agc: zb_memory parameter to observe agc activity (optional extra) - seg_agc_memory: actual 4-stage agc output (optional extra) + seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra) """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) @@ -2365,7 +2364,6 @@ def run_segment( receptor_pot = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) - seg_agc_memory = jnp.zeros((n_samp, n_agc_stages, n_ch, n_ears)) # A 2022 addition to make open-loop running behave: if open_loop: @@ -2381,7 +2379,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves = carry + naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2415,11 +2413,7 @@ def run_segment_scan_helper(carry, k): bm = bm.at[k, :, ear].set(car_out) car_state = state.ears[ear].car seg_ohc = seg_ohc.at[k, :, ear].set(car_state.za_memory) - seg_agc = seg_agc.at[k, :, ear].set(car_state.zb_memory) - for i in range(n_agc_stages): - seg_agc_memory = seg_agc_memory.at[k, i, :, ear].set( - state.ears[ear].agc[i].agc_memory - ) + seg_agc = seg_agc.at[k, :, ear].set(state.ears[ear].agc[0].agc_memory) def close_agc_loop_helper( hypers: CarfacHypers, weights: CarfacWeights, state: CarfacState @@ -2441,11 +2435,11 @@ def close_agc_loop_helper( state, ) - return (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves), None + return (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory, input_waves), + (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2464,7 +2458,7 @@ def run_segment_jit( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """A JITted version of run_segment for convenience. @@ -2497,9 +2491,8 @@ def run_segment_jit( BM: The basilar membrane motion receptor_pot: receptor potential of ihc (optional extra) seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) - seg_agc: zb_memory parameter to observe agc activity (optional extra) - seg_agc_memory: actual 4-stage agc output (optional extra) - """ + seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra) +\ """ return run_segment(input_waves, hypers, weights, state, open_loop) @@ -2511,7 +2504,7 @@ def run_segment_jit_in_chunks_notraceable( open_loop: bool = False, segment_chunk_length: int = 32 * 48000, ) -> tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """Runs the jitted segment runner in segment groups. @@ -2543,8 +2536,7 @@ def run_segment_jit_in_chunks_notraceable( bm_out: The basilar membrane motion as a numpy array. v_recep_out: receptor potential of ihc (optional extra) ohc_out za_memory parameter to observe ohc nonlinearity (optional extra) - agc_out: zb_memory parameter to observe agc activity (optional extra) - agc_memory_out: actual 4-stage agc output (optional extra) + agc_out: agc output (stage 1) to observe agc activity (optional extra) Raises: RuntimeError: If this function is being JITTed, which it should not be. @@ -2570,7 +2562,7 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = ( run_segment_jit(current_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) @@ -2579,13 +2571,12 @@ def run_segment_jit_in_chunks_notraceable( v_recep_out.append(receptor_pot_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) - agc_memory_out.append(seg_agc_memory_jax) else: segment_length //= 2 [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, seg_agc_memory_jax, = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = ( run_segment_jit(input_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) @@ -2594,12 +2585,10 @@ def run_segment_jit_in_chunks_notraceable( bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) - agc_memory_out.append(seg_agc_memory_jax) naps_out = np.concatenate(naps_out, 0) naps_fibers_out = np.concatenate(naps_fibers_out, 0) bm_out = np.concatenate(bm_out, 0) v_recep_out = np.concatenate(v_recep_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) - agc_memory_out = np.concatenate(agc_memory_out, 0) - return naps_out, naps_fibers_out, state, bm_out, v_recep_out, ohc_out, agc_out, agc_memory_out, + return naps_out, naps_fibers_out, state, bm_out, v_recep_out, ohc_out, agc_out, diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index b5319d5..89ae033 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, _, _, _, _, _, _, _ = carfac_jax.run_segment( + nap_output, _, _, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, _, state_jax, _, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, _, _, _, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, _, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_receptor_pot, seg_ohc, seg_agc, seg_agc_actual = ( + seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_receptor_pot, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State): params_jax ) short_silence = jnp.zeros(shape=(n_samp, n_ears)) - naps_jax, _, state_jax, _, _, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _ = run_segment_function( short_silence, hypers_jax, weights_jax, state_jax, open_loop=False ) # This block ensures calculation. @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, _, state_jax, _, _, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 6202a40..cf75038 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, _, state_jax, _, _, _, _,_ = carfac_jax.run_segment( + naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index 59dd06d..3bc5957 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -91,7 +91,7 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, naps_fibers, ret_state, bm, receptor_pot, seg_ohc, seg_agc, seg_agc_memory = ( + naps, naps_fibers, ret_state, bm, receptor_pot, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( input_waves, hypers, weights, state, open_loop ) @@ -107,7 +107,6 @@ def parallel_helper(input_waves, state): receptor_pot[None], seg_ohc[None], seg_agc[None], - seg_agc_memory[None], ) ( @@ -118,7 +117,6 @@ def parallel_helper(input_waves, state): stacked_receptor_pot, stacked_ohc, stacked_agc, - stacked_agc_memory, ) = parallel_helper(input_waves_array, batch_state) output_states = _tree_unstack(stacked_states) output = [] @@ -133,7 +131,6 @@ def parallel_helper(input_waves, state): stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], - stacked_agc_memory[i], ) output.append(tup) return output @@ -155,7 +152,6 @@ def run_multiple_segment_pmap( jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -184,7 +180,6 @@ def run_multiple_segment_pmap( stacked_receptor_pot, stacked_ohc, stacked_agc, - stacked_agc_memory, ) = pmapped(input_waves_array, hypers, weights, state, open_loop) output_states = _tree_unstack(stacked_states) @@ -198,7 +193,6 @@ def run_multiple_segment_pmap( stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], - stacked_agc_memory[i], ) output.append(tup) return output diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index af2e8ff..bf9bd6f 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a, agc_memory_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, agc_memory_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -96,8 +96,6 @@ def test_same_outputs_parallel_for_pmap(self): self.assertTrue((combined_output[1][5] == ohc_out_b).all()) self.assertTrue((combined_output[0][6] == agc_out_a).all()) self.assertTrue((combined_output[1][6] == agc_out_b).all()) - self.assertTrue((combined_output[0][7] == agc_memory_out_a).all()) - self.assertTrue((combined_output[1][7] == agc_memory_out_b).all()) self.assertTrue( jax.tree_util.tree_all( jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) @@ -118,7 +116,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a,, ohc_out_a, agc_out_a, agc_memory_out_a, = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a, = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -138,7 +136,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, agc_memory_out_b, = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -166,8 +164,6 @@ def test_same_outputs_parallel_for_shmap(self): self.assertTrue((combined_output[1][5] == ohc_out_b).all()) self.assertTrue((combined_output[0][6] == agc_out_a).all()) self.assertTrue((combined_output[1][6] == agc_out_b).all()) - self.assertTrue((combined_output[0][6] == agc_memory_out_a).all()) - self.assertTrue((combined_output[1][6] == agc_memory_out_b).all()) self.assertTrue( jax.tree_util.tree_all(