| import ctranslate2 | |
| from transformers import AutoTokenizer | |
| import threading | |
| import gradio as gr | |
| from typing import Optional | |
| from queue import Queue | |
| class TokenIteratorStreamer: | |
| def __init__(self, end_token_id: int, timeout: Optional[float] = None): | |
| self.end_token_id = end_token_id | |
| self.queue = Queue() | |
| self.timeout = timeout | |
| def put(self, token: int): | |
| self.queue.put(token, timeout=self.timeout) | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| token = self.queue.get(timeout=self.timeout) | |
| if token == self.end_token_id: | |
| raise StopIteration() | |
| else: | |
| return token | |
| def generate_prompt(history): | |
| prompt = "" | |
| for chain in history[:-1]: | |
| prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n" | |
| prompt += f"<human>: {history[-1][0]}\n<bot>:" | |
| tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt)) | |
| return tokens | |
| def generate(streamer, history): | |
| def stepResultCallback(result): | |
| streamer.put(result.token_id) | |
| if result.is_last and (result.token_id != end_token_id): | |
| streamer.put(end_token_id) | |
| print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}") | |
| tokens = generate_prompt(history) | |
| results = translator.translate_batch( | |
| [tokens], | |
| beam_size=1, | |
| max_decoding_length = 256, | |
| repetition_penalty = 1.8, | |
| callback = stepResultCallback | |
| ) | |
| return results | |
| translator = ctranslate2.Translator("model", intra_threads=2) | |
| tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt") | |
| end_token = "</s>" | |
| end_token_id = tokenizer.encode(end_token)[0] | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| def user(user_message, history): | |
| return "", history + [[user_message, ""]] | |
| def bot(history): | |
| bot_message_tokens = [] | |
| streamer = TokenIteratorStreamer(end_token_id = end_token_id) | |
| generation_thread = threading.Thread(target=generate, args=(streamer, history)) | |
| generation_thread.start() | |
| for token in streamer: | |
| bot_message_tokens.append(token) | |
| history[-1][1] = tokenizer.decode(bot_message_tokens) | |
| yield history | |
| generation_thread.join() | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |