-
Notifications
You must be signed in to change notification settings - Fork 60
Description
Several unit tests are failing:
-
RuntimeParamsSliceTest.test_wext_in_dynamic_runtime_params_cannot_be_negative
and
InterpolatedParamTest.test_interpolated_param_need_xs_to_be_sorted
both expect JAX error-checks to be enabled by default, butjax_utils._ERRORS_ENABLED
defaults toFalse
. -
RunSimulationMainTest
tests fail under pytest withFATAL Flags parsing error: Unknown command line flag 'rootdir'
because
run_simulation_main.py
invokesjax.config.parse_flags_with_absl()
(which grabs allsys.argv
) before tests can override flags. -
TransportSmoothingTest
ends up with aUnion[..., FakeTransportConfig, FakeTransportConfig]
after two test classes each doToraxConfig.model_fields['transport'].annotation |= FakeTransportConfig ToraxConfig.model_rebuild(force=True)
Pydantic rejects duplicate discriminators. We need to dedupe the union before rebuild.
Reproduction
Run the full test suite under pytest:
pytest -q
You’ll see exactly 7 failures (the three categories above).
Proposed Fixes
-
In
torax/jax_utils.py
, change_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', False)
to
_ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', True)
so that
error_if
actually raises by default (tests then pass), and document that production users can still disable errors viaTORAX_ERRORS_ENABLED=False
to re-enable the JAX compilation cache. -
In
run_simulation_main.py
, remove the top-leveljax.config.parse_flags_with_absl()
, and immediately mark flags as parsed so pytest’s unwanted flags are ignored:from absl import flags flags.FLAGS.mark_as_parsed()
(or wrap the parse so it only runs when invoked as a script, not on import).
-
In
torax/torax_pydantic/model_config.py
(insideToraxConfig.model_rebuild
), detect and remove duplicate types in thetransport
Union before calling the original rebuild. For example:# before rebuild: from typing import get_args, Union anns = ToraxConfig.model_fields['transport'].annotation.__args__ unique = tuple(dict.fromkeys(anns)) # dedupe while preserving order ToraxConfig.model_fields['transport'].annotation = Union[unique] # then call original model_rebuild(...)
Tests
All existing tests should pass. No new behavior change beyond fixing the test suite.
Let me know if you’d like any tweaks before I PR it!