Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serving errors: deprecated dependencies and structure error #103

Open
sjw8793 opened this issue Nov 7, 2023 · 2 comments
Open

Serving errors: deprecated dependencies and structure error #103

sjw8793 opened this issue Nov 7, 2023 · 2 comments

Comments

@sjw8793
Copy link

sjw8793 commented Nov 7, 2023

When I try to serve LLaMA with v3_8 TPU as suggested in example script, there were some errors.

Environment

  • TPU: v3-8
  • Software: tpu-vm-base

Command

$ git clone https://github.com/young-geng/EasyLM
$ cd EasyLM
$ ./scripts/tpu_vm_setup.sh
$
$ python -m EasyLM.models.llama.llama_train \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --total_steps=500 \
    --log_freq=50 \
    --load_llama_config='1b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --save_model_freq=100 \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=1e-3 \
    --optimizer.adamw_optimizer.end_lr=1e-4 \
    --optimizer.adamw_optimizer.lr_warmup_steps=10 \
    --optimizer.adamw_optimizer.lr_decay_steps=100 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='/path/to/dataset.jsonl' \
    --train_dataset.json_dataset.seq_length=1024 \
    --train_dataset.json_dataset.batch_size=64 \
    --train_dataset.json_dataset.tokenizer_processes=1 \
    --checkpointer.save_optimizer_state=True \
    --checkpointer.float_dtype=bf16 \
    --logger.online=False \
    --logger.output_dir="~/ellama_checkpoints/" \
|& tee $HOME/output1107_wiki.txt 
$ 
$ python -m EasyLM.models.llama.llama_serve \
    --load_llama_config='1b' \
    --load_checkpoint="params::/path/to/streaming_train_state" \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --input_length=1024 \
    --seq_length=2048 \
    --lm_server.batch_size=4 \
    --lm_server.port=8888 \
    --lm_server.pre_compile='all'

1. Deprecation warning

ImportError: cannot import name 'soft_unicode' from 'markupsafe'
ImportError: Pandas requires version '3.0.0' or newer of 'jinja2'

These can be solved by adding 2 lines to tpu_requirements.txt

markupsafe==2.0.1
jinja2~=3.0.0

DeprecationWarning: concurrency_count has been deprecated. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via max_threads in launch().

I was able to solve this by deleting concurrency_count=1 in serving.py, line 403.
According to Gradio v4.0.0 changelog, concurrency_count is removed and can be replaced with concurrency_limit. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.

2. Structure error

However, when I solve deprecation errors above, this error appears:

Error Log

I1107 06:16:48.996244 140573565926464 mesh_utils.py:260] Reordering mesh to physical ring order on single-tray TPU v2/v3.
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name user_fn already exists, using user_fn_1
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_1
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
$HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_2
  warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
Traceback (most recent call last):
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
    mlxu.run(main)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
    server.run()
  File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
    self.loglikelihood(pre_compile_data, pre_compile_data)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
    loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 775, in infer_params
    return common_infer_params(pjit_info_args, *args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
    logits = hf_model.module.apply(
  File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "$HOME/.local/lib/python3.8/site-packages/flax/linen/module.py", line 1511, in apply
    return apply(
  File "$HOME/.local/lib/python3.8/site-packages/flax/core/scope.py", line 930, in wrapper
    raise errors.ApplyScopeInvalidVariablesStructureError(variables)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e.  {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in <module>
    mlxu.run(main)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main
    server.run()
  File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run
    self.loglikelihood(pre_compile_data, pre_compile_data)
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood
    loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
  File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood
    logits = hf_model.module.apply(

flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the variables (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)

It seems like something went wrong with "params" loading at function load_trainstate_checkpoint in checkpoint.py, but I couldn't figure where.
Is there someone who knows what's wrong?

@sjw8793
Copy link
Author

sjw8793 commented Nov 8, 2023

There was some misunderstanding; I should have used trainstate_params instead of params in my case.
So, the serving script should be like below:

$ python -m EasyLM.models.llama.llama_serve \
    --load_llama_config='1b' \
    --load_checkpoint="trainstate_params::/path/to/streaming_train_state" \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --input_length=1024 \
    --seq_length=2048 \
    --lm_server.batch_size=4 \
    --lm_server.port=8888 \
    --lm_server.pre_compile='all'

@sjw8793 sjw8793 closed this as completed Nov 8, 2023
@sjw8793
Copy link
Author

sjw8793 commented Nov 8, 2023

Sorry for reopen, I thought it'd be better to keep this opened until the dependency deprecation is solved.

@sjw8793 sjw8793 reopened this Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant