Edit this page | Blame

Build an AI system for GN

Tags

  • type: feature
  • assigned: johannesm
  • priority: medium
  • status: in progress
  • keywords: llm, rag, ai, agent

Description

The aim is to build an AI system/agent/RAG able to digest mapping results and metadata in GN for analysis scaling. This is not quite possible at the moment, given that one stills need to dig and compare manually that type of information. And the data in GN is somehow big for such approach :)

I have made an attempt to using Deep-Learning for my Masters project. It could work but required further processing of results for interpretation. Not quite handy! Instead, we want a system which takes care of all the work (at least most of it) and that we can understand. This is how transformers and LLMs came into the picture.

This work is an extension of the GNQA system initiated by Shelby and Pjotr.

Tasks

  • [X] Look for transformer model ready for use and try
  • [X] Build a RAG system and test with small corpus of mapping results
  • [X] Experiment with actual mapping results and metadata
  • [X] Move from RAG to agent
  • [X] Optimize AI system
  • [] Scale analysis to more data
  • [] Compare performance of open LLMs with Claude in the system

Look for transformer model ready for use and try

Given the success of transformers, I was first incited by Pjotr to look for a model that can support different types of data i.e numerical (mapping results) vs textual (metadata).

I found TAPAS which:

  • takes data of different types in tabular format
  • takes a query or question in form of text
  • performs operations on rows of the data table
  • retrieves relevant information
  • returns an answer to the original query

Experimentations were ongoing when Rob found with the help of Claude that this architecture would not go far. I know we used an AI to assist our work on AI (at least we did not ask an AI to do the job from the go :)) But it was a good point. TAPAS is relatively old and a lot of progress have been made with LLM and agent since!

To take advantage of all the progress made with LLM, need to find a way to have only text data. LLMs are trained to undertand and work with text. Metadata, being RDF, is already in text format. I only needed to convert the mapping results to text. It is a detour worth the shot if this can give more flexibility and save development time!

Build a RAG system and test with a small corpus of mapping results

I have read a number of books and found that RAG system are pretty easy to design with LangChain. The RAG is made of 2 components:

  • search and retrieval -> need a keyword search algorithm or embedding model
  • response generation -> need an LLM

The system takes a query and return a response based on the information retrieved. Very efficient in increasing factuality of LLM responses!

I wrote a first draft with:

1. BMD25 keyword search algorithm and Qwen3-Embedding-0.6B model This is an hybrid search subsystem. It searches both by keyword and meaning. It should be more scalable. The subsystem:

  • takes the query
  • processes it
  • extracts the most similar documents or data chunks
  • passes the results to the LLM for further processing

2. Qwen2.5-72B-Instruct LLM This is the response subsystem. It:

  • takes the original query
  • takes retrieved information
  • reason using tree-of-thought mechanism for validation
  • generate an answer to the query using retrieved information

Small bonus I suspected that more than one query might be required to address some questions. I needed to find a way to make the system conversational so that one could ask many questions without the information to be lost in between. This came down to adding memory to the system. So, I added a small model for memory propagation. The model was: 3. Phi-3-mini-4k-instruct It uses the model and the ConversationSummary class from LangChain to summarizekey information in interaction for accessibility between sessions.

Time to test!

  • Query 1
Question: What is the lod for trait leptin receptor EPFLMouseLiverCDEx0413 at position 100?

Response:
**Expert 1:**
The question asks for the LOD score at position 100. From the relevant information, I see that there is a LOD score of 1 at position 100.

**Expert 2:**
I agree with Expert 1. The relevant information clearly states that the LOD score at position 100 is 1.

**Expert 3:**
I concur with both Experts 1 and 2. The LOD score for the trait leptin receptor EPFLMouseLiverCDEx0413 at position 100 is indeed 1.

**Conclusion:**
All experts agree that the LOD score for the trait leptin receptor EPFLMouseLiverCDEx0413 at position 100 is 1. <|end|>

Since the corpus used for testing is small, I could inspect to confirm the accuracy of the response generated by the RAG. Lookup revealed that it was correct, yah!

  • Query 2
Question: Which position of leptin receptor EPFLMouseLiverCDEx0413 was examined in the previous prompt?

Response:
**Expert 1:**
The previous prompt asked for the LOD score at position 100. Therefore, the position examined in the previous prompt is 100.

**Expert 2:**
I agree with Expert 1. The previous prompt specifically asked about the LOD score at position 100. Thus, the position examined is 100.

**Expert 3:**
I concur with both Experts 1 and 2. The previous prompt clearly specified that it was asking for the LOD score at position 100. Therefore, the position examined in the previous prompt is 100.

**Conclusion:**
All experts agree that the position of leptin receptor EPFLMouseLiverCDEx0413 examined in the previous prompt is 100. <|end|>

This was also correct. I was very excited, Rob too!

I received a feedback that made me clarify the meaning of position. I rebuilt the small corpus using a BXD dataset of traits I run GEMMA on for my Masters project, making sure to use actual marker positions this time and including also marker names.

Let's experiment again!

I got results similar to the ones above to the exception that marker positions are true and marker names are supported.

I faced a challenge though :(

For queries that require combinations of different data chunks or documents (non atomic queries), the system does not perform well. For example, to the query

  • How many traits hepatic nuclear factor 4 are in the datasets?

The system was confused. Even after prompt engineering, the answer generated was not accurate

  • Identify 2 traits that have similar lod values on chromosome 1 position 3010274

The system sometimes missed or caught only 1 trait having a lod value at the position.

This is probably because the system cannot execute more than one retrieval run. To get there, I need to make the RAG more autonomous: this is how the concept of agent came up.

Experiment with actual mapping results and metadata

Getting an agent asked for more readings. In the meantime, I decided to get actual mapping results and metadata for experimentation. Would be sad to proceed if the system is actually not compatible with data to use in production :)

I waited for Pjotr to precompute GEMMA association results and export them with metadata to an endpoint. The RDF schema was very interesting to learn and Bonz did some work about that in the past :)

You can check out recent developments of Pjotr's work here:

For Bonz work, see:

Anyway, it took some time but I finally got a glance of the data.

This started with the metadata from an old endpoint created by Bonz. I had also to learn SPARQL - I was quite new to it!

We thought LLMs can make directly sense of RDF data (still in text format) but it turns out it is not. They can recognize that it is RDF but in between all the URIs, they start making mistakes quite quickly. Instead of using RDF natively, we decided to use LLMs to first convert RDF data - could be both metadata or mapping results - to natural text before using it with the RAG system. The system should do best and we confirmed that!

Pjotr made available the first version of the global endpoint. Nothing should stop me now :) I wrote a script to fetch metadata from the endpoint. I am not sharing my code so far. Let me fix that right now. You can follow this link for the script I was referring to above:

Pjotr also made available the ttl files in my home directory on balg01 - full flexibility!

I naturalized some RDF triples. The corpus now looked like below

The phenotype identified as BXD_19926, or abbreviated as TAG(48:3)_HFD, is part of the EPFL LISP3 Study. It's a measurement of Liver Lipidomics: TAG48:3_HFD, which has side chain FA composition [TG12:0_18:1_18:2,TG14:0_16:0_18:3,TG14:0_16:1_18:2,TG14:0_16:2_18:1,TG14:1_16:0_18:2,TG14:1_16:1_18:1,TG16:0_16:1_16:2,TG16:1_16:1_16:1] belonging to the lipid class triacylglycerol TAG. This phenotype was measured by MS [R.U.AUC] in overnight fasted male mice at 29 weeks of age, who were on a high fat diet (60% kCal/fat HarlanTD.06414).
The phenotype was submitted by Jha P and contributed to by Auwerx J. It belongs to the group setBxd and has a sequence number of 1. The mean value for this phenotype is 0.00410489, with an additive effect of 0.00066854. The LOD score for this phenotype is 2.91779.
The locus associated with this phenotype is Rsm10000001653.

It was time to proceed to testing. The results I show below are not exactly for the RAG system I explained above. I was improving the RAG in parallel. So watch out. I will explain everything in the next task :)

For now, you can have a look at the first implementation at:

The goal was to try a complex query the previous system failed on. You can see that the question was indeed not atomic.

Question: Identify traits having a lod score > 4.0

Response:
In addressing the query to identify traits with a LOD score exceeding 4.0, it's pertinent to highlight three mapped loci that meet these criteria: Rsm10000011038_BXDPublish_170
86_gemma_GWA_19ca1aeb (LOD: 4.5), Rsm10000000185_BXDPublish_12641_gemma_GWA_8c8d0548 (LOD: 4.14) and Rsm10000000559 (LOD: 4.61). These findings underscore the significance of these loci in genetic studies, given their high LOD scores indicative of strong linkage evidence.

It is already better. At least, the system combined information from different data or document chunks. Retrieved chunks were:

[[Document(metadata={}, page_content='\nThe mapped locus associated with Rsm10000001536 has a LOD score of 2.07, and the mapped SNP is identified as GEMMAMapped_LOCO_BXDPublish_24451_gemma_GWA_aecf628e. The effect size for this locus is 0.01, with an allele frequency (AF) of 0.333.\n                '), Document(metadata={}, page_content='\nThe mapped
 locus Rsm10000011536 is associated with a LOD score of 5.69, an effect size of 0.385 and an allele frequency of 0.526. This locus has been mapped to the SNP GEMMAMapped_LOCO_BXDPublish_2032
0_gemma_GWA_6832c0e4.\n                '), Document(metadata={}, page_content='\nThe mapped locus, Rsm10000000185_BXDPublish_12641_gemma_GWA_8c8d0548, has an effect size of -3.137 and a LOD
score of 4.14. This locus is associated with the mapped SNP GEMMAMapped_LOCO_BXDPublish_12641_gemma_GWA_8c8d0548, and it has an allele frequency of 0.556.\n                '), Document(metad
ata={}, page_content='\nIn plain English, this data refers to a mapped locus associated with the Rsm10000011038_BXDPublish_17086_gemma_GWA_19ca1aeb identifier. This locus is linked to the Rsm10000011038 identifier, has an effect size of -0.048, a LOD score of 4.5, and an allele frequency (AF) of 0.167. The mapped SNP associated with this data can be found under the GEMMAMapped_LOCO_BXDPublish_17086_gemma_GWA_19ca1aeb identifier.\n                '), Document(metadata={}, page_content='\n                In plain English, the data describes a genetic locus identified as Rsm10000000559. This locus was mapped through an effect size of -34.191, with an allele frequency of 0.438. The mapping achieved a LOD score of 4.61, indicating the statistical significance of this genetic association. The mapped locus is associated with a specific SNP (Single Nucleotide Polymorphism) identified as GEMMAMapped_LOCO_BXDPublish_12016_gemma_GWA_bc6adcae.\n                ')]]

Move from RAG to agent

This is where I made the system more autonomous i.e agentic. I am now going to explain how I did it. I read a couple of sources and found that RAG system built with LangChain could be made agentic by using LangGraph. This creates a graph structure which splits the task among different nodes or agents. Each agent achieves a specific subtasks and a final node manages the integration.

Checkout this commit to see the results:

You can clearly see differences between *rag_langchain.py* and *rag_langgraph.py*

Basically,

def ask_question(self, question: str):
        start=time.time()
        memory_var=self.memory.load_memory_variables({})
        chat_history=memory_var.get('chat_history', '')
        result=self.retrieval_chain.invoke(
            {'question': question,
             'input': question,
             'chat_history': chat_history})
        answer=result.get("answer")
        citations=result.get("context")
        self.memory.save_context(
            {'input': question},
            {'answer': answer})
        # Close LLMs
        GENERATIVE_MODEL.client.close()
        SUMMARY_MODEL.client.close()
        end=time.time()
        print(f'ask_question: {end-start}')
        return {
            "question": question,
            "answer": answer,
            "citations": citations,
        }

became:

def retrieve(self, state: State) -> dict:
        # Define graph node for retrieval
        prompt = f"""
        You are powerful data retriever and you strictly return
        what is asked for.
        Retrieve relevant documents for the query below,
        excluding these documents: {state.get('seen_documents', [])}
        Query: {state['input']}"""
        retrieved_docs = self.ensemble_retriever.invoke(prompt)
        return {"input": state["input"],
                "context": retrieved_docs,
                "digested_context": state.get("digested_context", []),
                "result_count": state.get("result_count", 0),
                "target": state.get("target", 3),
                "max_iterations": state.get("max_iterations", 5),
                "should_continue": "naturalize",
                "iterations": state.get("iterations", 0) + 1, # Add one per run
                "chat_history": state.get("chat_history", []),
                "answer": state.get("answer", ""),
                "seen_documents": state.get("seen_documents", [])}

    def manage(self, state:State) -> dict:
        # Define graph node for task orchestration
        context = state.get("context", [])
        digested_context = state.get("digested_context", [])
        answer = state.get("answer", "")
        iterations = state.get("iterations", 0)
        chat_history = state.get("chat_history", [])
        result_count = state.get("result_count", 0)
        target = state.get("target", 3)
        max_iterations = state.get("max_iterations", 5)
        should_continue = state.get("should_continue", "retrieve")
        # Orchestration logic
        if iterations >= max_iterations or result_count >= target:
            should_continue = "summarize"
        elif should_continue == "retrieve":
            # Reset fields
            context = []
            digested_context = []
            answer = ""
        elif should_continue == "naturalize" and not context:
            should_continue = "retrieve"  # Can't naturalize without context
            context = []
            digested_context = []
            answer = ""
        elif should_continue == "analyze" and \
             (not context or not digested_context):
            should_continue = "retrieve"  # Can't analyze without context
            context = []
            digested_context = []
            answer = ""
        elif should_continue == "check_relevance" and not answer:
            should_continue = "analyze"  # Can't check relevance without answer
        elif should_continue not in ["retrieve", \
                "naturalize", "check_relevance", "analyze", "summarize"]:
            should_continue = "summarize"  # Fallback
        return {"input": state["input"],
                "should_continue": should_continue,
                "result_count": result_count,
                "target": target,
                "iterations": iterations,
                "max_iterations": max_iterations,
                "context": context,
                "digested_context": digested_context,
                "chat_history": chat_history,
                "answer": answer,
                "seen_documents": state.get("seen_documents", [])}

    def analyze(self, state:State) -> dict:
        # Define graph node for analysis and text generation
        context = "\n".join(state.get("digested_context", []))
        existing_history="\n".join(state.get("chat_history", [])) \
            if state.get("chat_history") else ""
        iterations = state.get("iterations", 0)
        max_iterations = state.get("max_iterations", 5)
        result_count = state.get("result_count", 0)
        target = state.get("target", 3)
        if not context: # Cannot proceed without context
            should_continue = "summarize" if iterations >= max_iterations \
                or result_count >= target else "retrieve"
            response = ""
        else:
            prompt = f"""
             <|im_start|>system
             You are an experienced analyst that can use available information
             to provide accurate and concise feedback.
             <|im_end|>
             <|im_start|>user
             Answer the question below using following information.
             Context: {context}
             History: {existing_history}
             Question: {state["input"]}
             Answer:
             <|im_end|>
             <|im_start|>assistant"""
            response = GENERATIVE_MODEL.invoke(prompt)
            if not response or not isinstance(response, str) or \
                    response.strip() == "": # Need valid generation
                should_continue = "summarize" if iterations >= max_iterations \
                    or result_count >= target else "retrieve"
                response = ""  # Ensure a clean state
            else:
                should_continue = "check_relevance"
        return {"input": state["input"],
                "answer": response,
                "should_continue": should_continue,
                "context": state.get("context", []),
                "digested_context": state.get("digested_context", []),
                "iterations": iterations,
                "max_iterations": max_iterations,
                "result_count": result_count,
                "target": target,
                "chat_history": state.get("chat_history", []),
                "seen_documents": state.get("seen_documents", [])}

    
    def summarize(self, state:State) -> dict:
        # Define node for summarization
        existing_history = state.get("chat_history", [])
        current_interaction=f"""
            User: {state["input"]}\nAssistant: {state["answer"]}"""
        full_context = "\n".join(existing_history) + "\n" + \
            current_interaction if existing_history else current_interaction
        result_count = state.get("result_count", 0)
        target = state.get("target", 3)
        iterations = state.get("iterations", 0)
        max_iterations = state.get("max_iterations", 5)
        prompt = f"""
            <|system|>
            You are an excellent and concise summary maker.
            <|end|>
            <|user|>
            Summarize in bullet points the conversation below.
            Follow this format: input - answer
            Conversation: {full_context}
            <|end|>
            <|assistant|>"""
        summary = GENERATIVE_MODEL.invoke(prompt).strip() # central task
        if not summary or not isinstance(summary, str) or summary.strip() == "":
            summary = f"- {state['input']} - No valid answer generated"
        should_continue="end" if result_count >= target or \
            iterations >= max_iterations else "retrieve"
        updated_history = existing_history + [summary] # update chat_history
        print(f"\nChat history in summarize: {updated_history}")
        return {"input": state["input"],
                "answer": summary,
                "should_continue": should_continue,
                "context": state.get("context", []),
                "digested_context": state.get("digested_context", []),
                "iterations": iterations,
                "max_iterations": max_iterations,
                "result_count": result_count,
                "target": target,
                "chat_history": updated_history,
                "seen_documents": state.get("seen_documents", [])}

    def check_relevance(self, state:State) -> dict:
        # Define node to check relevance of retrieved data
        context = "\n".join(state.get("digested_context", []))
        result_count = state.get("result_count", 0)
        target = state.get("target", 3)
        iterations = state.get("iterations", 0)
        max_iterations = state.get("max_iterations", 5)
        seen_documents = state.get("seen_documents", [])
        prompt = f"""
            <|system|>
            You are an expert in evaluating data relevance. You do it seriously.
            <|end|>
            <|user|>
            Assess if the provided answer is relevant to the query.
            Return only yes or no. Nothing else.
            Answer: {state["answer"]}
            Query: {state["input"]}
            Context: {context}
            <|end|>
            <|assistant|>"""
        assessment = GENERATIVE_MODEL.invoke(prompt).strip()
        if assessment=="yes":
            result_count = result_count + 1
            should_continue = "summarize"
        elif result_count >= target or iterations >= max_iterations:
            should_continue = "summarize"
        else:
            should_continue = "retrieve"
            seen_documents.extend([doc.page_content for doc in \
                state.get("context", [])])
        return {"input": state["input"],
                "context": state.get("context", []),
                "digested_context": state.get("digested_context", []),
                "iterations": iterations,
                "max_iterations": max_iterations,
                "answer": state["answer"],
                "result_count": result_count,
                "target": target,
                "seen_documents": seen_documents,
                "chat_history": state.get("chat_history", []),
                "should_continue": should_continue}
        
    def route_manage(self, state: State) -> str:
            should_continue = state.get("should_continue", "retrieve")
            iterations = state.get("iterations", 0)
            max_iterations = state.get("max_iterations", 5)
            result_count = state.get("result_count", 0)
            target = state.get("target", 3)
            context = state.get("context", [])
            digested_context = state.get("digested_context", [])
            answer = state.get("answer", "")
            # Validate state and enforce termination
            if iterations >= max_iterations or result_count >= target:
                return "summarize"
            if should_continue not in ["retrieve", "naturalize", \
                    "check_relevance", "analyze", "summarize"]:
                return "summarize"  # Fallback to summarize
            return should_continue

    def initialize_langgraph_chain(self) -> Any:
        graph_builder = StateGraph(State)
        graph_builder.add_node("manage", self.manage)
        graph_builder.add_node("retrieve", self.retrieve)
        graph_builder.add_node("naturalize", self.naturalize)
        graph_builder.add_node("check_relevance", self.check_relevance)
        graph_builder.add_node("analyze", self.analyze)
        graph_builder.add_node("summarize", self.summarize)
        graph_builder.add_edge(START, "manage")
        graph_builder.add_edge("retrieve", "naturalize")
        graph_builder.add_edge("naturalize", "analyze")
        graph_builder.add_edge("analyze", "check_relevance")
        graph_builder.add_edge("check_relevance", "manage")
        graph_builder.add_edge("summarize", END)
        graph_builder.add_conditional_edges(
            "manage",
            self.route_manage,
            {"retrieve": "retrieve",
             "naturalize": "naturalize",
             "check_relevance": "check_relevance",
             "analyze": "analyze",
             "summarize": "summarize"})
        graph=graph_builder.compile()
        return graph

    async def invoke_langgraph(self, question: str) -> Any:
        graph = self.initialize_langgraph_chain()
        initial_state = {
            "input": question,
            "chat_history": [],
            "context": [],
            "digested_context": [],
            "seen_documents": [],
            "answer": "",
            "iterations": 0,
            "result_count": 0,
            "should_continue": "retrieve",
            "target": 3,  # Explain magic number 3
            "max_iterations": 5 # Explain magic number 5
        }
        result = await graph.ainvoke(initial_state) # Run graph asynchronously
        return result

    
    def answer_question(self, question: str) -> Any:
        start = time.time()
        result = asyncio.run(self.invoke_langgraph(question))
        end = time.time()
        print(f'answer_question: {end-start}')
        return {"result": result["chat_history"],
                "state": result}

As mentioned above, we quickly spotted the need for the naturalization of RDF triples. This explains the addition of a naturalization node to the graph:

def naturalize(self, state: State) -> dict:
        # Define graph node for RDF naturalization
        prompt = f"""
        <|im_start|>system
        You are extremely good at naturalizing RDF and inferring meaning.
        <|im_end|>
        <|im_start|>user
        Take element in the list of RDF triples one by one and
        make it sounds like Plain English. Repeat for each the subject
        which is at the start. You should return a list. Nothing else.
        List: ["Entity http://genenetwork.org/id/traitBxd_20537 \
        \nhas http://purl.org/dc/terms/isReferencedBy of \
        http://genenetwork.org/id/unpublished22893", "has \
        http://genenetwork.org/term/locus of \
        http://genenetwork.org/id/Rsm10000002554"]
        <|im_end|>
        <|im_start|>assistant
        New list: ["traitBxd_20537 isReferencedBy unpublished22893", \
        "traitBxd_20537 has a locus Rsm10000002554"]
        <|im_end|>
        <|im_start|>user
        Take element in the list of RDF triples one by one and
        make it sounds like Plain English. Repeat for each the subject
        which is at the start. You should return a list. Nothing else.
        List: {state.get("context", [])}
        <|im_start|>end
        <|im_start|>assistant"""
        response = GENERATIVE_MODEL.invoke(prompt)
        print(f"Response in naturalize: {response}")
        if isinstance(response, str):
            start=response.find("[")
            end=response.rfind("]") + 1 # offset by 1 to make slicing
            response=json.loads(response[start:end])
        else:
            response=[]
        return {"input": state["input"],
                "context": state.get("context", []),
                "digested_context": response,
                "result_count": state.get("result_count", 0),
                "target": state.get("target", 3),
                "max_iterations": state.get("max_iterations", 5),
                "should_continue": "analyze",
                "iterations": state.get("iterations", 0),
                "chat_history": state.get("chat_history", []),
                "answer": state.get("answer", ""),
                "seen_documents": state.get("seen_documents", [])}

Next step is to compare the RAG system (rag_langchain.py) to the agent system (rag_langgraph.py) on some queries.

  • Query 1
Question: What is the lod score of BXDPublish_10187_gemma_GWA at D12mit280?

RAG response: 3.01245

Agent response: 5.21

Hmm, the result of the RAG system is dubious. I did not expect long floating number for the LOD scores. Fortunately, Pjotr's endpoint is useful to make quick query agains the RDF data. By the way, one could see how complementary SPARQL queries and LLM calls are. It was very thoughtful to pass through RDF :)

Using the endpoint, I have fetched all lodscores for BXDPublish_10187_gemma_GWA at D12mit280 and only the result of the agent system appeared to be true.

  • Query 2
Question: I want you look for D12mit280. You are allowed to initiate many rounds of retrieval until you get 20 relevant results. Next, extract the lod score and trait for each result. List for me traits that have a lod score > 4.0. Join to the list the corresponding lod scores so I can confirm. Show results using the following format: trait - lod score

RAG response: It seems there's a misunderstanding. The provided data doesn't include any results related to "D12mit280". I can only provide information or perform tasks based on the given context and data. If you have another query or need clarification on something else, feel free to ask.

Agent response: It appears there are no results for D12mit280 with a lod score > 4.0 in the provided entities. If you have more data or different criteria, please let me know so I can assist further.

This was surprising. My hypotehsis was that the node in charge of retrieval was not fetching relevant documents for the LLM to elaborate a sound feedback. There is only one way to confirm: check the documents retrieved.

Printing the documents after retrieval showed that I was actually right. I also noticed a couple of other problems in the execution of the program. Nodes were mostly not executed in the order I expected. Consequently, I decided to get into a long quest of incremental improvements :)

Optimize AI system

A couple of things needed to be optimized. This included retrieval, node orchestration and GPU acceleration.

  • Retrieval

Let's start with retrieval. I played with different parameters of the retriever. It was an EnsembleRetriever using both keyword and semantic search as illustrated below:

ensemble_retriever = EnsembleRetriever(
            retrievers=[
                self.chroma_db.as_retriever(search_kwargs={"k": 10}),
                bm25_retriever,
            ],
            weights=[0.4, 0.6],
        )

I opted for trying different combinations of weights to get to this selection. But more rigorous work needs to be done to systematically identify the best hyparameters for retrieval.

  • Node orchestration

Moving to node orchestration. It took me some time and reflection to realize that the nodes, I had at the moment, make only sense to be executed sequentially. Analysis (analyze node) should always be followed with relevance checking (check_relevance node), finding summarization (summarize node), and in that order. Any other sequence of execution is not useful. I had to modify the code to comply to this and prevent getting into some unnecessary loops :)

But this also highlighted other limitations of the system: lack of flexibility and lack of autonomy

To address the lack of flexibility, I introduced a new node to split a query into multiple queries that can be solved independently and asynchronously. The node split_query works as follows:

def split_query(self, query: str) -> list[str]:

        prompt = f"""
            <|im_start|>system
            You are a very powerful task generator.
        
            Split the query into task and context based on tags.
            Based on the context, ask relevant questions that help achieve the task. Make sure the subquestions are atomic and do not rely on each other.
            Return only the subquestions.
            Return strictly a JSON list of strings, nothing else.
            <|im_end|>
            <|im_start|>user
            Query:
            Task: Identify traits with a lod score > 3.0 for the marker Rsm10000011643. Tell me what marker Rsm10000011643 is involved in biology.
            Context: A trait has a long name and contain generally strings like GWA or GEMMA. The goal is to know the biological processes which might be related to the marker previously mentioned.
        
            Result:
            <|im_end|>
            <|im_start|>assistant
            ["What traits (containing GWA or GEMMA) have a lod score > 3.0 at Rsm10000011643?", "Which biological processes are related to Rsm10000011643?"]
            <|im_end|>
            <|im_start|>user
            Query:
            {query}
            Result:
            <|im_end|>
            <|im_start|>assistant"""

        with self.generative_lock:
            response = GENERATIVE_MODEL.invoke(prompt)
        print(f"Subqueries in split_query: {response}")

        if isinstance(response, str):
            start = response.find("[")
            end = response.rfind("]") + 1
            subqueries = json.loads(response[start:end])
        else:
            subqueries = [query]

        return subqueries

There is need for another node to reconcile answers generated for each subquery. This motivated the addition of the node finalize:

def finalize(self, query: str, subqueries: list[str], answers: list[str]) -> dict:

        prompt = f"""
            <|im_start|>system
            You are an experienced biology scientist. Given the subqueries and corresponding answers, generate a comprehensive explanation to address the query using all information provided.
            Ensure the response is insightful, concise, and draws logical inferences where possible.
            Do not modify entities names such as trait and marker.            
            Make sure to link based on what is common in the answers.
            Provide only the story, nothing else.
            Do not repeat answers. Use only 200 words max.
            <|im_end|>
            <|im_start|>user
            Query:
            Identify two traits related to diabetes.
            Compare their lod scores at Rsm149505.
            Subqueries:
            ["Identify two traits related to diabetes",
            "Compare lod scores of same traits at Rsm149505"]
            Answers:
            ["Traits A and B are related to diabetes", \
            "The lod score at Rsm149505 is 2.3 and 3.4 for trait A and B"]
            Conclusion:
            <|im_end|>
            <|im_start|>assistant
            Traits A and B are related to diabetes and have a lod score of\
            2.3 and 3.4 at Rsm149505. The two traits could interact via a\
            gene close to the marker Rsm149505.
            <|im_end|>
            <|im_start|>user
            Query:
            {query}
            Subqueries:
            {subqueries}
            Answers:
            {answers}
            Conclusion:
            <|im_end|>
            <|im_start|>assistant"""
	with self.generative_lock:
            response = GENERATIVE_MODEL.invoke(prompt)
        print(f"Response in finalize: {response}")

        final_answer = (
            response
            if response
            else "Sorry, we are unable to \
            provide an overall feedback due to lack of relevant data."
        )

        return final_answer

The system could now take a multi-faceted query, split it into multiple subqueries, address each one of them asynchronously using sequentially retriever, analysis, check_relevance and summarize. Results are combined in the end before giving a feedback to the user.

Time to make the system really agentic - so far it is not trulty because of the lack of autonomy! An agentic system requires access to many tools and a core LLM that can reason on its own about sequence of tools to call in order to solve a problem. This sounds scary but not quite if well designed :) I was also planning to add some safeguards to prevent infinite looping that could consume a lot of tokens very quickly.

What I did was to register the graph I have so far as a subgraph a bigger graph (real AI system). This arm of the AI system is called researcher and has the following definition:

def researcher(self, state: AgentState) -> Any:
        if len(state.messages) < 3:
            input = state.messages[0]
        else:
            input = state.messages[-1]
        input = input.content
        logging.info(f"Input in researcher: {input}")
        result = self.manage_subtasks(input)
        end = time.time()
        logging.info(f"Result in researcher: {result}")

        return {
            "messages": [result],
        }

I also designed a planner, reflector and supervisor that the system can use. As the name indicates, the planner helps with planning steps to take to solve the problem. The reflector provides feedback and helps improve the output of the researcher. The supervisor is the core handler. It manages interations between planner, researcher and reflector.

You can inspect design code for planner, reflector and supervisor below:

def planner(self, state: AgentState) -> Any:
    input = [self.plan_system_prompt] + state.messages
    result = plan(background=input)
    answer = result.get("answer")
    return {
            "messages": [answer],
        }

def reflector(self, state: AgentState) -> Any:
    trans_map = {AIMessage: HumanMessage, HumanMessage: AIMessage}
    translated_messages = [self.refl_system_prompt, state.messages[0]] + [
    trans_map[msg.__class__](content=msg.content) for msg in state.messages[1:]
        ]
    result = tune(background=translated_messages)
    answer = result.get("answer")
    answer = (
            "Progress has been made. Use now all the resources to addess this new suggestion: "
            + answer
        )
    return {
            "messages": [HumanMessage(answer)],
        }

def supervisor(self, state: AgentState) -> Any:
    messages = [
            ("system", self.sup_system_prompt1),
            *state.messages,
            ("system", self.sup_system_prompt2),
        ]

    if len(messages) > self.max_global_visits:
       return {"next": "end"}

    result = supervise(background=messages)
    next = result.get("next")

    return {
            "next": next,
        }
  • GPU acceleration

The last point is GPU acceleration. Pjotr installed a GPU on balg01 to allow for acceleration. You can check out the details here:

The GPU is automatically used for LLM related work. I just started using it at first. Later, I learnt about SGLang which allows for deployment of LLM server and even faster inference. Code for deployment of the server is here:

With DSPy, I could literally switch between any model, closed or open. Consequently, I added support for DSPy. For details, check out the following commit:

Small gotcha, for locally served model using SGLang, not all open models could be run given VRAM (GPU's RAM) constraint. Took me some experiments to find workable models that are finetuned for instruction following and have decent performance. At the time of writing, I am working with Qwen/Qwen2.5-7B-Instruct accessed via HuggingFace. This is the LLM. There is also an embedding model but I have not added GPU acceleration support for it to improve memory management. We have limited resources for now :)

I also performed a series of refactoring and formatting to improve readability of the source code. Find it here:

Scale analysis to more data

(made with skribilo)