-
Notifications
You must be signed in to change notification settings - Fork 0
/
FurGen-1.0.py
74 lines (61 loc) · 2.71 KB
/
FurGen-1.0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import tensorflow as tf
import numpy as np
import PIL.Image
import json
import keras
from diffusers import DiffusionPipeline
from transformers import CLIPProcessor, AutoTokenizer, TFAutoModel
print("furGen-ON")
# Prompt user for API token
api_token = input("Please enter your Hugging Face API token: ")
os.environ["HUGGINGFACE_CO_API_TOKEN"] = api_token
# Load the model configuration from a URL
print("getting request for config.json url")
config_url = "https://huggingface.co/lunarfish/furry-diffusion/blob/main/unet/config.json"
cache_dir = "D:\\cache"
config_path = tf.keras.utils.get_file("config.json", config_url, cache_dir=cache_dir)
with open(config_path) as f:
content = f.read()
print(content)
response = json.loads(content)
model_config = response
# Load the model weights from the Hugging Face model hub
print("loading model weights")
model = TFAutoModel.from_pretrained("lunarfish/furry-diffusion", config=model_config)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Create the diffusion pipeline
print("creating diffusion pipeline")
tokenizer = AutoTokenizer.from_pretrained("lunarfish/furry-diffusion")
# Define function to generate image from text prompt
print("generating fursona")
def generate_fursona(prompt, truncation=0.5):
# Generate text input
print("generating text input:", prompt)
input_text = tf.constant(prompt, tf.string)
# Convert input text to tokens
print("converting input text into tokens")
input_tokens = tokenizer(input_text, truncation=truncation, padding='max_length', max_length=128, return_tensors='tf')
# Generate image from input tokens
print("generating image")
with tf.device('/cpu:0'):
generated_image = model.generate(input_ids=input_tokens['input_ids'],
attention_mask=input_tokens['attention_mask'],
max_length=128,
num_beams=1,
no_repeat_ngram_size=2,
early_stopping=True)[0]
# Postprocess image
print("postprocessing", generated_image)
generated_image = generated_image.numpy()
generated_image = (generated_image - generated_image.min()) / (generated_image.max() - generated_image.min()) * 255
generated_image = generated_image.astype(np.uint8)
generated_image = PIL.Image.fromarray(generated_image)
return generated_image
# Prompt user for text input
prompt = input("Fox with purple fur and green eyes: ")
# Generate and save image
print("Generating fursona image...")
image = generate_fursona(prompt)
image.save("D:\\pfp\\furGen(OUT1).png")
print("Image saved!")