package org.example.tools;

import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.*;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.store.*;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

public class VectorStore implements EmbeddingStore<TextSegment> {
    private final Directory directory;
    private final IndexWriter writer;
    private DirectoryReader reader;
    private IndexSearcher searcher;
    private final int dimension;
    private final Path indexDir;
    
    public VectorStore(String indexPath, int dimension) throws IOException {
        this.dimension = dimension;
        this.indexDir = Paths.get(indexPath);
        this.directory = FSDirectory.open(indexDir);
        
        IndexWriterConfig config = new IndexWriterConfig(new StandardAnalyzer());
        config.setOpenMode(IndexWriterConfig.OpenMode.CREATE_OR_APPEND);
        this.writer = new IndexWriter(directory, config);
        refreshSearcher();
    }

    private void refreshSearcher() throws IOException {
        if (reader != null) {
            reader.close();
        }
        reader = DirectoryReader.open(writer);
        this.searcher = new IndexSearcher(reader);
    }

    public IndexSearcher getSearcher() {
        try {
            if (searcher == null) {
                refreshSearcher();
            }
            return searcher;
        } catch (IOException e) {
            throw new RuntimeException("Failed to get searcher", e);
        }
    }
    
    @Override
    public String add(Embedding embedding) {
        return add(embedding, null);
    }

    @Override
    public void add(String id, Embedding embedding) {
        add(id, embedding, null);
    }

    @Override
    public String add(Embedding embedding, TextSegment textSegment) {
        String id = UUID.randomUUID().toString();
        add(id, embedding, textSegment);
        return id;
    }

    @Override
    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = new ArrayList<>();
        for (int i = 0; i < embeddings.size(); i++) {
            ids.add(add(embeddings.get(i)));
        }
        return ids;
    }

    @Override
    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = new ArrayList<>();
        for (int i = 0; i < embeddings.size(); i++) {
            ids.add(add(embeddings.get(i), embedded.get(i)));
        }
        return ids;
    }

    public void add(String id, Embedding embedding, TextSegment textSegment) {
        try {
            Document doc = new Document();
            doc.add(new StringField("id", id, Field.Store.YES));
            
            // 存储向量
            doc.add(new KnnVectorField("vector", embedding.vector(), VectorSimilarityFunction.COSINE));
            
            if (textSegment != null) {
                doc.add(new TextField("text", textSegment.text(), Field.Store.YES));
                
                // 存储所有元数据
                for (Map.Entry<String, String> entry : textSegment.metadata().asMap().entrySet()) {
                    doc.add(new StringField("meta_" + entry.getKey(), entry.getValue(), Field.Store.YES));
                    // 添加可搜索的文本字段用于统计
                    doc.add(new TextField("search_" + entry.getKey(), entry.getValue(), Field.Store.NO));
                }
                
                // 添加时间戳字段用于排序和过滤
                doc.add(new LongPoint("timestamp", System.currentTimeMillis()));
                doc.add(new StoredField("timestamp", System.currentTimeMillis()));

                int textLength = textSegment.text().length();
                doc.add(new NumericDocValuesField("text_length", textLength));
                doc.add(new StoredField("text_length", textLength));
            }
            
            writer.addDocument(doc);
            
            // 每添加1000个文档后提交，提高批量性能
            if (getDocCount() % 1000 == 0) {
                writer.commit();
                refreshSearcher();
            }
        } catch (IOException e) {
            throw new RuntimeException("添加文档失败", e);
        }
    }
    
    public void commit() {
        try {
            writer.commit();
            refreshSearcher();
        } catch (IOException e) {
            throw new RuntimeException("提交索引失败", e);
        }
    }

    @Override
    public void remove(String id) {
        try {
            writer.deleteDocuments(new Term("id", id));
            writer.commit();
            refreshSearcher();
        } catch (IOException e) {
            throw new RuntimeException("删除文档失败", e);
        }
    }

    @Override
    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        try {
            // 使用KNN向量查询
            KnnVectorQuery query = new KnnVectorQuery("vector", referenceEmbedding.vector(), maxResults);
            TopDocs topDocs = searcher.search(query, maxResults);
            
            List<EmbeddingMatch<TextSegment>> results = new ArrayList<>();
            for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                if ((double) scoreDoc.score < minScore) continue;
                
                Document doc = searcher.doc(scoreDoc.doc);
                String id = doc.get("id");
                String text = doc.get("text");
                
                // 重建元数据
                Map<String, String> metadata = new HashMap<>();
                for (IndexableField field : doc.getFields()) {
                    if (field.name().startsWith("meta_")) {
                        metadata.put(field.name().substring(5), field.stringValue());
                    }
                }
                
                TextSegment segment = TextSegment.from(text, Metadata.from(metadata));
                results.add(new EmbeddingMatch<TextSegment>(
                    (double) scoreDoc.score, id, referenceEmbedding, segment
                ));
            }
            
            return results;
        } catch (IOException e) {
            throw new RuntimeException("搜索失败", e);
        }
    }

    public void clear() {
        try {
            writer.deleteAll();
            writer.commit();
            refreshSearcher();
        } catch (IOException e) {
            throw new RuntimeException("清空索引失败", e);
        }
    }
    
    public int size() {
        return getDocCount();
    }
    
    // 获取文档数量的辅助方法
    private int getDocCount() {
        try {
            return reader.numDocs();
        } catch (Exception e) {
            return 0;
        }
    }
    
    public void optimize() {
        try {
            writer.forceMerge(1); // 优化索引
            writer.commit();
            refreshSearcher();
        } catch (IOException e) {
            throw new RuntimeException("优化索引失败", e);
        }
    }
    
    @Override
    protected void finalize() throws Throwable {
        try {
        } finally {
            super.finalize();
        }
    }
}