CHUNYU0505 commited on
Commit
058eba2
·
verified ·
1 Parent(s): e3d4ffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os, torch
2
  from langchain.docstore.document import Document
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -9,20 +10,34 @@ from huggingface_hub import login, snapshot_download
9
  import gradio as gr
10
 
11
  # -------------------------------
12
- # 1. 模型設定(中文 T5)
13
  # -------------------------------
14
- MODEL_NAME = "Langboat/mengzi-t5-base" # ✅ 換成穩定的中文 T5
 
15
 
16
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
17
  if HF_TOKEN:
18
  login(token=HF_TOKEN)
19
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
20
 
21
- # 嘗試下載模型
22
- LOCAL_MODEL_DIR = f"./models/{MODEL_NAME.split('/')[-1]}"
23
- if not os.path.exists(LOCAL_MODEL_DIR):
24
- print(f"⬇️ 嘗試下載模型 {MODEL_NAME} ...")
25
- snapshot_download(repo_id=MODEL_NAME, token=HF_TOKEN, local_dir=LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  print(f"👉 最終使用模型:{MODEL_NAME}")
28
 
@@ -31,22 +46,28 @@ print(f"👉 最終使用模型:{MODEL_NAME}")
31
  # -------------------------------
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  LOCAL_MODEL_DIR,
34
- use_fast=False # 避免 tiktoken / fast tokenizer 問題
35
  )
36
- model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
 
 
 
 
 
 
 
 
 
37
 
38
  generator = pipeline(
39
- "text2text-generation", # Seq2Seq 用這個
40
  model=model,
41
  tokenizer=tokenizer,
42
- device=-1 # CPU
43
  )
44
 
45
  def call_local_inference(prompt, max_new_tokens=256):
46
  try:
47
- if "中文" not in prompt:
48
- prompt += "\n(請用中文回答)"
49
-
50
  outputs = generator(
51
  prompt,
52
  max_new_tokens=max_new_tokens,
@@ -58,7 +79,7 @@ def call_local_inference(prompt, max_new_tokens=256):
58
  return f"(生成失敗:{e})"
59
 
60
  # -------------------------------
61
- # 3. RAG 部分:向量資料庫
62
  # -------------------------------
63
  DB_PATH = "./faiss_db"
64
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
@@ -68,10 +89,10 @@ if os.path.exists(os.path.join(DB_PATH, "index.faiss")):
68
  print("✅ 載入現有向量資料庫...")
69
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
70
  else:
71
- print("⚠️ 沒有找到資料庫,請先建立 faiss_db")
72
- db = None
73
 
74
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) if db else None
75
 
76
  # -------------------------------
77
  # 4. 文章生成(結合 RAG)
@@ -83,12 +104,9 @@ def generate_article_progress(query, segments=3):
83
 
84
  all_text = []
85
 
86
- # 🔍 從資料庫檢索
87
- context = ""
88
- if retriever:
89
- retrieved_docs = retriever.get_relevant_documents(query)
90
- context_texts = [d.page_content for d in retrieved_docs]
91
- context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
92
 
93
  for i in range(segments):
94
  prompt = (
@@ -100,27 +118,29 @@ def generate_article_progress(query, segments=3):
100
  all_text.append(paragraph)
101
  doc.add_paragraph(paragraph)
102
 
103
- yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME}"
104
 
105
  doc.save(docx_file)
106
- yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME}"
107
 
108
  # -------------------------------
109
  # 5. Gradio 介面
110
  # -------------------------------
111
  with gr.Blocks() as demo:
112
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
 
 
113
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
114
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
115
  output_text = gr.Textbox(label="生成文章")
116
  output_file = gr.File(label="下載 DOCX")
117
- model_info = gr.Textbox(label="模型資訊")
118
 
119
  btn = gr.Button("生成文章")
120
  btn.click(
121
  generate_article_progress,
122
  inputs=[query_input, segments_input],
123
- outputs=[output_text, output_file, model_info]
124
  )
125
 
126
  if __name__ == "__main__":
 
1
+ # app.py
2
  import os, torch
3
  from langchain.docstore.document import Document
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
10
  import gradio as gr
11
 
12
  # -------------------------------
13
+ # 1. 模型設定(中文 T5 / Pegasus
14
  # -------------------------------
15
+ PRIMARY_MODEL = "imxly/t5-pegasus-small" # 適合中文摘要/生成
16
+ FALLBACK_MODEL = "uer/gpt2-chinese-cluecorpussmall" # 若 T5 無法下載就 fallback GPT2
17
 
18
  HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
19
  if HF_TOKEN:
20
  login(token=HF_TOKEN)
21
  print("✅ 已使用 HUGGINGFACEHUB_API_TOKEN 登入 Hugging Face")
22
 
23
+ def try_download_model(repo_id):
24
+ local_dir = f"./models/{repo_id.split('/')[-1]}"
25
+ if not os.path.exists(local_dir):
26
+ print(f"⬇️ 嘗試下載模型 {repo_id} ...")
27
+ try:
28
+ snapshot_download(repo_id=repo_id, token=HF_TOKEN, local_dir=local_dir)
29
+ except Exception as e:
30
+ print(f"⚠️ 模型 {repo_id} 無法下載: {e}")
31
+ return None
32
+ return local_dir
33
+
34
+ LOCAL_MODEL_DIR = try_download_model(PRIMARY_MODEL)
35
+ if LOCAL_MODEL_DIR is None:
36
+ print("⚠️ 切換到 fallback 模型:小型 GPT2-Chinese")
37
+ LOCAL_MODEL_DIR = try_download_model(FALLBACK_MODEL)
38
+ MODEL_NAME = FALLBACK_MODEL
39
+ else:
40
+ MODEL_NAME = PRIMARY_MODEL
41
 
42
  print(f"👉 最終使用模型:{MODEL_NAME}")
43
 
 
46
  # -------------------------------
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  LOCAL_MODEL_DIR,
49
+ use_fast=False # 防止 sentencepiece 問題
50
  )
51
+
52
+ # 判斷 GPU (CL3) 或 CPU
53
+ device = 0 if torch.cuda.is_available() else -1
54
+ print(f"💻 使用裝置:{'GPU' if device == 0 else 'CPU'}")
55
+
56
+ try:
57
+ model = AutoModelForSeq2SeqLM.from_pretrained(LOCAL_MODEL_DIR)
58
+ except:
59
+ from transformers import AutoModelForCausalLM
60
+ model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_DIR)
61
 
62
  generator = pipeline(
63
+ "text2text-generation" if "t5" in MODEL_NAME or "pegasus" in MODEL_NAME else "text-generation",
64
  model=model,
65
  tokenizer=tokenizer,
66
+ device=device
67
  )
68
 
69
  def call_local_inference(prompt, max_new_tokens=256):
70
  try:
 
 
 
71
  outputs = generator(
72
  prompt,
73
  max_new_tokens=max_new_tokens,
 
79
  return f"(生成失敗:{e})"
80
 
81
  # -------------------------------
82
+ # 3. FAISS 向量資料庫載入
83
  # -------------------------------
84
  DB_PATH = "./faiss_db"
85
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
89
  print("✅ 載入現有向量資料庫...")
90
  db = FAISS.load_local(DB_PATH, embeddings_model, allow_dangerous_deserialization=True)
91
  else:
92
+ print("⚠️ 找不到向量資料庫,將建立空的 DB")
93
+ db = FAISS.from_documents([], embeddings_model)
94
 
95
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
96
 
97
  # -------------------------------
98
  # 4. 文章生成(結合 RAG)
 
104
 
105
  all_text = []
106
 
107
+ retrieved_docs = retriever.get_relevant_documents(query)
108
+ context_texts = [d.page_content for d in retrieved_docs]
109
+ context = "\n".join([f"{i+1}. {txt}" for i, txt in enumerate(context_texts[:3])])
 
 
 
110
 
111
  for i in range(segments):
112
  prompt = (
 
118
  all_text.append(paragraph)
119
  doc.add_paragraph(paragraph)
120
 
121
+ yield "\n\n".join(all_text), None, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
122
 
123
  doc.save(docx_file)
124
+ yield "\n\n".join(all_text), docx_file, f"本次使用模型:{MODEL_NAME},裝置:{'GPU' if device == 0 else 'CPU'}"
125
 
126
  # -------------------------------
127
  # 5. Gradio 介面
128
  # -------------------------------
129
  with gr.Blocks() as demo:
130
  gr.Markdown("# 📺 電視弘法視頻生成文章 RAG 系統")
131
+ gr.Markdown("使用 Hugging Face 本地模型 + FAISS RAG,僅基於資料庫生成文章。")
132
+
133
  query_input = gr.Textbox(lines=2, placeholder="請輸入文章主題", label="文章主題")
134
  segments_input = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="段落數")
135
  output_text = gr.Textbox(label="生成文章")
136
  output_file = gr.File(label="下載 DOCX")
137
+ status_info = gr.Label(label="狀態")
138
 
139
  btn = gr.Button("生成文章")
140
  btn.click(
141
  generate_article_progress,
142
  inputs=[query_input, segments_input],
143
+ outputs=[output_text, output_file, status_info]
144
  )
145
 
146
  if __name__ == "__main__":