Skip to content

Commit bb73e15

Browse files
Merge pull request #292 from appwrite/feat-generate-with-tensorflow
feat: python generate with tensorflow
2 parents 882f67b + cdd4c06 commit bb73e15

File tree

7 files changed

+474
-0
lines changed

7 files changed

+474
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
#.idea/
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# 🤖 Python Generate with TensorFlow Function
2+
3+
Generate text using a TensorFlow-based RNN model.
4+
5+
## 🧰 Usage
6+
7+
### GET /
8+
9+
HTML form for interacting with the function.
10+
11+
### POST /
12+
13+
Query the model for a text generation completion.
14+
15+
**Parameters**
16+
17+
| Name | Description | Location | Type | Sample Value |
18+
| ------------ | ------------------------------------ | -------- | ------------------ | ------------------ |
19+
| Content-Type | The content type of the request body | Header | `application/json` | N/A |
20+
| prompt | Text to prompt the model | Body | String | `Once upon a time` |
21+
22+
Sample `200` Response:
23+
24+
Response from the model.
25+
26+
```json
27+
{
28+
"ok": true,
29+
"completion": "Once upon a time, in a land far, far away, there lived a wise old owl."
30+
}
31+
```
32+
33+
Sample `400` Response:
34+
35+
Response when the request body is missing.
36+
37+
```json
38+
{
39+
"ok": false,
40+
"error": "Missing body with a prompt."
41+
}
42+
```
43+
44+
Sample `500` Response:
45+
46+
Response when the model fails to respond.
47+
48+
```json
49+
{
50+
"ok": false,
51+
"error": "Failed to query model."
52+
}
53+
```
54+
55+
## ⚙️ Configuration
56+
57+
| Setting | Value |
58+
| ----------------- | -------------------------------------------------------- |
59+
| Runtime | Python ML (3.11) |
60+
| Entrypoint | `src/main.py` |
61+
| Build Commands | `pip install -r requirements.txt && python src/train.py` |
62+
| Permissions | `any` |
63+
| Timeout (Seconds) | 30 |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
tensorflow
2+
numpy
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
from .utils import get_static_file, throw_if_missing
4+
5+
6+
def main(context):
7+
if context.req.method == "GET":
8+
return context.res.send(
9+
get_static_file("index.html"),
10+
200,
11+
{"content-type": "text/html; charset=utf-8"},
12+
)
13+
14+
try:
15+
throw_if_missing(context.req.body, ["prompt"])
16+
except ValueError as err:
17+
return context.res.json({"ok": False, "error": str(err)}, 400)
18+
19+
prompt = context.req.body["prompt"]
20+
generated_text = generate_text(prompt)
21+
return context.res.json({"ok": True, "completion": generated_text}, 200)
22+
23+
24+
def generate_text(prompt):
25+
# Load the trained model and tokenizer
26+
model = tf.keras.models.load_model("text_generation_model.h5")
27+
char2idx = np.load("char2idx.npy", allow_pickle=True).item()
28+
idx2char = np.load("idx2char.npy", allow_pickle=True)
29+
30+
# Vectorize the prompt
31+
input_eval = [char2idx[s] for s in prompt]
32+
input_eval = tf.expand_dims(input_eval, 0)
33+
34+
# Generate text
35+
text_generated = []
36+
temperature = 1.0
37+
38+
model.reset_states()
39+
for _ in range(1000):
40+
predictions = model(input_eval)
41+
predictions = tf.squeeze(predictions, 0)
42+
predictions = predictions / temperature
43+
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
44+
45+
input_eval = tf.expand_dims([predicted_id], 0)
46+
text_generated.append(idx2char[predicted_id])
47+
48+
return prompt + "".join(text_generated)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import os
4+
5+
6+
def main():
7+
path_to_file = tf.keras.utils.get_file(
8+
"shakespeare.txt",
9+
"https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt",
10+
)
11+
text = open(path_to_file, "rb").read().decode(encoding="utf-8")
12+
vocab = sorted(set(text))
13+
char2idx = {u: i for i, u in enumerate(vocab)}
14+
idx2char = np.array(vocab)
15+
16+
text_as_int = np.array([char2idx[c] for c in text])
17+
seq_length = 100
18+
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
19+
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
20+
21+
def split_input_target(chunk):
22+
input_text = chunk[:-1]
23+
target_text = chunk[1:]
24+
return input_text, target_text
25+
26+
dataset = sequences.map(split_input_target)
27+
BATCH_SIZE = 64
28+
BUFFER_SIZE = 10000
29+
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
30+
31+
vocab_size = len(vocab)
32+
embedding_dim = 256
33+
rnn_units = 1024
34+
35+
model = tf.keras.Sequential(
36+
[
37+
tf.keras.layers.Embedding(
38+
vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]
39+
),
40+
tf.keras.layers.GRU(
41+
rnn_units,
42+
return_sequences=True,
43+
stateful=True,
44+
recurrent_initializer="glorot_uniform",
45+
),
46+
tf.keras.layers.Dense(vocab_size),
47+
]
48+
)
49+
50+
model.compile(
51+
optimizer="adam", loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True)
52+
)
53+
54+
EPOCHS = 10
55+
checkpoint_dir = "./training_checkpoints"
56+
57+
os.makedirs(checkpoint_dir, exist_ok=True)
58+
59+
checkpoint_prefix = f"{checkpoint_dir}/ckpt_{{epoch}}"
60+
61+
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
62+
filepath=checkpoint_prefix, save_weights_only=True
63+
)
64+
65+
model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
66+
67+
model.save("text_generation_model.h5")
68+
np.save("char2idx.npy", char2idx)
69+
np.save("idx2char.npy", idx2char)
70+
71+
os.remove(path_to_file)
72+
73+
74+
if __name__ == "__main__":
75+
main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
3+
__dirname = os.path.dirname(os.path.abspath(__file__))
4+
static_folder = os.path.join(__dirname, "../static")
5+
6+
7+
def get_static_file(file_name: str) -> str:
8+
"""
9+
Returns the contents of a file in the static folder
10+
11+
Parameters:
12+
file_name (str): Name of the file to read
13+
14+
Returns:
15+
(str): Contents of static/{file_name}
16+
"""
17+
file_path = os.path.join(static_folder, file_name)
18+
with open(file_path, "r") as file:
19+
return file.read()
20+
21+
22+
def throw_if_missing(obj: object, keys: list[str]) -> None:
23+
"""
24+
Throws an error if any of the keys are missing from the object
25+
26+
Parameters:
27+
obj (object): Object to check
28+
keys (list[str]): List of keys to check
29+
30+
Raises:
31+
ValueError: If any keys are missing
32+
"""
33+
missing = [key for key in keys if key not in obj or not obj[key]]
34+
if missing:
35+
raise ValueError(f"Missing required fields: {', '.join(missing)}")

0 commit comments

Comments
 (0)