class FaissIdx:
def __init__(self, model, dim=768):
self.index = faiss.IndexFlatIP(dim)
self.doc_map = dict()
self.model = model
self.ctr = 0
def add_doc(self, document_text):
self.index.add(self.model.get_embedding(document_text))
self.doc_map[self.ctr] = document_text
self.ctr += 1
def search_doc(self, query, k=3):
D, I = self.index.search(self.model.get_embedding(query), k)
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]