How to use a Custom Embedding Model or LLM with Trieve

The LLM is the easiest thing to replace, all you have to do is set the OPENAI_BASE_URL to an OpenAI compliant API endpoint for text completions using a gpt-3.5-turbo model. The easiest way to do this is by self-hosting LocalAI.

This guide will demonstrate how to build a custom embedding server for the Trieve API. This will allow you to use your own models for generating embeddings and not use OpenAI.

Prerequisites

Step 1: Create a new Python file

Create a new Python file called embedding_server.py and import the following packages:

from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModel
import torch

Step 2: Load the model

Load the model and tokenizer. For this example, we will use the bert-base-uncased model.

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

Step 3: Create the Flask app

Create the Flask app and add a route for the /embed endpoint. This endpoint will be used to generate embeddings for the Trieve API.

app = Flask(__name__)

@app.route("/embed", methods=["POST"])
def embed():
    return jsonify({"embeddings": []})

Step 4: Generate embeddings

Now we can generate embeddings for the Trieve API. The Trieve API will send a POST request to the /embed endpoint with the following JSON body:

{
  "sentences": ["This is the first sentence.", "This is the second sentence."]
}

We can use the sentences field to generate embeddings for each sentence. We will use the bert-base-uncased model to generate the embeddings.

@app.route("/embed", methods=["POST"])
def embed():
    sentences = request.json["sentences"]
    inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.detach().numpy().tolist()
    return jsonify({"embeddings": embeddings})

Step 5: Run the server

Now we can run the server and test it out. Run the following command to start the server:

python embedding_server.py

You should see the following output:

* Serving Flask app 'embedding_server' (lazy loading)
* Environment: production
* Debug mode: off
* Running on http://

Now we can test the server by sending a POST request to the /embed endpoint. You can use the following command to test the server:

curl -X POST -H "Content-Type: application/json" -d '{"sentences": ["This is the first sentence.", "This is the second sentence."]}'

You should see the following output:

{
  "embeddings": [
    [
        [
            -0.3405719699859619,...
        ],
    ]
  ]
}

Step 6: Call the API from the Trieve API

Now that we have a custom embedding server, we can use it to generate embeddings for the Trieve API. To do this, we need to update the .env file to use the custom embedding server. Open the .env file and update the embedding_server field to use the custom embedding server:

EMBEDDING_SERVER_CALL = <Link to embedding server>