Retrieval¶
Key Takeaway¶
The key takeaway of RAG is that, is set up accurately, customers can customize LLMs to their applications without modifying any of their traditional data storage methods. By simply building RAG pipelines, which are analogous to traditional ETL pipelines, data can just be maintained at the source, and LLM applications will inherit this knowledge automatically.
How to do retrieval in SGP¶
In this section we dive right into things. For more information about what retrieval is and how it works, scroll to the FAQ below.
SGP Knowledge Bases¶
The first step of retrieval is to load custom data into a vector database. Vector databases allow users to search unstructured data using natural language queries. In SGP, our Knowledge Base API manages the entire lifecycle of data ingestion into vector databases on behalf of users.
This means that a user only has to manage data at the source. Simply create a knowledge base to reflect the source and periodically upload data to it.
SGP will take care of all of the underlying challenges, so users don't have to:
- Vector Index Creation and Management
- Optimize shard density and size for performance
- Automatically create new indexes when optimal index sizes are exceeded
- Multiple Data Source Integrations
- Supports Google Drive, S3, Sharepoint, Direct JSON Upload, and more.
- Smart File-Diff Uploads
- Delete artifacts deleted from source
- Re-index artifacts modified at source
- Do nothing for artifacts unchanged at source
- Worker parallelization
- Scale ingestion horizontally to maximize throughput
- SGP can ingest documents at ~500MB/hour using less than 100 worker nodes. This throughput can easily increased by bumping the number of nodes for both the ingestion workers and the embedding model. Throughput can also improve by optimizing the hardware the embedding model is hosted on.
- Autoscaling
- Autoscale ingestion workers to lower costs during dormancy and burst for spiky workloads
- Text extraction
- Automatically extract text from non-text documents, i.e. DOCX, PPTX, PDF
- Chunking
- Select from a list of SGP supported chunking strategies to automatically split data into chunks during ingestion
- Easily swap out chunking strategies just by varying a small API request payload
- Embedding
- Automatically embed each chunk of text into a vector for storage in the vector DB
- Create knowledge bases with different embedding models to test how different embedding models affect retrieval performance
To set up retrieval with SGP, first create a knowledge base. How to best organize data into separate knowledge bases depends on the use case and is still an area of research. For demo purposes, we will simply create a single knowledge base and upload the contents of an S3 bucket to it.
from scale_egp.sdk.client import EGPClient
from scale_egp.sdk.enums import EmbeddingModelName
from scale_egp.sdk.types.knowledge_base_uploads import (
S3DataSourceConfig,
CharacterChunkingStrategyConfig
)
client = EGPClient(api_key="<YOUR_SGP_API_KEY>")
knowledge_base = client.knowledge_bases().create(
name="example_knowledge_base",
embedding_model_name=EmbeddingModelName.OPENAI_TEXT_EMBEDDING_ADA_002,
)
upload = client.knowledge_bases().uploads().create_remote_upload(
knowledge_base=knowledge_base,
data_source_config=S3DataSourceConfig(
s3_bucket="<YOUR_S3_BUCKET>",
s3_prefix="<PREFIX_OF_FOLDER_WITHIN_S3_BUCKET>",
aws_region="<AWS_REGION>",
aws_account_id="<AWS_ACCOUNT_ID>",
),
data_source_auth_config=None,
chunking_strategy_config=CharacterChunkingStrategyConfig(
separator="\n\n",
chunk_size=1000,
chunk_overlap=200,
),
)
That's it! Data is now being ingested into your knowledge base. To check the status of the upload, simply poll the upload status:
import time
print(f"Upload ID: {upload.upload_id}\n")
complete = False
poll_count = 1
while not complete:
upload = client.knowledge_bases().uploads().get(id=upload.upload_id, knowledge_base=knowledge_base)
complete = upload.status == "Completed"
print(f"Poll count: {poll_count}")
print(f"Status: {upload.status}")
print(f"Status Reason: {upload.status_reason}")
print(f"Artifact Statuses: {upload.artifacts_status}\n")
poll_count += 1
time.sleep(3)
To analyze the contents of your knowledge base, you can lists its artifacts.
You can also analyze the text chunks that were extracted for a specific artifact to see if the text extraction and chunking worked properly.
Note: It is recommended that users create a knowledge base with a small sample of data and to investigate its contents before ingesting large amounts of data.
artifact = client.knowledge_bases().artifacts().get(id=artifacts[0].artifact_id, knowledge_base=knowledge_base)
for index, chunk in enumerate(artifact.chunks):
print(f"Chunk {index}")
print("="*30)
print(chunk.text)
Querying your knowledge base¶
To query a knowledge base, simply submit a natural language query. Behind the scenes, this query
is being embedded using the same embedding model that data ingested into the knowledge base used
as well. A query performs a similarity search between the embedded query and the embedding
vectors in the knowledge base. This API returns the top_k
chunks of text most semantically
similar to the query.
chunks = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query="<USER_QUERY>",
top_k=10,
include_embeddings=False
)
for index, chunk in enumerate(chunks):
print(f"Chunk rank: {index}")
print("="*30)
print(chunk.text)
Optimizing Retrieval Accuracy¶
At Scale, we have observed that trusting the naive query performance of a vector database does not yield good retrieval accuracy on its own. One of the most powerful ways to improve retrieval accuracy is to re-rank the chunks returned from an initial query.
The easiest way to improve performance is to perform a high-recall query on a knowledge base
(set top_k
to a large number) and then use a cross-encoder model to re-rank the chunks before
choosing the top-scoring chunks to append to the user query for the final LLM prompt.
from scale_egp.sdk.types.chunks import CrossEncoderRankParams, CrossEncoderRankStrategy,
RougeRankStrategy
RougeRankParams
from scale_egp.sdk.enums import CrossEncoderModelName
# Try a larger recall
query = "<YOUR_QUERY>"
chunks = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query=query,
top_k=500,
include_embeddings=False
)
reranked_chunks = client.chunks().rank(
query=query,
relevant_chunks=chunks,
rank_strategy=CrossEncoderRankStrategy(
params=CrossEncoderRankParams(
cross_encoder_model=CrossEncoderModelName.CROSS_ENCODER_MS_MARCO_MINILM_L12_V2.value,
)
),
top_k=3,
)
Users can investigate how this reranking step improved their retrieval accuracy by investigating the difference between the initial top chunks and the reranked top chunks:
print("Original Top 3 Chunks")
for index, original_top_3_chunk in enumerate(chunks[:3]):
print(f"CHUNK {index + 1}")
print("="*30)
print(original_top_3_chunk.text)
print()
print("Reranked Top 3 Chunks")
for index, reranked_top_3_chunk in enumerate(reranked_chunks):
print(f"CHUNK {index + 1}")
print("="*30)
print(reranked_top_3_chunk.text)
print()
Re-ranking can affect the latency of the end application, so there are a variety of techniques to bring this latency down. Please talk to your Scale representative for recommended techniques on how to maximize accuracy and reduce latency.
Prompt Engineering and RAG¶
Now that users have high quality chunks, it is time to augment the initial user query with the retrieved information. This piece is more experimental and prompt can vary depending on the use case, so it is important for users to attempt multiple prompts.
Users can also consult their Scale representative for advice on how to craft an effective prompt.
A standard prompt engineering function may look like the following:
import string
from typing import List
def create_prompt(user_input: str, retrieved_chunks: List[Chunk]):
prompt_template = """Instructions: {instructions}
User Query: {user_input}
Additional Context Retrieved from internal sources:
{chunks_str}
"""
chunks_str = ""
for chunk in retrieved_chunks:
chunks_str += chunk.text
if chunk.metadata:
chunks_str += "\n"
for key, value in chunk.metadata.items():
chunks_str += f"{key}: {value}"
chunks_str += "\n\n"
string_template = string.Template(template=prompt_template)
prompt = string_template.substitute(dict(
instructions="Please answer the following user query using the additional context"
"appended below.",
user_input=user_input,
chunks_str=chunks_str,
))
return prompt
Lastly, we feed the augmented prompt to our an LLM using our completions API. Here we are using GPT 3.5 Turbo, but we can swap this LLM model out for any other LLM that SGP supports.
rag_prompt = create_prompt(user_input="<YOUR_QUERY>", retrieved_chunks=reranked_chunks)
completion = client.completions().create(model="gpt-3.5-turbo", prompt=rag_prompt)
print(completion.completion.text)
Improve Response Accuracy¶
As you can see, there are many steps in the retrieval process. Let's review them:
- Choose an embedding model*
- Create a knowledge base using that embedding model
- Choose a chunking strategy*
- Ingest data into the knowledge base directly or from a supported data source using the chosen embedding model and chunking strategy
- Choose a reranking model*
- Rerank chunks queried from the knowledge base using the chosen reranking model
- Choose an LLM*
- Engineer a retrieval-augmented prompt*
- Generate a response using the chosen LLM
After setting up this pipeline, the next obvious step is to improve the AIs response accuracy. Here, SGP demonstrates one of its most powerful features. Every step denoted with an asterisk can be swapped out for quick experimentation without modifying the rest of the RAG pipeline.
Below, we discuss several scenarios and how to optimize some of these steps.
Note: Currently, if model finetuning is needed, customers are encouraged to engage with their Scale representatives to decide what kind of data to collect. Scale representatives will finetune the models on this data, then deploy them to the customer's SGP platform. To use the finetuned model, the user simply has to swap the model name for the finetuned model name. In the near future, finetuning will be supported as a self-service feature via the SGP API.
Modify the Chunking Strategy¶
Static chunking is currently the simplest way to split up unstructured data. However, more intelligent chunking may be needed for specific use cases and data types. SGP allows users to swap out chunking strategies for other supported strategies easily. Swapping out a chunking strategy in an upload job to an existing knowledge re-index all artifacts from the data source using the new chunking strategy.
SGP will continue to add supported chunking strategies based on demand.
upload = client.knowledge_bases().uploads().create_remote_upload(
...,
chunking_strategy_config=CustomChunkingStrategyConfig(...),
)
Finetune or Swap the Reranking Model¶
Scale has internally discovered that reranking chunks is one of the most effective ways to improve retrieval accuracy. By simply replacing the base re-ranking model with a finetuned or alternative version of the model, users will get easy bumps in retrieval accuracy.
Here are some signs that indicate that the reranking model should be finetuned:
- The query can be easily matched by a human to chunks that are available in the knowledge base
- The correct chunks to retrieve are available in the knowledge base, but the reranker scores other chunks higher.
SGP will continue to add supported ranking strategies and models based on demand.
reranked_chunks = client.chunks().rank(
query=query,
relevant_chunks=chunks,
rank_strategy=CustomRankingStrategy(...),
top_k=3,
)
Finetune or Swap the Embedding Model¶
The initial embedding model chosen may be insufficient to understand the semantics of use-case specific queries. For example, if a lawyer asks the question "What is carry?" on a 100 page LPA document, it is extremely unlikely for any off the shelf embedding model to encode each chunk of the document in a way that includes enough information for the user query to semantically match the correct chunks.
Here are some signs that indicate that the embedding model should be finetuned:
- The correct chunks exist in the knowledge base, but the query cannot be easily matched by a human to these chunks
- The correct chunks do not consistently appear in the initial high-recall knowledge base query, so the reranker does not consistently see the correct chunks.
- The user wants to lower query latency for the application, so the recall size on the initial knowledge base query needs to be lowered, but the correct chunks still need to consistently appear in this smaller recall.
knowledge_base = client.knowledge_bases().create(
name="kb_with_new_embedding_model",
embedding_model_name="<NEW_EMBEDDING_MODEL>"
)
Finetune or Swap the Large-Langauge Model¶
There are situations where the synthesis of a final AI response is still not accurate even when the prompts contain enough data to respond to the user query. These problems can be fixed by finetuning the LLM that synthesizes the final response.
Here are some indications that the LLM should be finetuned:
- Even when a prompt contains sufficient information to respond to the query, the LLM says it cannot generate a proper response.
- The LLM generates a non-sensical answer instead of saying it cannot respond to the user query.
- Security and responsible AI vulnerabilities have been uncovered by red-teaming and the LLM needs to be finetuned to not respond to malicious user queries.
completion = client.completions().create(model="<NEW_LLM>", prompt="<RAG_PROMPT>")
output = completion.completion.text
Iterative Prompt Engineering¶
Prompt engineering happens entirely client side, so users have maximum flexibility to modify and test various prompts as needed.
FAQ¶
What is retrieval and why is it important for enterprise?¶
Large Langauge Models have inspired an exciting new wave of AI applications. However, incorporating LLM capabilities into enterprise applications is more challenging. It takes millions of dollars and very-specific expertise to build a foundation model. So, how can companies get ChatGPT-like functionality running on their enterprise data?
Currently, one of the most effective ways to do this is through Retrieval Augmented Generation (RAG). RAG is a technique where users retrieve custom information from a data source and append this information to their LLM prompt. For example, if a users asks:
How do we think Company X will perform this quarter?
a RAG application would search various data sources (articles, databases, etc.) for information about Company X and create a prompt such as the following:
Instructions:
Please answer the following user query using the additional context appended below.
User Query:
How do we think Company X will perform this quarter?
Additional Context Retrieved from internal sources:
Expected 2023 stock growth: +10%, Actual 2023 stock growth: +20% Source: internal_database
Restrictions on international exports of Product Y are expected to slow down Company X's growth. Source: external_news/international_exports_restricted_on_product_y.pdf Publication Timestamp: 1 week ago
Author Z recommends that investors hold current assets, but temper expectations and slow down stock purchases. Source: internal_article_source/recommendations/author_z.pdf Publication Timestamp: 3 minutes ago
For the original user prompt, an off-the-shelf LLM would produce the following response:
I'm sorry I could not find any information about how Company X will perform this quarter. As an AI model I am unable to make future predictions.
For the Retrieval Augmented prompt, the same off-the-shelf LLM would now be able to generated a reasonable response:
According to data retrieved from internal sources, Company X has double stock growth estimates so far in 2023, however, is expected to slow down due to restrictions imposed last week on exports of Product Y. Author Z recommends that investors hold current assets and not execute additional stock purchases.
LLMs have a unique capability to do what is called "in-context learning". This means that even if they weren't trained on a specific piece of data, if it is given that data live in the prompt, it is generally capable of interpreting that data and using it to answer a user query.
Appendix¶
End to End Code Snippet¶
import json
import os
import time
from typing import List, Union
import questionary as q
from scale_egp.sdk.client import EGPClient
from scale_egp.sdk.enums import (
CrossEncoderModelName,
EmbeddingModelName,
)
from scale_egp.sdk.types.chunks import (
CrossEncoderRankParams, CrossEncoderRankStrategy,
RougeRankParams, RougeRankStrategy,
)
from scale_egp.sdk.types.knowledge_base_uploads import (
S3DataSourceConfig,
CharacterChunkingStrategyConfig,
)
from scale_egp.utils.model_utils import BaseModel
# Helper functions
def dump_model(model: Union[BaseModel, List[BaseModel]]):
if isinstance(model, list):
return json.dumps([m.dict() for m in model], indent=2, sort_keys=True, default=str)
return json.dumps(model.dict(), indent=2, sort_keys=True, default=str)
if __name__ == "__main__":
api_key = q.text(f"Please enter your SGP API key:", default=os.environ.get("EGP_API_KEY")).ask()
client = EGPClient(api_key=api_key)
# Contains public 8ks from amazon, microsoft, apple, morgan stanley
KNOWLEDGE_BASE_ID = None
knowledge_base_name = "small_8k_demo"
embedding_model_name = EmbeddingModelName.OPENAI_TEXT_EMBEDDING_ADA_002
if KNOWLEDGE_BASE_ID:
knowledge_base_id = KNOWLEDGE_BASE_ID
else:
knowledge_base_id = q.text(
f"ID of existing knowledge base (Leave blank to create a new one with name "
f"'{knowledge_base_name}' and embedding model '{embedding_model_name}'):"
).ask()
if knowledge_base_id:
knowledge_base = client.knowledge_bases().get(id=knowledge_base_id)
else:
knowledge_base = client.knowledge_bases().create(
name=knowledge_base_name,
embedding_model_name=embedding_model_name,
)
print(f"Knowledge base:\n{dump_model(knowledge_base)}")
print("=" * 50)
UPLOAD_ID = None
if UPLOAD_ID:
upload_id = UPLOAD_ID
else:
upload_id = q.text(f"ID of existing upload (Leave blank to create a new one):").ask()
if upload_id:
upload = client.knowledge_bases().uploads().get(id=upload_id, knowledge_base=knowledge_base)
else:
print("Please enter the following information to create a new upload:")
s3_bucket = q.text(f"S3 bucket:").ask()
s3_prefix = q.text(f"S3 prefix:").ask()
aws_region = q.text(f"AWS region:").ask()
aws_account_id = q.text(f"AWS account ID:").ask()
upload = client.knowledge_bases().uploads().create_remote_upload(
knowledge_base=knowledge_base,
data_source_config=S3DataSourceConfig(
s3_bucket=s3_bucket,
s3_prefix=s3_prefix,
aws_region=aws_region,
aws_account_id=aws_account_id,
),
data_source_auth_config=None,
chunking_strategy_config=CharacterChunkingStrategyConfig(
separator="\n\n",
chunk_size=1000,
chunk_overlap=200,
),
)
print(f"Knowledge Base Upload:\n{dump_model(upload)}")
complete = False
poll_count = 1
while not complete:
upload = client.knowledge_bases().uploads().get(
id=upload.upload_id, knowledge_base=knowledge_base
)
complete = upload.status == "Completed"
print(f"Poll count: {poll_count}")
print(f"Status: {upload.status}")
print(f"Status Reason: {upload.status_reason}")
print(f"Artifact Statuses: {upload.artifacts_status}\n")
poll_count += 1
time.sleep(3)
print("=" * 50)
print("Artifacts in knowledge base:")
artifacts = client.knowledge_bases().artifacts().list(knowledge_base=knowledge_base)
if not artifacts:
print("No artifacts in knowledge base.")
for artifact in artifacts:
print(f"({artifact.artifact_id}) {artifact.artifact_uri}")
print("=" * 50)
selected_artifact = artifacts[-1]
print(f"Chunks in artifact: {selected_artifact.artifact_uri}")
artifact = client.knowledge_bases().artifacts().get(
id=artifacts[-1].artifact_id, knowledge_base=knowledge_base
)
for index, chunk in enumerate(artifact.chunks):
print(f"Chunk {index}")
print("=" * 30)
print(chunk.text)
print("=" * 50)
query = "What new events did Morgan Stanley report in their latest 8k?"
print(f"Querying knowledge base: {query}")
chunks = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query=query,
top_k=3,
include_embeddings=False
)
if not chunks:
print("No chunks returned.")
for index, chunk in enumerate(chunks):
print(f"Chunk rank: {index}")
print("=" * 30)
print(chunk.text)
print("=" * 50)
print("Comparing raw retrieval with cross-encoder/rouge re-ranked retrieval...")
chunks = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query=query,
top_k=3,
include_embeddings=False
)
print("Original Top 3 Chunks")
for index, original_top_3_chunk in enumerate(chunks[:3]):
print(f"CHUNK {index + 1}")
print("=" * 30)
print(original_top_3_chunk.text)
print()
# Try a larger recall with reranking
print("Re-ranking chunks with cross-encoder/rouge...")
chunks = client.knowledge_bases().query(
knowledge_base=knowledge_base,
query=query,
top_k=10,
include_embeddings=False
)
sub_re_ranked_chunks = client.chunks().rank(
query=query,
relevant_chunks=chunks,
rank_strategy=RougeRankStrategy(
params=RougeRankParams(
method="rouge2",
score="recall",
)
),
top_k=5,
)
re_ranked_chunks = client.chunks().rank(
query=query,
relevant_chunks=sub_re_ranked_chunks,
rank_strategy=CrossEncoderRankStrategy(
params=CrossEncoderRankParams(
cross_encoder_model=
CrossEncoderModelName.CROSS_ENCODER_MS_MARCO_MINILM_L12_V2.value,
)
),
top_k=3,
)
print("\n\n\nRe-ranked Top 3 Chunks")
for index, re_ranked_top_3_chunk in enumerate(re_ranked_chunks):
print(f"CHUNK {index + 1}")
print("=" * 30)
print(re_ranked_top_3_chunk.text)
print()
print("=" * 50)