Skip to content

Commit

Permalink
More Inference Endpoints features and fixes (#68)
Browse files Browse the repository at this point in the history
* feat(generator): better handle exceptions on multiprocessing

This will raise an error, signaling there was a problem. Before the
root thread was getting stuck waiting for the agent that was dead. This
way it should exit.

* feat(tgi): add more debug on server

* chore(docker): entrypoint json output is set by default

It is possible to disable it by setting JSON_OUTPUT_DISABLE.
It is now possible also to play with more batch sizes.

* feat(generator): add bucketing functions to use in prefill

* feat(generator): store position_id in current slot

This will further simplify the implementation of prefill bucketing.

* fix(generator): correct input_ids and attention_mask padding

* fix(TGI): fix input truncation

Truncation was sub-optimal, and it was done on the wrong side.

* feat(generator): enable logs on children processes

* feat(tgi): warmup runs prefill/decode on all supported combinations

This will prevent XLA compilation at inference time. Note that I had to
disable dynamo compilation though, otherwise the model was not
generating correct results. This leads to slower generation, but at
least generation seems stable now.

* ci(tgi): create images when pushing on current branch

This will allow for testing IE before release.

* feat(tgi): reversed loop order in warmup to test memory limits earlier

* chore(ci): remove image generation for this branch
  • Loading branch information
tengomucho authored Jul 6, 2024
1 parent 246fb24 commit fd29591
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 60 deletions.
5 changes: 5 additions & 0 deletions optimum/tpu/xla_mp_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, manager: mp.Manager):
self.root_command = manager.list()
self.agent_ready = manager.Event()
self.output_data = manager.list()
self.agent_error = manager.Event()
self.agent_error.clear()

def send(self, command: int, *args) -> ListProxy:
"""Send a command and arguments to the agents and wait for the response.
Expand All @@ -30,6 +32,8 @@ def send(self, command: int, *args) -> ListProxy:
self.root_bell.set()
# wait again until agent is ready, meaning command has been processed
self.agent_ready.wait()
if self.agent_error.is_set():
raise RuntimeError("Error on one of threads, stopping.")
ret = self.output_data
return ret

Expand All @@ -41,6 +45,7 @@ def __init__(self, root_mailbox: RootMailbox):
self.root_command = root_mailbox.root_command
self.agent_ready = root_mailbox.agent_ready
self.output_data = root_mailbox.output_data
self.agent_error = root_mailbox.agent_error

def receive(self) -> ListProxy:
"""Wait for a command from the root process and return it.
Expand Down
15 changes: 14 additions & 1 deletion text-generation-inference/docker/entrypoint.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
ulimit -l 68719476736

# Hugging Face Hub related
if [[ -z "${BATCH_SIZE}" ]]; then
BATCH_SIZE=4
fi
export BATCH_SIZE="${BATCH_SIZE}"

if [[ -z "${JSON_OUTPUT_DISABLE}" ]]; then
JSON_OUTPUT_DISABLE=--json-output
else
JSON_OUTPUT_DISABLE=""
fi
export JSON_OUTPUT_DISABLE="${JSON_OUTPUT_DISABLE}"

if [[ -z "${MODEL_ID}" ]]; then
echo "MODEL_ID must be set"
exit 1
fi
export MODEL_ID="${MODEL_ID}"

text-generation-launcher --port 8080 \
--max-batch-size 4 \
--max-batch-size ${BATCH_SIZE} \
${JSON_OUTPUT_DISABLE} \
--model-id ${MODEL_ID}

Loading

0 comments on commit fd29591

Please sign in to comment.