File size: 1,701 Bytes
4e6bed1
 
 
 
 
 
 
 
 
c0ac2e0
4e6bed1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cce46c
4e6bed1
 
 
 
 
 
 
 
 
c0ac2e0
4e6bed1
 
 
 
 
 
f07b911
4e6bed1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from pathlib import Path

import streamlit as st

from haystack.nodes import PreProcessor, TextConverter, FARMReader, BM25Retriever
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.pipelines.standard_pipelines import TextIndexingPipeline
from haystack.pipelines.base import Pipeline
import torch

# Hash hack, assume all outputs of ExtractiveQAPipeline type are equal
@st.cache(hash_funcs={ExtractiveQAPipeline: lambda _: "42"})
def get_pipe():
    transcript_path = Path("making_sense_transcripts/")
    document_store = InMemoryDocumentStore(use_bm25=True)

    indexing_pipeline = Pipeline()
    indexing_pipeline.add_node(
        component=TextConverter(), name="TextConverter", inputs=["File"]
    )
    indexing_pipeline.add_node(
        component=PreProcessor(), name="PreProcessor", inputs=["TextConverter"]
    )
    indexing_pipeline.add_node(
        component=document_store, name="DocumentStore", inputs=["PreProcessor"]
    )

    file_paths = list(transcript_path.glob("*.txt"))
    doc_paths = [{"file_path": str(path)} for path in file_paths]
    indexing_pipeline.run_batch(file_paths=file_paths, meta=doc_paths)

    retriever = BM25Retriever(document_store=document_store)
    reader = FARMReader(
        model_name_or_path="deepset/roberta-base-squad2",
        use_gpu=torch.cuda.is_available(),
        context_window_size=200,
    )
    pipe = ExtractiveQAPipeline(reader, retriever)
    return pipe


def ask_pipe(question: str, pipe: ExtractiveQAPipeline):
    prediction = pipe.run(
        query=question, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}}
    )
    return prediction