CHUNYU0505 commited on
Commit
6f0de2e
·
verified ·
1 Parent(s): 033a019

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -80
app.py CHANGED
@@ -1,108 +1,61 @@
1
- # app.py
2
- import os, torch
3
- from langchain.docstore.document import Document
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import FAISS
6
- from langchain_huggingface import HuggingFaceEmbeddings
7
- from docx import Document as DocxDocument
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
- from huggingface_hub import login, snapshot_download
10
- import gradio as gr
11
-
12
  # -------------------------------
13
- # 1. 模型設定(中文 GPT2 + fallback)
14
  # -------------------------------
15
- PRIMARY_MODEL = "uer/gpt2-chinese-cluecorpusmedium"
16
- FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall"
17
-
18
- HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
- if HF_TOKEN:
20
- login(token=HF_TOKEN)
21
- print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
 
23
- def try_download_model(repo_id):
24
- local_dir = f"./models/{repo_id.split('/')[-1]}"
25
- if not os.path.exists(local_dir):
26
- print(f"⬇️ 嘗試下載模型 {repo_id} ...")
27
- try:
28
- snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
29
- except Exception as e:
30
- print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
31
- return None
32
- return local_dir
33
-
34
- # 嘗試下載 Primary,失敗就換 Small
35
- LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
36
- if LOCAL_MODEL_DIR is None:
37
- print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
38
- LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
39
- MODEL_NAME = FALLBACK_MODEL
40
  else:
41
- MODEL_NAME = PRIMARY_MODEL
42
-
43
- print(f"👉 最終使用模型:{MODEL_NAME}")
44
-
45
- # -------------------------------
46
- # 2. pipeline 載入
47
- # -------------------------------
48
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_DIR)
49
- model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
50
-
51
- # 修正 pad_token 缺失問題
52
- if tokenizer.pad_token is None:
53
- tokenizer.pad_token = tokenizer.eos_token
54
 
55
- generator = pipeline(
56
- "text-generation",
57
- model=model,
58
- tokenizer=tokenizer,
59
- device=-1 # CPU
60
- )
61
-
62
- def call_local_inference(prompt, max_new_tokens=256):
63
- try:
64
- # 強制補充中文提示
65
- if "中文" not in prompt:
66
- prompt += "\n(請用中文回答)"
67
-
68
- outputs = generator(
69
- prompt,
70
- max_new_tokens=max_new_tokens,
71
- do_sample=True,
72
- temperature=0.7,
73
- pad_token_id=tokenizer.pad_token_id
74
- )
75
- return outputs[0]["generated_text"]
76
- except Exception as e:
77
- return f"(生成失敗:{e})"
78
 
79
  # -------------------------------
80
- # 3. 文章生成
81
  # -------------------------------
82
  def generate_article_progress(query, segments=5):
83
  docx_file = "/tmp/generated_article.docx"
84
  doc = DocxDocument()
85
  doc.add_heading(query, level=1)
 
86
 
87
  all_text = []
88
- base_prompt = f"請依據下列主題生成一篇中文文章,主題:{query}\n每段約150-200字。\n"
 
 
 
 
89
 
90
  for i in range(segments):
91
- prompt = base_prompt + f"第{i+1}段:"
 
 
 
 
 
92
  paragraph = call_local_inference(prompt)
93
  all_text.append(paragraph)
 
 
 
94
  doc.add_paragraph(paragraph)
95
- yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
 
 
 
 
 
96
 
97
- doc.save(docx_file)
98
- yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
99
 
100
  # -------------------------------
101
- # 4. Gradio 介面
102
  # -------------------------------
103
  with gr.Blocks() as demo:
104
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
105
- gr.Markdown("固定使用 GPT2-Chinese(medium 下載失敗會自動 fallback 到 small)。")
106
 
107
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
108
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -------------------------------
2
+ # 0. 載入向量資料庫
3
  # -------------------------------
4
+ EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
5
+ embeddings_model = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
 
 
 
 
 
6
 
7
+ DB_PATH = "./faiss_db"
8
+ if os.path.exists(DB_PATH):
9
+ print("載入現有向量資料庫...")
10
+ db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  else:
12
+ raise ValueError("❌ 沒找到 faiss_db,請先建立向量資料庫")
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # -------------------------------
17
+ # 文章生成(RAG + 檢索片段 + 進度提示 + 即時寫入DOCX)
18
  # -------------------------------
19
  def generate_article_progress(query, segments=5):
20
  docx_file = "/tmp/generated_article.docx"
21
  doc = DocxDocument()
22
  doc.add_heading(query, level=1)
23
+ doc.save(docx_file) # 先建立空的 docx,避免後面保存出錯
24
 
25
  all_text = []
26
+
27
+ # 🔍 使用 RAG 從 FAISS 檢索相關文獻
28
+ retrieved_docs = retriever.get_relevant_documents(query)
29
+ context_texts = [d.page_content for d in retrieved_docs]
30
+ context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
31
 
32
  for i in range(segments):
33
+ progress_text = f"⏳ 正在生成第 {i+1}/{segments} 段..."
34
+ prompt = (
35
+ f"以下是佛教經論的相關段落:\n{context}\n\n"
36
+ f"請依據上面內容,寫一段約150-200字的中文文章,"
37
+ f"主題:{query}。\n第{i+1}段:"
38
+ )
39
  paragraph = call_local_inference(prompt)
40
  all_text.append(paragraph)
41
+
42
+ # ✅ 每段生成後立即寫入 DOCX
43
+ doc = DocxDocument(docx_file) # 重新打開現有的檔案
44
  doc.add_paragraph(paragraph)
45
+ doc.save(docx_file)
46
+
47
+ yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}", context, progress_text
48
+
49
+ final_progress = f"✅ 已完成全部 {segments} 段生成!"
50
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}", context, final_progress
51
 
 
 
52
 
53
  # -------------------------------
54
+ # Gradio 介面
55
  # -------------------------------
56
  with gr.Blocks() as demo:
57
  gr.Markdown("# 📺 電視弘法視頻生成文章RAG系統")
58
+ gr.Markdown("使用 GPT2-Chinese + FAISS RAG,生成文章。")
59
 
60
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
61
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="段落數")