Spaces:
Sleeping
Sleeping
| import os | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import ( | |
| SystemMessage, | |
| HumanMessage, | |
| AIMessage | |
| ) | |
| from datasets import load_dataset | |
| from pinecone import Pinecone | |
| from pinecone import ServerlessSpec | |
| import time | |
| from langchain_openai import OpenAIEmbeddings | |
| from tqdm.auto import tqdm | |
| dataset = load_dataset( | |
| "jamescalam/llama-2-arxiv-papers-chunked", | |
| split="train" | |
| ) | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| chat = ChatOpenAI( | |
| openai_api_key = os.environ["OPENAI_API_KEY"], | |
| model='gpt-3.5-turbo' | |
| ) | |
| messages = [ | |
| SystemMessage(content="You are a helpful assistant."), | |
| HumanMessage(content="Hi AI, how are you today?"), | |
| AIMessage(content="I'm great thank you. How can I help you?"), | |
| HumanMessage(content="I'd like to understand string theory.") | |
| ] | |
| res = chat(messages) | |
| # add latest AI response to messages | |
| messages.append(res) | |
| # connect to pinecone | |
| api_key = os.getenv('PINECONE_API_KEY') | |
| # configure client | |
| pc = Pinecone(api_key=api_key) | |
| # connect to serverless | |
| spec = ServerlessSpec( | |
| cloud="aws", region="us-east-1" | |
| ) | |
| # initialize index | |
| index_name = 'llama-2-rag' | |
| existing_indexes = [ | |
| index_info["name"] for index_info in pc.list_indexes() | |
| ] | |
| # check if index already exists (it shouldn't if this is first time) | |
| if index_name not in existing_indexes: | |
| # if does not exist, create index | |
| pc.create_index( | |
| index_name, | |
| dimension=1536, # dimensionality of ada 002 | |
| metric='dotproduct', | |
| spec=spec | |
| ) | |
| # wait for index to be initialized | |
| while not pc.describe_index(index_name).status['ready']: | |
| time.sleep(1) | |
| # connect to index | |
| index = pc.Index(index_name) | |
| time.sleep(1) | |
| # view index stats | |
| index.describe_index_stats() | |
| # create vector embeddings of our index | |
| embed_model = OpenAIEmbeddings(model="text-embedding-ada-002") | |
| # iterate over dataset | |
| data = dataset.to_pandas() | |
| batch_size = 100 | |
| for i in tqdm(range(0, len(data), batch_size)): | |
| i_end = min(len(data), i+batch_size) | |
| # get batch of data | |
| batch = data.iloc[i:i_end] | |
| # generate unique ids for each chunk | |
| ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()] | |
| # get text to embed | |
| texts = [x['chunk'] for _, x in batch.iterrows()] | |
| # embed text | |
| embeds = embed_model.embed_documents(texts) | |
| # get metadata to store in Pinecone | |
| metadata = [ | |
| {'text': x['chunk'], | |
| 'source': x['source'], | |
| 'title': x['title']} for i, x in batch.iterrows() | |
| ] | |
| # add to Pinecone | |
| index.upsert(vectors=zip(ids, embeds, metadata)) | |
| index.describe_index_stats() | |
| #### Retrival Augmented Generation | |
| #from langchain_pinecone import PineconeVectoreStore | |
| from langchain.vectorstores import Pinecone | |
| # the metadata field that contains our text | |
| text_field = "text" | |
| # initialize the vector store object | |
| vectorstore = Pinecone( | |
| index, embed_model.embed_query, text_field | |
| ) | |
| query = "What is so special about Llama 2?" | |
| vectorstore.similarity_search(query, k=3) | |
| # connect the output from vectorstore to chat | |
| def augment_prompt(query: str): | |
| # get top 3 results from knowledge base | |
| results = vectorstore.similarity_search(query, k=3) | |
| # get the text from the results | |
| source_knowledge = "\n".join([x.page_content for x in results]) | |
| # feed into an augmented prompt | |
| augmented_prompt = f"""Using the contexts below, answer the query. | |
| Contexts: | |
| {source_knowledge} | |
| Query: {query}""" | |
| return augmented_prompt | |
| # create a new user prompt | |
| prompt = HumanMessage( | |
| content=augment_prompt(query) | |
| ) | |
| # add to messages | |
| messages.append(prompt) | |
| res = chat(messages) | |
| print(res.content) |