-
Notifications
You must be signed in to change notification settings - Fork 17
/
llamaindex_context.py
62 lines (51 loc) · 1.98 KB
/
llamaindex_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import chromadb
from llama_index.llms import OpenAI
import streamlit as st
from chromadb.api.models.Collection import Collection
from llama_index import (
ServiceContext,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
)
from llama_index.core import BaseQueryEngine
from llama_index.response.schema import RESPONSE_TYPE
from llama_index.vector_stores import ChromaVectorStore
def create_vectors(collection_name="tmp_collection") -> Collection:
chroma_client = chromadb.PersistentClient("./db")
return chroma_client.get_or_create_collection(collection_name)
class LlamaIndexContext:
_query_engine: BaseQueryEngine
def __init__(self):
if "query_engine" not in st.session_state:
service_context = LlamaIndexContext.create_service_context()
storage_context = LlamaIndexContext.create_storage_context()
index: VectorStoreIndex = VectorStoreIndex.from_documents(
documents=SimpleDirectoryReader("./tmp").load_data(),
service_context=service_context,
storage_context=storage_context,
show_progress=True,
)
st.session_state["query_engine"] = index.as_query_engine()
self._query_engine = st.session_state["query_engine"]
@classmethod
def create_service_context(cls) -> ServiceContext:
return ServiceContext.from_defaults(
chunk_overlap=0,
chunk_size=500,
llm=OpenAI(),
)
@classmethod
def create_storage_context(cls) -> StorageContext:
return StorageContext.from_defaults(
vector_store=ChromaVectorStore(create_vectors())
)
def run(self):
query: str = st.text_input("Query", placeholder="Enter query here")
if query != "":
result: RESPONSE_TYPE = self._query_engine.query(query)
st.write(result.response)
def __call__(self):
self.run()
if __name__ == "__main__":
LlamaIndexContext()()