DSDUDEd commited on
Commit
3213643
Β·
verified Β·
1 Parent(s): e9076b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ from datasets import load_dataset
3
+ import random
4
+
5
+ # Load tokenizer & model
6
+ tokenizer = AutoTokenizer.from_pretrained("driaforall/mem-agent")
7
+ model = AutoModelForCausalLM.from_pretrained("driaforall/mem-agent")
8
+
9
+ # Load dataset of prompts
10
+ ds = load_dataset("fka/awesome-chatgpt-prompts")
11
+ prompts = ds["train"]["prompt"]
12
+
13
+ # Pick a random prompt to initialize Firebox's "persona"
14
+ firebox_persona = random.choice(prompts)
15
+ print("πŸ”₯ Firebox Persona:", firebox_persona)
16
+
17
+ def firebox_chat(user_message, history=[]):
18
+ """
19
+ Firebox chatbot using driaforall/mem-agent with persona prompts.
20
+ """
21
+ # If history is empty, insert persona/system prompt first
22
+ if not history:
23
+ history.append({"role": "system", "content": firebox_persona})
24
+
25
+ # Add user message
26
+ history.append({"role": "user", "content": user_message})
27
+
28
+ # Apply Hugging Face chat template
29
+ inputs = tokenizer.apply_chat_template(
30
+ history,
31
+ add_generation_prompt=True,
32
+ tokenize=True,
33
+ return_dict=True,
34
+ return_tensors="pt",
35
+ ).to(model.device)
36
+
37
+ # Generate Firebox reply
38
+ outputs = model.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
39
+ reply = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
40
+
41
+ # Add assistant reply to history
42
+ history.append({"role": "assistant", "content": reply})
43
+
44
+ return reply, history
45
+
46
+
47
+ # Example usage
48
+ if __name__ == "__main__":
49
+ history = []
50
+ print("πŸ”₯ Firebox:", firebox_persona)
51
+ while True:
52
+ user_input = input("You: ")
53
+ if user_input.lower() in ["quit", "exit", "bye"]:
54
+ print("πŸ”₯ Firebox: Goodbye! Stay safe.")
55
+ break
56
+ reply, history = firebox_chat(user_input, history)
57
+ print("πŸ”₯ Firebox:", reply)