Spaces:
Running
Running
| import os | |
| import sys | |
| import gc | |
| import tempfile | |
| from pathlib import Path | |
| # κΆν λ¬Έμ ν΄κ²°μ μν νκ²½ λ³μ μ€μ (μ΅μλ¨μ μμΉ) | |
| temp_dir = tempfile.gettempdir() | |
| os.environ["STREAMLIT_HOME"] = temp_dir | |
| os.environ["STREAMLIT_CONFIG_DIR"] = os.path.join(temp_dir, ".streamlit") | |
| os.environ["STREAMLIT_SERVER_HEADLESS"] = "true" | |
| os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # μΊμ λλ ν 리λ tempλ‘ μ€μ | |
| os.environ["TRANSFORMERS_CACHE"] = os.path.join(temp_dir, "transformers_cache") | |
| os.environ["HF_HOME"] = os.path.join(temp_dir, "huggingface") | |
| # PyTorch ν΄λμ€ κ²½λ‘ μΆ©λ ν΄κ²° | |
| try: | |
| import torch | |
| import importlib.util | |
| torch_classes_path = os.path.join(os.path.dirname(importlib.util.find_spec("torch").origin), "classes") | |
| if hasattr(torch, "classes"): | |
| torch.classes.__path__ = [torch_classes_path] | |
| except Exception: | |
| pass | |
| import streamlit as st | |
| # transformers λΌμ΄λΈλ¬λ¦¬ import λ° μν μ²΄ν¬ | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| TRANSFORMERS_AVAILABLE = True | |
| except ImportError as e: | |
| TRANSFORMERS_AVAILABLE = False | |
| st.error(f"Transformers λΌμ΄λΈλ¬λ¦¬λ₯Ό λΆλ¬μ¬ μ μμ΅λλ€: {e}") | |
| # νμ΄μ§ μ€μ | |
| st.set_page_config( | |
| page_title="TinyLlama Demo", | |
| page_icon="π¦", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.title("π¦ TinyLlama 1.1B (CPU μ μ©) λ°λͺ¨") | |
| if not TRANSFORMERS_AVAILABLE: | |
| st.error("νμν λΌμ΄λΈλ¬λ¦¬λ₯Ό μ€μΉν΄μ£ΌμΈμ:") | |
| st.code(""" | |
| pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | |
| pip install transformers | |
| pip install streamlit | |
| """, language="bash") | |
| st.stop() | |
| def load_tinyllama_model(): | |
| """TinyLlama 1.1B λͺ¨λΈ λ‘λ (CPU Only)""" | |
| try: | |
| # μ¬λ¬ κ°λ₯ν λͺ¨λΈ μ΄λ¦ μλ | |
| model_options = [ | |
| "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "microsoft/DialoGPT-small" # λ°±μ μ΅μ | |
| ] | |
| for model_name in model_options: | |
| try: | |
| st.info(f"λͺ¨λΈ μλ μ€: {model_name}") | |
| # ν ν¬λμ΄μ λ‘λ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| cache_dir=os.environ.get("TRANSFORMERS_CACHE") | |
| ) | |
| # λͺ¨λΈ λ‘λ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| cache_dir=os.environ.get("TRANSFORMERS_CACHE"), | |
| device_map="cpu" | |
| ) | |
| # CPUλ‘ λͺ μμ μ΄λ λ° νκ° λͺ¨λ | |
| model = model.to("cpu") | |
| model.eval() | |
| # ν ν¬λμ΄μ μ€μ | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # λ©λͺ¨λ¦¬ μ 리 | |
| gc.collect() | |
| return model, tokenizer, f"β {model_name} λ‘λ μ±κ³΅!" | |
| except Exception as model_error: | |
| st.warning(f"{model_name} λ‘λ μ€ν¨: {str(model_error)}") | |
| continue | |
| return None, None, "β λͺ¨λ λͺ¨λΈ λ‘λ μ€ν¨" | |
| except Exception as e: | |
| return None, None, f"β μ 체 λ‘λ μ€ν¨: {str(e)}" | |
| def generate_text(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7): | |
| """μμ ν ν μ€νΈ μμ± ν¨μ""" | |
| try: | |
| # μ λ ₯ κΈΈμ΄ μ ν | |
| max_input_length = 400 | |
| # ν ν°ν | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_input_length, | |
| padding=False | |
| ) | |
| # CPUλ‘ μ΄λ | |
| inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
| input_length = inputs['input_ids'].shape[1] | |
| # μμ ν μμ± κΈΈμ΄ κ³μ° | |
| safe_max_tokens = min(max_new_tokens, 800 - input_length) | |
| if safe_max_tokens < 20: | |
| safe_max_tokens = 20 | |
| # μμ± μ€μ | |
| generation_kwargs = { | |
| "max_new_tokens": safe_max_tokens, | |
| "temperature": temperature, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "top_k": 50, | |
| "repetition_penalty": 1.1, | |
| "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "use_cache": True, | |
| "early_stopping": True | |
| } | |
| # λ©λͺ¨λ¦¬ μ 리 | |
| gc.collect() | |
| # μμ± μ€ν | |
| with st.spinner("ν μ€νΈ μμ± μ€..."): | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| **generation_kwargs | |
| ) | |
| # μλ‘ μμ±λ λΆλΆλ§ μΆμΆ | |
| new_tokens = outputs[0][input_length:] | |
| generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| return generated_text.strip() | |
| except Exception as e: | |
| raise Exception(f"μμ± μ€ μ€λ₯: {str(e)}") | |
| def main(): | |
| # λͺ¨λΈ λ‘λ | |
| with st.spinner("TinyLlama λͺ¨λΈ λ‘λ© μ€... (μ²μ μ€ν μ λ€μ΄λ‘λλ‘ μΈν΄ μκ°μ΄ 걸릴 μ μμ΅λλ€)"): | |
| model, tokenizer, status = load_tinyllama_model() | |
| st.info(status) | |
| if not (model and tokenizer): | |
| st.error("λͺ¨λΈ λ‘λμ μ€ν¨νμ΅λλ€. μΈν°λ· μ°κ²°μ νμΈνκ³ λ€μ μλν΄μ£ΌμΈμ.") | |
| return | |
| # μ¬μ΄λλ° μ€μ | |
| st.sidebar.header("βοΈ μμ± μ€μ ") | |
| max_new_tokens = st.sidebar.slider("μ΅λ μ ν ν° μ", 20, 200, 100) | |
| temperature = st.sidebar.slider("Temperature (μ°½μμ±)", 0.1, 1.0, 0.7, 0.1) | |
| # λμλ§ | |
| st.sidebar.header("π μ¬μ© κ°μ΄λ") | |
| st.sidebar.info(""" | |
| **Tips:** | |
| - ν둬ννΈλ λͺ ννκ³ κ°κ²°νκ² | |
| - CPU μ μ©μ΄λ―λ‘ μμ±μ μκ°μ΄ 걸립λλ€ | |
| - 첫 μ€ν μ λͺ¨λΈ λ€μ΄λ‘λλ‘ μκ°μ΄ λ 걸립λλ€ | |
| """) | |
| # λ©μΈ μΈν°νμ΄μ€ | |
| st.header("π¬ ν μ€νΈ μμ±") | |
| # μμ ν둬ννΈ | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| example_prompts = [ | |
| "μ¬μ©μ μ μ μ λ ₯", | |
| "The future of artificial intelligence is", | |
| "Once upon a time in a magical forest,", | |
| "Python is a programming language that", | |
| "Climate change is an important issue because", | |
| "The benefits of reading books include" | |
| ] | |
| selected_prompt = st.selectbox("μμ ν둬ννΈ μ ν:", example_prompts) | |
| with col2: | |
| st.write("") # κ³΅κ° ν보 | |
| st.write("") # κ³΅κ° ν보 | |
| if st.button("π² λλ€ μμ ", help="λλ€ν μμ ν둬ννΈ μ ν"): | |
| import random | |
| random_prompt = random.choice(example_prompts[1:]) # 첫 λ²μ§Έ μ μΈ | |
| st.session_state.random_prompt = random_prompt | |
| # ν둬ννΈ μ λ ₯ | |
| if selected_prompt == "μ¬μ©μ μ μ μ λ ₯": | |
| default_prompt = st.session_state.get('random_prompt', '') | |
| prompt = st.text_area( | |
| "ν둬ννΈλ₯Ό μ λ ₯νμΈμ:", | |
| value=default_prompt, | |
| height=100, | |
| placeholder="μ¬κΈ°μ ν μ€νΈλ₯Ό μ λ ₯νμΈμ..." | |
| ) | |
| else: | |
| prompt = st.text_area( | |
| "ν둬ννΈ:", | |
| value=selected_prompt, | |
| height=100 | |
| ) | |
| # ν ν° μ νμ | |
| if prompt and tokenizer: | |
| try: | |
| token_count = len(tokenizer.encode(prompt)) | |
| st.caption(f"νμ¬ ν둬ννΈ ν ν° μ: {token_count}") | |
| if token_count > 400: | |
| st.warning("β οΈ ν둬ννΈκ° λ무 κΉλλ€. 400 ν ν°μΌλ‘ μλ μλ¦Όλ©λλ€.") | |
| except: | |
| pass | |
| # μμ± λ²νΌ | |
| col1, col2, col3 = st.columns([1, 1, 2]) | |
| with col1: | |
| generate_btn = st.button("π μμ± μμ", type="primary", use_container_width=True) | |
| with col2: | |
| clear_btn = st.button("ποΈ κ²°κ³Ό μ§μ°κΈ°", use_container_width=True) | |
| # κ²°κ³Ό μ§μ°κΈ° | |
| if clear_btn: | |
| if 'generated_result' in st.session_state: | |
| del st.session_state['generated_result'] | |
| st.rerun() | |
| # ν μ€νΈ μμ± | |
| if generate_btn and prompt.strip(): | |
| try: | |
| # μμ± μ§νλ₯ νμ | |
| progress_container = st.container() | |
| with progress_container: | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| status_text.text("ν ν°ν μ€...") | |
| progress_bar.progress(20) | |
| status_text.text("ν μ€νΈ μμ± μ€... (CPUμμ μ€νλλ―λ‘ μκ°μ΄ 걸립λλ€)") | |
| progress_bar.progress(40) | |
| # μ€μ μμ± | |
| generated_text = generate_text( | |
| model, tokenizer, prompt.strip(), | |
| max_new_tokens, temperature | |
| ) | |
| progress_bar.progress(80) | |
| status_text.text("κ²°κ³Ό μ²λ¦¬ μ€...") | |
| # κ²°κ³Ό μ μ₯ | |
| full_result = prompt + generated_text | |
| st.session_state['generated_result'] = { | |
| 'prompt': prompt, | |
| 'generated': generated_text, | |
| 'full_text': full_result | |
| } | |
| progress_bar.progress(100) | |
| status_text.text("μλ£!") | |
| # μ§νλ₯ νμ μ κ±° | |
| progress_bar.empty() | |
| status_text.empty() | |
| except Exception as e: | |
| st.error(f"μμ± μ€ μ€λ₯κ° λ°μνμ΅λλ€: {str(e)}") | |
| st.info("π‘ λ€μ μλνκ±°λ λ μ§§μ ν둬ννΈλ₯Ό μ¬μ©ν΄λ³΄μΈμ.") | |
| elif generate_btn: | |
| st.warning("β οΈ ν둬ννΈλ₯Ό μ λ ₯ν΄μ£ΌμΈμ.") | |
| # κ²°κ³Ό νμ | |
| if 'generated_result' in st.session_state: | |
| result = st.session_state['generated_result'] | |
| st.header("π μμ± κ²°κ³Ό") | |
| # νμΌλ‘ κ΅¬λΆ | |
| tab1, tab2 = st.tabs(["π― μμ±λ ν μ€νΈλ§", "π μ 체 ν μ€νΈ"]) | |
| with tab1: | |
| st.markdown("**μλ‘ μμ±λ λΆλΆ:**") | |
| st.markdown(f'<div style="background-color: #f0f2f6; padding: 15px; border-radius: 10px; border-left: 4px solid #4CAF50;">{result["generated"]}</div>', unsafe_allow_html=True) | |
| with tab2: | |
| st.markdown("**μ 체 ν μ€νΈ (ν둬ννΈ + μμ±):**") | |
| st.text_area( | |
| "μ 체 κ²°κ³Ό:", | |
| value=result['full_text'], | |
| height=200, | |
| disabled=True | |
| ) | |
| # λ€μ΄λ‘λ λ²νΌ | |
| st.download_button( | |
| label="πΎ ν μ€νΈ νμΌλ‘ μ μ₯", | |
| data=result['full_text'], | |
| file_name=f"tinyllama_output_{len(result['full_text'])}.txt", | |
| mime="text/plain", | |
| use_container_width=True | |
| ) | |
| # μμ€ν μ 보 μ¬μ΄λλ° | |
| st.sidebar.header("π» μμ€ν μ 보") | |
| st.sidebar.write(f"**Python:** {sys.version.split()[0]}") | |
| if TRANSFORMERS_AVAILABLE: | |
| st.sidebar.write(f"**PyTorch:** {torch.__version__}") | |
| st.sidebar.write(f"**CUDA μ¬μ© κ°λ₯:** {'β ' if torch.cuda.is_available() else 'β'}") | |
| st.sidebar.write(f"**μ€ν λͺ¨λ:** CPU μ μ©") | |
| # μ±λ₯ ν | |
| with st.sidebar.expander("π μ±λ₯ μ΅μ ν ν"): | |
| st.markdown(""" | |
| **μλ ν₯μ:** | |
| - ν둬ννΈλ₯Ό 100λ¨μ΄ μ΄νλ‘ μ μ§ | |
| - ν ν° μλ₯Ό 150κ° μ΄νλ‘ μ ν | |
| - μ¬λ¬ νμμ λμ μ€ννμ§ μκΈ° | |
| **λ©λͺ¨λ¦¬ μ μ½:** | |
| - λ€λ₯Έ λ¬΄κ±°μ΄ μ ν리μΌμ΄μ μ’ λ£ | |
| - λΈλΌμ°μ ν μ΅μν | |
| """) | |
| if __name__ == "__main__": | |
| main() |