Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from threading import Event, Thread | |
| from typing import List, Tuple | |
| # Importing necessary packages | |
| from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from optimum.intel.openvino import OVModelForCausalLM | |
| import openvino as ov | |
| import openvino.properties as props | |
| import openvino.properties.hint as hints | |
| import openvino.properties.streams as streams | |
| from gradio_helper import make_demo # UI logic import | |
| from llm_config import SUPPORTED_LLM_MODELS | |
| # Model configuration setup | |
| max_new_tokens = 256 | |
| model_language_value = "English" | |
| model_id_value = 'qwen2.5-0.5b-instruct' | |
| prepare_int4_model_value = True | |
| enable_awq_value = False | |
| device_value = 'CPU' | |
| model_to_run_value = 'INT4' | |
| pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"] | |
| pt_model_name = model_id_value.split("-")[0] | |
| int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" | |
| int4_weights = int4_model_dir / "openvino_model.bin" | |
| model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
| model_name = model_configuration["model_id"] | |
| start_message = model_configuration["start_message"] | |
| history_template = model_configuration.get("history_template") | |
| has_chat_template = model_configuration.get("has_chat_template", history_template is None) | |
| current_message_template = model_configuration.get("current_message_template") | |
| stop_tokens = model_configuration.get("stop_tokens") | |
| tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {}) | |
| # Model loading | |
| core = ov.Core() | |
| ov_config = { | |
| hints.performance_mode(): hints.PerformanceMode.LATENCY, | |
| streams.num(): "1", | |
| props.cache_dir(): "" | |
| } | |
| tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) | |
| ov_model = OVModelForCausalLM.from_pretrained( | |
| int4_model_dir, | |
| device=device_value, | |
| ov_config=ov_config, | |
| config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), | |
| trust_remote_code=True, | |
| ) | |
| # Define stopping criteria for specific token sequences | |
| class StopOnTokens(StoppingCriteria): | |
| def __init__(self, token_ids): | |
| self.token_ids = token_ids | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids) | |
| if stop_tokens is not None: | |
| if isinstance(stop_tokens[0], str): | |
| stop_tokens = tok.convert_tokens_to_ids(stop_tokens) | |
| stop_tokens = [StopOnTokens(stop_tokens)] | |
| # Helper function for partial text update | |
| def default_partial_text_processor(partial_text: str, new_text: str) -> str: | |
| return partial_text + new_text | |
| text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor) | |
| # Convert conversation history to tokens based on model template | |
| def convert_history_to_token(history: List[Tuple[str, str]]): | |
| if pt_model_name == "baichuan2": | |
| system_tokens = tok.encode(start_message) | |
| history_tokens = [] | |
| for old_query, response in history[:-1]: | |
| round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response) | |
| history_tokens = round_tokens + history_tokens | |
| input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196] | |
| input_token = torch.LongTensor([input_tokens]) | |
| elif history_template is None or has_chat_template: | |
| messages = [{"role": "system", "content": start_message}] | |
| for idx, (user_msg, model_msg) in enumerate(history): | |
| if idx == len(history) - 1 and not model_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| break | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if model_msg: | |
| messages.append({"role": "assistant", "content": model_msg}) | |
| input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt") | |
| else: | |
| text = start_message + "".join( | |
| [history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])] | |
| ) | |
| text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1]) | |
| input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids | |
| return input_token | |
| # Initialize search tool | |
| search = DuckDuckGoSearchRun() | |
| # Determine if a search is needed based on the query | |
| def should_use_search(query: str) -> bool: | |
| search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current", | |
| "announcement", "bulletin", "report", "brief", "insight", "disclosure", "update", | |
| "release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate", | |
| "recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate", | |
| "explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define", | |
| "illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate", | |
| "break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion", | |
| "product", "performance", "resolution" | |
| ] | |
| return any(keyword in query.lower() for keyword in search_keywords) | |
| # Construct the prompt with optional search context | |
| def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str: | |
| instructions = "Use the information below if relevant to provide an accurate and concise answer. If no information is available, rely on your general knowledge." | |
| prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n" | |
| return prompt | |
| # Fetch search results for a query | |
| def fetch_search_results(query: str) -> str: | |
| search_results = search.invoke(query) | |
| print("Search results:", search_results) # Optional: Debugging output | |
| return f"Relevant and recent information:\n{search_results}" | |
| # Main chatbot function | |
| def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): | |
| user_query = history[-1][0] | |
| search_context = fetch_search_results(user_query) if should_use_search(user_query) else "" | |
| prompt = construct_model_prompt(user_query, search_context, history) | |
| input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history) | |
| # Limit input length to avoid exceeding token limit | |
| if input_ids.shape[1] > 2000: | |
| history = [history[-1]] | |
| # Configure response streaming | |
| streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "do_sample": temperature > 0.0, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer, | |
| "stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None, | |
| } | |
| # Signal completion | |
| stream_complete = Event() | |
| def generate_and_signal_complete(): | |
| try: | |
| ov_model.generate(**generate_kwargs) | |
| except RuntimeError as e: | |
| # Check if the error message indicates the request was canceled | |
| if "Infer Request was canceled" in str(e): | |
| print("Generation request was canceled.") | |
| else: | |
| # If it's a different RuntimeError, re-raise it | |
| raise e | |
| finally: | |
| # Signal completion of the stream | |
| stream_complete.set() | |
| t1 = Thread(target=generate_and_signal_complete) | |
| t1.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text = text_processor(partial_text, new_text) | |
| history[-1] = (user_query, partial_text) | |
| yield history | |
| def request_cancel(): | |
| ov_model.request.cancel() | |
| # Gradio setup and launch | |
| demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860) | |