diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 4a6d9d4..25e1580 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2302,7 +2302,7 @@ def run_segment( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """This function runs the entire CARFAC model. @@ -2341,8 +2341,9 @@ def run_segment( (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion - seg_ohc & seg_agc are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. + receptor_pot: receptor potential of ihc (optional extra) + seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) + seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra) """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) @@ -2356,9 +2357,11 @@ def run_segment( # (n_ears, cfp.n_ears)) n_ch = hypers.ears[0].car.n_ch + n_agc_stages = jnp.shape(state.ears[0].agc)[0] naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears)) bm = jnp.zeros((n_samp, n_ch, n_ears)) + receptor_pot = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) @@ -2376,7 +2379,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry + naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2406,10 +2409,11 @@ def run_segment_scan_helper(carry, k): ) # save some output data: naps = naps.at[k, :, ear].set(ihc_out) + receptor_pot = receptor_pot.at[k, :, ear].set(v_recep) bm = bm.at[k, :, ear].set(car_out) car_state = state.ears[ear].car seg_ohc = seg_ohc.at[k, :, ear].set(car_state.za_memory) - seg_agc = seg_agc.at[k, :, ear].set(car_state.zb_memory) + seg_agc = seg_agc.at[k, :, ear].set(state.ears[ear].agc[0].agc_memory) def close_agc_loop_helper( hypers: CarfacHypers, weights: CarfacWeights, state: CarfacState @@ -2431,11 +2435,11 @@ def close_agc_loop_helper( state, ) - return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None + return (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), + (naps, naps_fibers, state, bm, receptor_pot, seg_ohc, seg_agc, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2454,7 +2458,7 @@ def run_segment_jit( state: CarfacState, open_loop: bool = False, ) -> Tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """A JITted version of run_segment for convenience. @@ -2463,11 +2467,11 @@ def run_segment_jit( way to account for this is to always make a deep copy of the hypers and modify those. Example usage if modifying the hypers (which most users should not): - naps, _, _, _, _ = run_segment_jit(input, hypers, weights, state) + naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers, weights, state) hypers_copy = copy.deepcopy(hypers) hypers_jax2.ears[0].car.r1_coeffs /= 2.0 - naps, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state) + naps, _, _, _, _, _, _, _ = run_segment_jit(input, hypers_copy, weights, state) If no modifications to the CarfacHypers are made, the same hypers object should be reused. @@ -2485,9 +2489,10 @@ def run_segment_jit( (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion - seg_ohc & seg_agc are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. - """ + receptor_pot: receptor potential of ihc (optional extra) + seg_ohc: za_memory parameter to observe ohc nonlinearity (optional extra) + seg_agc: agc output (stage 1) parameter to observe agc activity (optional extra) +\ """ return run_segment(input_waves, hypers, weights, state, open_loop) @@ -2499,7 +2504,7 @@ def run_segment_jit_in_chunks_notraceable( open_loop: bool = False, segment_chunk_length: int = 32 * 48000, ) -> tuple[ - jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, ]: """Runs the jitted segment runner in segment groups. @@ -2524,11 +2529,14 @@ def run_segment_jit_in_chunks_notraceable( largest chunk. Returns: - naps: Neural activity pattern as a numpy array. + naps_out: Neural activity pattern as a numpy array. + naps_fibers_out: neural activity of the different fiber types + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: The updated state of the CARFAC model. - BM: The basilar membrane motion as a numpy array. - seg_ohc & seg_agc are optional extra outputs useful for seeing what the - ohc nonlinearity and agc are doing; both in terms of extra damping. + bm_out: The basilar membrane motion as a numpy array. + v_recep_out: receptor potential of ihc (optional extra) + ohc_out za_memory parameter to observe ohc nonlinearity (optional extra) + agc_out: agc output (stage 1) to observe agc activity (optional extra) Raises: RuntimeError: If this function is being JITTed, which it should not be. @@ -2545,19 +2553,22 @@ def run_segment_jit_in_chunks_notraceable( naps_out = [] naps_fibers_out = [] bm_out = [] + v_recep_out = [] ohc_out = [] agc_out = [] + agc_memory_out = [] # NOMUTANTS -- This is a performance optimization. while segment_length > 16: [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = ( run_segment_jit(current_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) bm_out.append(bm_jax) + v_recep_out.append(receptor_pot_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) else: @@ -2565,17 +2576,19 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, state, bm_jax, receptor_pot_jax, seg_ohc_jax, seg_agc_jax, = ( run_segment_jit(input_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) naps_fibers_out.append(naps_fibers_jax) + v_recep_out.append(receptor_pot_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) naps_out = np.concatenate(naps_out, 0) naps_fibers_out = np.concatenate(naps_fibers_out, 0) bm_out = np.concatenate(bm_out, 0) + v_recep_out = np.concatenate(v_recep_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) - return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out + return naps_out, naps_fibers_out, state, bm_out, v_recep_out, ohc_out, agc_out, diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index ca58785..89ae033 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, _, _, _, _, _ = carfac_jax.run_segment( + nap_output, _, _, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, _, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, _, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, _, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( + seg_naps, seg_naps_fibers, jax_loop_state, seg_bm, seg_receptor_pot, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State): params_jax ) short_silence = jnp.zeros(shape=(n_samp, n_ears)) - naps_jax, _, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _ = run_segment_function( short_silence, hypers_jax, weights_jax, state_jax, open_loop=False ) # This block ensures calculation. @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, _, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 3237024..cf75038 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment( + naps_jax, _, state_jax, _, _, _, _ = carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index 246b5f0..acb985b 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -324,24 +324,17 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): state_jax_copied = copy.deepcopy(state_jax) # Only tests the JITted version because this is what we will use. - naps_jax, _, _, bm_jax, ohc_jax, agc_jax = ( - carfac_jax.run_segment_jit( - run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False - ) + naps_jax, naps_fibers_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit( + run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) - ( - naps_jax_chunked, - _, - _, - bm_chunked, - ohc_chunked, - agc_chunked, - ) = carfac_jax.run_segment_jit_in_chunks_notraceable( - run_seg_input, - hypers_jax, - weights_jax, - state_jax_copied, - open_loop=False, + naps_jax_chunked, naps_fibes_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = ( + carfac_jax.run_segment_jit_in_chunks_notraceable( + run_seg_input, + hypers_jax, + weights_jax, + state_jax_copied, + open_loop=False, + ) ) self.assertLess(jnp.max(abs(naps_jax_chunked - naps_jax)), 1e-7) self.assertLess(jnp.max(abs(bm_chunked - bm_jax)), 1e-7) @@ -387,7 +380,7 @@ def test_equal_forward_pass( run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Only tests the JITted version because this is what we will use. - naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, naps_fibers_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index d789415..3bc5957 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -44,6 +44,8 @@ def run_multiple_segment_states_shmap( jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -89,7 +91,7 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = ( + naps, naps_fibers, ret_state, bm, receptor_pot, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( input_waves, hypers, weights, state, open_loop ) @@ -102,6 +104,7 @@ def parallel_helper(input_waves, state): naps_fibers[None], ret_state, bm[None], + receptor_pot[None], seg_ohc[None], seg_agc[None], ) @@ -111,6 +114,7 @@ def parallel_helper(input_waves, state): stacked_naps_fibers, stacked_states, stacked_bm, + stacked_receptor_pot, stacked_ohc, stacked_agc, ) = parallel_helper(input_waves_array, batch_state) @@ -124,6 +128,7 @@ def parallel_helper(input_waves, state): stacked_naps_fibers[i], output_state, stacked_bm[i], + stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], ) @@ -146,6 +151,7 @@ def run_multiple_segment_pmap( jnp.ndarray, jnp.ndarray, jnp.ndarray, + jnp.ndarray, ] ]: """Run multiple equal-length, segments in carfac, Jitted, in parallel. @@ -171,6 +177,7 @@ def run_multiple_segment_pmap( stacked_naps_fibers, stacked_states, stacked_bm, + stacked_receptor_pot, stacked_ohc, stacked_agc, ) = pmapped(input_waves_array, hypers, weights, state, open_loop) @@ -183,6 +190,7 @@ def run_multiple_segment_pmap( stacked_naps_fibers[i], output_state, stacked_bm[i], + stacked_receptor_pot[i], stacked_ohc[i], stacked_agc[i], ) diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index a88c449..bf9bd6f 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -90,10 +90,12 @@ def test_same_outputs_parallel_for_pmap(self): self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) self.assertTrue((combined_output[0][3] == bm_out_a).all()) self.assertTrue((combined_output[1][3] == bm_out_b).all()) - self.assertTrue((combined_output[0][4] == ohc_out_a).all()) - self.assertTrue((combined_output[1][4] == ohc_out_b).all()) - self.assertTrue((combined_output[0][5] == agc_out_a).all()) - self.assertTrue((combined_output[1][5] == agc_out_b).all()) + self.assertTrue((combined_output[0][4] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][4] == receptor_pot_b).all()) + self.assertTrue((combined_output[0][5] == ohc_out_a).all()) + self.assertTrue((combined_output[1][5] == ohc_out_b).all()) + self.assertTrue((combined_output[0][6] == agc_out_a).all()) + self.assertTrue((combined_output[1][6] == agc_out_b).all()) self.assertTrue( jax.tree_util.tree_all( jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) @@ -114,7 +116,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, receptor_pot_a, ohc_out_a, agc_out_a, = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -126,7 +128,7 @@ def test_same_outputs_parallel_for_shmap(self): # Run sample B twice, so we have a separate "starting" state for the # test for shmap. - _, _, state_out_b_first, _, _, _ = carfac.run_segment_jit( + _, _, _, state_out_b_first, _, _, _, _ = carfac.run_segment_jit( self.sample_b, self.hypers, self.weights, @@ -134,7 +136,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, receptor_pot_b, ohc_out_b, agc_out_b, = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -156,10 +158,13 @@ def test_same_outputs_parallel_for_shmap(self): self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) self.assertTrue((combined_output[0][3] == bm_out_a).all()) self.assertTrue((combined_output[1][3] == bm_out_b).all()) - self.assertTrue((combined_output[0][4] == ohc_out_a).all()) - self.assertTrue((combined_output[1][4] == ohc_out_b).all()) - self.assertTrue((combined_output[0][5] == agc_out_a).all()) - self.assertTrue((combined_output[1][5] == agc_out_b).all()) + self.assertTrue((combined_output[0][4] == receptor_pot_a).all()) + self.assertTrue((combined_output[1][4] == receptor_pot_b).all()) + self.assertTrue((combined_output[0][5] == ohc_out_a).all()) + self.assertTrue((combined_output[1][5] == ohc_out_b).all()) + self.assertTrue((combined_output[0][6] == agc_out_a).all()) + self.assertTrue((combined_output[1][6] == agc_out_b).all()) + self.assertTrue( jax.tree_util.tree_all( jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2])