Summary: This post explores how to build Retrieval-Augmented Generation (RAG) applications using LLaMA 2 and .NET. Learn how to set up LLaMA 2, implement vector search, and create a complete RAG system that enhances LLM responses with relevant information from your own data sources.
Introduction
In July 2023, Meta released LLaMA 2, a significant advancement in open-source large language models (LLMs). With improved performance and a permissive license for commercial use, LLaMA 2 has quickly become a popular choice for organizations looking to build AI applications with their own data.
One of the most powerful applications of LLMs like LLaMA 2 is Retrieval-Augmented Generation (RAG). RAG combines the generative capabilities of LLMs with information retrieval systems to produce responses that are both contextually relevant and factually accurate based on your specific data sources.
For .NET developers, building RAG applications with LLaMA 2 opens up exciting possibilities for creating intelligent applications that can reason over your organization’s proprietary data. In this post, we’ll explore how to build a complete RAG system using LLaMA 2 and .NET, covering everything from setting up the model to implementing vector search and creating a seamless user experience.
Understanding RAG
Before diving into implementation, let’s understand what RAG is and why it’s so valuable.
What is Retrieval-Augmented Generation?
RAG is a hybrid approach that combines:
- Retrieval: Finding relevant information from a knowledge base or data source
- Augmentation: Enhancing the LLM’s context with this retrieved information
- Generation: Using the LLM to generate a response based on both the user query and the retrieved information
This approach addresses several limitations of pure LLMs:
- Knowledge cutoff: LLMs only know information up to their training cutoff date
- Hallucinations: LLMs can generate plausible-sounding but incorrect information
- Proprietary data: LLMs don’t know about your organization’s specific data
- Traceability: Pure LLM responses don’t provide sources for their information
The RAG Architecture
A typical RAG system consists of these components:
- Document processing pipeline: Ingests, chunks, and processes documents
- Embedding model: Converts text into vector representations
- Vector database: Stores and enables semantic search of embeddings
- Retriever: Finds relevant documents based on a query
- LLM: Generates responses using the retrieved documents as context
- Orchestrator: Coordinates the flow between components
Setting Up LLaMA 2 for .NET
Let’s start by setting up LLaMA 2 for use with .NET applications.
Prerequisites
To follow along with this tutorial, you’ll need:
- .NET 6 or later
- Visual Studio 2022 or Visual Studio Code
- Access to LLaMA 2 (we’ll use OLLAMA for easy setup)
- At least 16GB of RAM (32GB recommended for larger models)
Installing OLLAMA
OLLAMA is a tool that simplifies running LLMs locally. Let’s install it:
Windows:
- Download the installer fromĀ ollama.ai
- Run the installer and follow the prompts
macOS:
bash
brew install ollama
Linux:
bash
curl -fsSL https://ollama.ai/install.sh | sh
Pulling LLaMA 2
Once OLLAMA is installed, pull the LLaMA 2 model:
bash
ollama pull llama2
For better performance with RAG applications, you might want to use the 13B parameter version:
bash
ollama pull llama2:13b
Creating a .NET Project
Let’s create a new .NET project for our RAG application:
bash
dotnet new console -n LlamaRagApp
cd LlamaRagApp
Add the necessary packages:
bash
dotnet add package System.Net.Http.Json
dotnet add package Pgvector.EntityFrameworkCore
dotnet add package Microsoft.EntityFrameworkCore.Sqlite
dotnet add package Microsoft.Extensions.Hosting
dotnet add package Microsoft.ML
Building the RAG Components
Now, let’s build the components of our RAG system.
1. Document Processing
First, we need to process documents into chunks that can be embedded and stored in our vector database.
csharp
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Text.RegularExpressions;
public class DocumentProcessor
{
private readonly int _maxChunkSize;
private readonly int _chunkOverlap;
public DocumentProcessor(int maxChunkSize = 1000, int chunkOverlap = 200 )
{
_maxChunkSize = maxChunkSize;
_chunkOverlap = chunkOverlap;
}
public List<DocumentChunk> ProcessFile(string filePath)
{
string text = File.ReadAllText(filePath);
string fileName = Path.GetFileName(filePath);
return ChunkText(text, fileName);
}
public List<DocumentChunk> ProcessText(string text, string source)
{
return ChunkText(text, source);
}
private List<DocumentChunk> ChunkText(string text, string source)
{
// Clean the text
text = CleanText(text);
// Split into paragraphs
string[] paragraphs = text.Split(new[] { "\n\n", "\r\n\r\n" }, StringSplitOptions.RemoveEmptyEntries);
var chunks = new List<DocumentChunk>();
var currentChunk = new StringBuilder();
foreach (var paragraph in paragraphs)
{
// If adding this paragraph would exceed the max chunk size,
// save the current chunk and start a new one
if (currentChunk.Length + paragraph.Length > _maxChunkSize && currentChunk.Length > 0)
{
chunks.Add(new DocumentChunk
{
Text = currentChunk.ToString().Trim(),
Source = source
});
// Start a new chunk with overlap
int overlapStart = Math.Max(0, currentChunk.Length - _chunkOverlap);
currentChunk = new StringBuilder(currentChunk.ToString().Substring(overlapStart));
}
currentChunk.AppendLine(paragraph);
currentChunk.AppendLine();
}
// Add the last chunk if it's not empty
if (currentChunk.Length > 0)
{
chunks.Add(new DocumentChunk
{
Text = currentChunk.ToString().Trim(),
Source = source
});
}
return chunks;
}
private string CleanText(string text)
{
// Remove extra whitespace
text = Regex.Replace(text, @"\s+", " ");
// Remove special characters that might cause issues
text = Regex.Replace(text, @"[^\w\s.,;:!?()[\]{}\-\"\'`]", "");
return text.Trim();
}
}
public class DocumentChunk
{
public string Text { get; set; }
public string Source { get; set; }
public List<float> Embedding { get; set; }
}
2. Embedding Generation
Next, we need to generate embeddings for our document chunks. We’ll use OLLAMA’s embedding API:
csharp
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
public class EmbeddingService
{
private readonly HttpClient _httpClient;
private readonly string _baseUrl;
private readonly string _modelName;
public EmbeddingService(string baseUrl = "http://localhost:11434", string modelName = "llama2" )
{
_baseUrl = baseUrl;
_modelName = modelName;
_httpClient = new HttpClient( );
}
public async Task<List<float>> GenerateEmbeddingAsync(string text)
{
var request = new
{
model = _modelName,
prompt = text
};
var response = await _httpClient.PostAsJsonAsync($"{_baseUrl}/api/embeddings", request );
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
var embeddingResponse = JsonSerializer.Deserialize<EmbeddingResponse>(content);
return embeddingResponse.Embedding;
}
public async Task<List<DocumentChunk>> GenerateEmbeddingsForChunksAsync(List<DocumentChunk> chunks)
{
foreach (var chunk in chunks)
{
chunk.Embedding = await GenerateEmbeddingAsync(chunk.Text);
}
return chunks;
}
private class EmbeddingResponse
{
[JsonPropertyName("embedding")]
public List<float> Embedding { get; set; }
}
}
3. Vector Database
Now, let’s set up a vector database to store our embeddings. We’ll use SQLite with the pgvector extension for simplicity:
csharp
using Microsoft.EntityFrameworkCore;
using Pgvector;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
public class VectorDbContext : DbContext
{
public DbSet<DocumentEntity> Documents { get; set; }
public VectorDbContext(DbContextOptions<VectorDbContext> options) : base(options)
{
}
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.Entity<DocumentEntity>(entity =>
{
entity.HasKey(e => e.Id);
entity.Property(e => e.Embedding).HasColumnType("vector");
});
}
}
public class DocumentEntity
{
public int Id { get; set; }
public string Text { get; set; }
public string Source { get; set; }
public Vector Embedding { get; set; }
}
public class VectorDatabase
{
private readonly VectorDbContext _dbContext;
public VectorDatabase(VectorDbContext dbContext)
{
_dbContext = dbContext;
}
public async Task InitializeDatabaseAsync()
{
await _dbContext.Database.EnsureCreatedAsync();
}
public async Task AddDocumentsAsync(List<DocumentChunk> chunks)
{
foreach (var chunk in chunks)
{
var documentEntity = new DocumentEntity
{
Text = chunk.Text,
Source = chunk.Source,
Embedding = new Vector(chunk.Embedding.ToArray())
};
_dbContext.Documents.Add(documentEntity);
}
await _dbContext.SaveChangesAsync();
}
public async Task<List<DocumentEntity>> SearchSimilarDocumentsAsync(List<float> queryEmbedding, int limit = 5)
{
var queryVector = new Vector(queryEmbedding.ToArray());
var results = await _dbContext.Documents
.OrderBy(d => d.Embedding.L2Distance(queryVector))
.Take(limit)
.ToListAsync();
return results;
}
}
4. LLaMA 2 Client
Next, let’s create a client to interact with LLaMA 2 through OLLAMA:
csharp
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
public class LlamaClient
{
private readonly HttpClient _httpClient;
private readonly string _baseUrl;
private readonly string _modelName;
public LlamaClient(string baseUrl = "http://localhost:11434", string modelName = "llama2" )
{
_baseUrl = baseUrl;
_modelName = modelName;
_httpClient = new HttpClient( );
}
public async Task<string> GenerateResponseAsync(string prompt, string systemPrompt = null)
{
var request = new
{
model = _modelName,
prompt = prompt,
system = systemPrompt,
stream = false
};
var response = await _httpClient.PostAsJsonAsync($"{_baseUrl}/api/generate", request );
response.EnsureSuccessStatusCode();
var content = await response.Content.ReadAsStringAsync();
var generateResponse = JsonSerializer.Deserialize<GenerateResponse>(content);
return generateResponse.Response;
}
public async Task<string> GenerateResponseWithContextAsync(string query, List<string> contextTexts)
{
// Combine the context texts
var contextBuilder = new System.Text.StringBuilder();
foreach (var text in contextTexts)
{
contextBuilder.AppendLine(text);
contextBuilder.AppendLine();
}
string context = contextBuilder.ToString();
// Create a prompt that includes the context
string prompt = $@"Answer the following question based on the provided context. If the answer cannot be determined from the context, say 'I don't have enough information to answer this question.'
Context:
{context}
Question: {query}
Answer:";
string systemPrompt = "You are a helpful assistant that provides accurate information based on the context provided. Always cite your sources when possible.";
return await GenerateResponseAsync(prompt, systemPrompt);
}
private class GenerateResponse
{
[JsonPropertyName("response")]
public string Response { get; set; }
}
}