Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

there is a requirements.txt file of whisper-jax? #167

Open
alcarazolabs opened this issue Dec 7, 2023 · 2 comments
Open

there is a requirements.txt file of whisper-jax? #167

alcarazolabs opened this issue Dec 7, 2023 · 2 comments

Comments

@alcarazolabs
Copy link

Hi, I'm struggling with whisper-jax, first of all. I successfully installed whisper-jax for cpu usage, however it goes very slow,
then I decided installa jax with cuda support, however then a lot of problems were comming out because of packages versions.

Here is a example, I needed to install this versions..

pip install orbax-checkpoint==0.1.8
pip install flax==0.6.4

You should post the packages versions (requirements.txt) for whisper-jax of this way the commutiy will have a better experience with this library. Is not enough just say "we used jax 0.4.5" you should also post the jaxlib version, flax version, orbax-checkpoint etc etc..

I will appreciate it if you share with us the requirements. thanks in advance.

@sanchit-gandhi
Copy link
Owner

Hey @alcarazolabs! The idea behind Whisper JAX is that it has minimal dependencies, solely Transformers, JAX, Flax and cached-property:

whisper-jax/setup.py

Lines 24 to 28 in 9c50a6e

_deps = [
"transformers>=4.27.4,<4.35.0",
"flax",
"cached-property",
]

Because JAX has different installation routes depending on your hardware, it's impossible to put it in a requirements.txt file in a way that respects each installation route. Instead, it is advocated that users install JAX based on the official instructions, and then install the remaining 3 dependencies for Whisper JAX (Transformers, Flax and cached-property). If you're having issues with installing JAX on CUDA, I recommend you ask on the JAX repository for more detailed support!

We don't pin any of the requirements on Flax since all you need is a version that is compatible with your JAX version. If we pinned the version of Flax, the requirements would become more stringent, and also not forward compatible with new versions of JAX/Flax.

Therefore, the requirements are fully defined as required dependencies in setup.py. You can install the trio of requirements by running:

pip install git+https://github.com/sanchit-gandhi/whisper-jax.git

@bvrockwell
Copy link

Hi @sanchit-gandhi - would you be able to list your benchmark/* dependencies please? These scripts actually don't seem to work anymore with naive package installs, but I could be wrong. If I simply install the additional required packages naively (ex: pip install datasets, ...), the benchmarking scripts fail and complain about jax_array. For context: I'm using python3.9 and jax 0.4.25 and have run "pip install ." after cloning the repo. The src code itself works fine but i'd like to run the benchmarks to validate performance. Thanks very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants