Spaces:
Sleeping
Sleeping
| import io | |
| import streamlit as st | |
| from PyPDF2 import PdfReader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # Global variables | |
| knowledge_base = None | |
| qa_chain = None | |
| # PDF ํ์ผ ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
| def load_pdf(pdf_file): | |
| pdf_reader = PdfReader(pdf_file) | |
| text = "".join(page.extract_text() for page in pdf_reader.pages) | |
| return text | |
| # ํ ์คํธ๋ฅผ ์ฒญํฌ๋ก ๋ถํ | |
| def split_text(text): | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n", | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len | |
| ) | |
| return text_splitter.split_text(text) | |
| # FAISS ๋ฒกํฐ ์ ์ฅ์ ์์ฑ | |
| def create_knowledge_base(chunks): | |
| model_name = "sentence-transformers/all-mpnet-base-v2" # ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ช ์ | |
| embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| return FAISS.from_texts(chunks, embeddings) | |
| # Hugging Face ๋ชจ๋ธ ๋ก๋ | |
| def load_model(): | |
| model_name = "halyn/gemma2-2b-it-finetuned-paperqa" # ํ ์คํธ ์์ฑ ๋ชจ๋ธ ์ฌ์ฉ | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=None, clean_up_tokenization_spaces=False) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=None) | |
| return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1) | |
| # QA ์ฒด์ธ ์ค์ | |
| def setup_qa_chain(): | |
| global qa_chain | |
| try: | |
| pipe = load_model() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| qa_chain = load_qa_chain(llm, chain_type="stuff") | |
| # ๋ฉ์ธ ํ์ด์ง UI | |
| def main_page(): | |
| st.title("Welcome to GemmaPaperQA") | |
| st.subheader("Upload Your Paper") | |
| paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") | |
| if paper: | |
| st.write(f"Upload complete! File name: {paper.name}") | |
| # ํ์ผ ํฌ๊ธฐ ํ์ธ | |
| file_size = paper.size # ํ์ผ ํฌ๊ธฐ๋ฅผ ํ์ผ ํฌ์ธํฐ ์ด๋ ์์ด ํ์ธ | |
| if file_size > 10 * 1024 * 1024: # 10MB ์ ํ | |
| st.error("File is too large! Please upload a file smaller than 10MB.") | |
| return | |
| # ์ค๊ฐ ํ์ธ ์ ์ฐจ - PDF ๋ด์ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ | |
| with st.spinner('Processing PDF...'): | |
| try: | |
| paper.seek(0) # ํ์ผ ์ฝ๊ธฐ ํฌ์ธํฐ๋ฅผ ์ฒ์์ผ๋ก ๋๋๋ฆผ | |
| contents = paper.read() | |
| pdf_file = io.BytesIO(contents) | |
| text = load_pdf(pdf_file) | |
| # ํ ์คํธ๊ฐ ์ถ์ถ๋์ง ์์ ๊ฒฝ์ฐ ์๋ฌ ์ฒ๋ฆฌ | |
| if len(text.strip()) == 0: | |
| st.error("The PDF appears to have no extractable text. Please check the file and try again.") | |
| return | |
| st.text_area("Preview of extracted text", text[:1000], height=200) | |
| st.write(f"Total characters extracted: {len(text)}") | |
| global knowledge_base | |
| if st.button("Proceed with this file"): | |
| chunks = split_text(text) | |
| knowledge_base = create_knowledge_base(chunks) | |
| if knowledge_base is None: | |
| st.error("Failed to create knowledge base.") | |
| return | |
| setup_qa_chain() | |
| st.session_state.paper_name = paper.name[:-4] | |
| st.session_state.page = "chat" | |
| st.success("PDF successfully processed! You can now ask questions.") | |
| except Exception as e: | |
| st.error(f"Failed to process the PDF: {str(e)}") | |
| # ์ฑํ ํ์ด์ง UI | |
| def chat_page(): | |
| st.title(f"Ask anything about {st.session_state.paper_name}") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Chat here!"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| response = get_response_from_model(prompt) | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if st.button("Go back to main page"): | |
| st.session_state.page = "main" | |
| # ๋ชจ๋ธ ์๋ต ์ฒ๋ฆฌ | |
| def get_response_from_model(prompt): | |
| try: | |
| global knowledge_base, qa_chain | |
| if not knowledge_base: | |
| return "No PDF has been uploaded yet." | |
| if not qa_chain: | |
| return "QA chain is not initialized." | |
| docs = knowledge_base.similarity_search(prompt) | |
| response = qa_chain.run(input_documents=docs, question=prompt) | |
| if "Helpful Answer:" in response: | |
| response = response.split("Helpful Answer:")[1].strip() | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ํ์ด์ง ์ค์ | |
| if "page" not in st.session_state: | |
| st.session_state.page = "main" | |
| if "paper_name" not in st.session_state: | |
| st.session_state.paper_name = "" | |
| # ํ์ด์ง ๋ ๋๋ง | |
| if st.session_state.page == "main": | |
| main_page() | |
| elif st.session_state.page == "chat": | |
| chat_page() | |