Skip to content

Fix defaults and rebuild issues: enable error_if, isolate ABSEIL flags in CLI, dedupe transport config union #1067

@AlankritVerma01

Description

@AlankritVerma01

Several unit tests are failing:

  1. RuntimeParamsSliceTest.test_wext_in_dynamic_runtime_params_cannot_be_negative and
    InterpolatedParamTest.test_interpolated_param_need_xs_to_be_sorted both expect JAX error-checks to be enabled by default, but jax_utils._ERRORS_ENABLED defaults to False.

  2. RunSimulationMainTest tests fail under pytest with

    FATAL Flags parsing error: Unknown command line flag 'rootdir'  
    

    because run_simulation_main.py invokes jax.config.parse_flags_with_absl() (which grabs all sys.argv) before tests can override flags.

  3. TransportSmoothingTest ends up with a Union[..., FakeTransportConfig, FakeTransportConfig] after two test classes each do

    ToraxConfig.model_fields['transport'].annotation |= FakeTransportConfig
    ToraxConfig.model_rebuild(force=True)

    Pydantic rejects duplicate discriminators. We need to dedupe the union before rebuild.

Reproduction
Run the full test suite under pytest:

pytest -q

You’ll see exactly 7 failures (the three categories above).

Proposed Fixes

  1. In torax/jax_utils.py, change

    _ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', False)

    to

    _ERRORS_ENABLED: bool = env_bool('TORAX_ERRORS_ENABLED', True)

    so that error_if actually raises by default (tests then pass), and document that production users can still disable errors via TORAX_ERRORS_ENABLED=False to re-enable the JAX compilation cache.

  2. In run_simulation_main.py, remove the top-level jax.config.parse_flags_with_absl(), and immediately mark flags as parsed so pytest’s unwanted flags are ignored:

    from absl import flags
    flags.FLAGS.mark_as_parsed()

    (or wrap the parse so it only runs when invoked as a script, not on import).

  3. In torax/torax_pydantic/model_config.py (inside ToraxConfig.model_rebuild), detect and remove duplicate types in the transport Union before calling the original rebuild. For example:

    # before rebuild:
    from typing import get_args, Union
    anns = ToraxConfig.model_fields['transport'].annotation.__args__
    unique = tuple(dict.fromkeys(anns))  # dedupe while preserving order
    ToraxConfig.model_fields['transport'].annotation = Union[unique]
    # then call original model_rebuild(...)

Tests
All existing tests should pass. No new behavior change beyond fixing the test suite.


Let me know if you’d like any tweaks before I PR it!

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions