-
Notifications
You must be signed in to change notification settings - Fork 0
/
petals.py
86 lines (72 loc) · 2.85 KB
/
petals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Petals Chat WebSocket Client in Python
This script demonstrates how to interact with the Petals Chat WebSocket API for text generation.
It opens an inference session and generates text based on a given prompt.
For more details and to set up your own backend, please visit: https://github.com/petals-infra/chat.petals.dev
"""
import json
import websocket
import time
# Global variables
max_length = 150
stop_sequence = None # You can set this to a string if you want the generation to stop at a specific sequence
def on_open(ws):
"""Handles WebSocket open event."""
global prompt
print("WebSocket opened. Opening inference session...")
prompt = input("Enter your prompt: ")
# Open inference session
ws.send(json.dumps({
"type": "open_inference_session",
"model": "stabilityai/StableBeluga2",
"max_length": max_length
}))
def on_message(ws, event):
"""Handles WebSocket message event."""
print("Received a message.")
try:
response = json.loads(event)
if response.get('ok'):
if 'outputs' not in response:
print("Inference session opened. Generating text...")
for i in range(3):
print(".", end="", flush=True)
time.sleep(1)
print()
# Generate text
ws.send(json.dumps({
"type": "generate",
"inputs": prompt,
"max_length": max_length,
"do_sample": 1,
"temperature": 0.6,
"top_p": 0.9,
"stop_sequence": stop_sequence # Optional
}))
else:
print(f"Generated: {prompt}{response['outputs']}")
if response.get('stop', False):
print("Stopping generation.")
ws.close()
else:
print(f"Server responded with an error: {response.get('traceback', 'Unknown error')}")
ws.close()
except json.JSONDecodeError:
print("Failed to decode the received message as JSON.")
ws.close()
except Exception as e:
print(f"An unexpected error occurred: {e}")
ws.close()
def on_error(ws, error):
"""Handles WebSocket error event."""
print(f"An error occurred: {error}")
def on_close(ws, close_status_code, close_msg):
"""Handles WebSocket close event."""
print(f"WebSocket closed with code: {close_status_code}, message: {close_msg}")
if __name__ == "__main__":
ws = websocket.WebSocketApp("wss://chat.petals.dev/api/v2/generate",
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close)
ws.run_forever()