rudall-e / app.py
anton-l's picture
anton-l HF staff
Update app.py
ea2efae
import random
import torch
import numpy as np
from tqdm import tqdm
from functools import partialmethod
import gradio as gr
from gradio.mix import Series
from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
# disable tqdm logging from the rudalle pipeline
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).half().to(device)
translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
def translation_wrapper(text: str):
input_ids = translation_tokenizer.encode(text, return_tensors="pt")
outputs = translation_model.generate(input_ids.to(device))
decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True)
return decoded
def dalle_wrapper(prompt: str):
top_k, top_p = random.choice([
(1024, 0.98),
(512, 0.97),
(384, 0.96),
])
images , _ = generate_images(
prompt,
tokenizer,
dalle,
vae,
top_k=top_k,
images_num=1,
top_p=top_p
)
title = f"<b>{prompt}</b>"
return title, images[0]
translator = gr.Interface(fn=translation_wrapper,
inputs=[gr.inputs.Textbox(label='What would you like to see?')],
outputs="text")
outputs = [
gr.outputs.HTML(label=""),
gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)
description = (
"ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
"This demo uses an English-Russian translation model to adapt the prompts. "
"Try pressing [Submit] multiple times to generate new images!"
)
article = (
"<p style='text-align: center'>"
"<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
"<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
"</p>"
)
examples = [["A still life of grapes and a bottle of wine"],
["Город в стиле киберпанк"],
["A colorful photo of a coral reef"],
["A white cat sitting in a cardboard box"]]
series = Series(translator, generator,
title='Kinda-English ruDALL-E',
description=description,
article=article,
layout='horizontal',
theme='huggingface',
examples=examples,
allow_flagging=False,
live=False,
enable_queue=True,
)
series.launch()