File size: 2,495 Bytes
9602bb7
 
 
551e9e2
aa73b52
551e9e2
de239b9
551e9e2
 
 
 
 
 
 
 
 
ce62afe
de239b9
 
aa73b52
de239b9
 
 
aa73b52
 
9602bb7
de239b9
9602bb7
 
 
aa73b52
9602bb7
 
de239b9
 
9602bb7
aa73b52
9602bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551e9e2
9602bb7
551e9e2
 
aa73b52
9602bb7
 
 
aa73b52
9602bb7
 
 
0d2c9df
9602bb7
aa73b52
9602bb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import spaces

# Model and tokenizer initialization
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
DEFAULT_SYSTEM_PROMPT = "你是 Ring,蚂蚁集团开发的智能助手,致力于为用户提供有用的信息和帮助,用中文回答用户的问题。"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)

@spaces.GPU(duration=150)
def generate_response(message, history, system_prompt=None):
    # (msg, history, system_prompt) -> str: stream response (yielding partial responses)
    
    # Determine the system prompt to use
    prompt_to_use = system_prompt if system_prompt is not None else DEFAULT_SYSTEM_PROMPT

    # To construct the 'chat', we start with system prompt
    # then append user and assistant messages from history
    messages = [
        {"role": "system", "content": prompt_to_use}
    ]
    
    # Add conversation history
    # history is a list of (human, assistant) tuples
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        if assistant: # Ensure assistant message is not None
            messages.append({"role": "assistant", "content": assistant})
    
    # Add current message from user
    messages.append({"role": "user", "content": message})
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Tokenize input
    model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
    
    # Generate response with streaming
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
    
    generation_kwargs = dict(
        **model_inputs,
        max_new_tokens=8192,
        temperature=0.7,
        do_sample=True,
        streamer=streamer,
    )
    
    # Start generation in a separate thread to enable streaming
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # ... and yield the generated tokens as they are produced
    response = ""
    for new_text in streamer:
        response += new_text
        yield response
    
    # wait for the generation thread to finish
    thread.join()