Full-text search in Flask with ChromaDB and embeddings

web
(Flask + Chroma logos)

Update 20/07/2023

This article is referring to ChromaDB version 0.3.26. There have been breaking changes in the API with respect to this article and the latest version 0.4.2. Consequently, a couple of changes are warranted:

  • Instead of chromadb.Client, one could now use chromadb.PersistentClient. There’s a path argument for persistence, and chromadbsettings is not needed. One should also use the settings argument with chromadbSettings(allow_reset = True), so that the app.chromadb.reset() used by the implemented flask reindex CLI in app.py works.
  • This post’s chromadb.persist() should be removed as it’s not supported anymore.

Implementing search in Flask

Like many developers using Flask, my first steps followed on Miguel Grinberg’s excellent tutorials on creating Flask web applications (you can find more info on these resources in his blog).

One of the more advanced features showcased in his Udemy course is the integration of a full-text capable search functionality, implemented using elasticsearch (you may find a written version in Grinberg’s blog here). By running a local elasticsearch installation and initializing a connection from the Flask web app, the author populated elasticsearch indexes with duplicated text data from the app’s relational database. Afterwards, utilizing a mixin class alongside the SQLAlchemy model classes, the searchable fields of the target database model could be specified and queried with a search function linking SQLAlchemy with elasticsearch queries.

While trying to implement such a search functionality during my Journal Hub project’s early development, I had some trouble installing elasticsearch on MacOS 13 via homebrew, and since dockerizing seemed overkill at the time, I wondered what alternatives there were. Having previously worked on information-enhanced chatbot development with langchain (see this post), I quickly realized that any full-text query & retrieval implementation could also be done via the lightweight and open-source embeddings models out there.

Embedding models

I jump-started with ChromaDB and its default embeddings model, which fortunately is quite slim: the 80 MB all-MiniLM-L6-v2 model from the SentenceTransformers framework, available also in the HuggingFace Hub. It can embed 256-token sequences into a 384-dimensional space (each token is thus a 384-dimensional vector), and is advertised as suitable for many use cases and for basic embedding proximity/similarity searches.

Text embeddings turn words into numbers in a high-dimensional space (in this case, 384 dimensions). The dimensions can be chosen either by hand, or by creating a dictionary of unique word tokens in a dataset, or rather by training a fixed-size Embedding layer in a Neural Network. During this training, the network learns to assign similar vectors to words that have similar meanings or are used in similar contexts. The resulting parameterized embedding layer usually captures far more information about words and language than a simple one-to-one dictionary mapping.

We implement search by passing each sequence we want to be able to query through the embedding function/model and storing the calculated embedding vectors (one vector per embedded sequence). Then, we calculate the similarity between all the stored embeddings and our query string, and fetch results that result in a high similarity ranking. Similarity is usually calculated as a distance measure between the two vectors, for example the angle between the two multi-dimensional vectors: if the angle is 0 degrees, they are completely similar, if it’s 90 degrees they are completely dissimilar.

Sample Project Structure

Depending on the structure of the Flask app, there are subtle differences in how to implement the necessary parts. I use a blueprint-based structure, where the flask app is initialized in a factory module, so it goes like the following lines:

Here, the app/__init__.py file contains the factory initialization function create_app that is used by app.py to create the app variable and all related flask extensions.

# app.py

# ...
app = create_app(os.getenv("FLASK_CONFIG") or "default")
# ...
# app/__init__.py

# ...
def create_app(config_class='default'):
    app = Flask(__name__)
    app.config.from_object(config[config_class])

    from .main import bp as main_bp
    app.register_blueprint(main_bp)

    from .auth import auth as auth_blueprint 
    app.register_blueprint(auth_blueprint, url_prefix='/auth')

    # ...
    return app

# ...

Implementation

1. ChromaDB initialization

We initialize the ChromaDB in the app/__init__.py file, and we make sure to specify a persistent directory for the calculated embedding vectors for each processed and stored text sequence.

# app/__init__.py

# ... 
chroma = chromadb.Client(chromadbSettings(
        chroma_db_impl="duckdb+parquet",
        persist_directory="cache_chromadb",
        ))
# ... 

Note we are using a locally run Chroma implementation. Much like using an SQLite database instead of a client/server database, however, this is not an ideal scenario when one is looking to scale the app. ChromaDB supports a client/server implementation via a docker image.

2. ChromaDB collection management

We now need to define some functions that will manipulate the ChromaDB collection objects during our application’s runtime.

Remember, as the implementation relates to the Journal Hub project, the search functionality is about searching for papers relevant to a full-text query, that have their respective SQLAlchemy models defined in app/models.py. Each model can have separate searchable attributes, defined via the __searchable__ dunder attribute.

The approach is as modular as possible with respect to other db.models, but there will be need for a few if/else checks within the functions to differentiate the handling for each model class.

# app/models.py

# ... 
class PaperShowcase(db.Model):
    __searchable__ = ['source', 'identifier', 'secondary_identifier', 'abstract', 'title']
    __tablename__ = 'papers'
    id = db.Column(db.Integer, primary_key=True)
    identifier = db.Column(db.String(50), unique=True, index=True)
    secondary_identifier = db.Column(db.String(50), index=True)
    abstract = db.Column(db.String(10000))
    title = db.Column(db.String(10000))
    # ...

We create a new app/search.py file to contain these functions:

from flask import current_app
from chromadb.utils import embedding_functions

def add_to_collection(collection, model):
    # collection builder for SQLalchemy model
    c = current_app.chromadb.get_or_create_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
    # searchable document is the concatenation of all relevant text fields
    entry = ""
    for field in model.__searchable__:
        if getattr(model, field) is not None:
            entry = entry + " " + getattr(model, field)
    metadatas = [{'source': getattr(model, 'identifier')+getattr(model,'secondary_identifier') if getattr(model,'secondary_identifier') is not None else getattr(model,'identifier')}]
    # note upsert instead of add
    c.upsert(documents=[entry],
                            metadatas=metadatas,
                            ids=[f"{collection}_{str(getattr(model, 'id'))}"])
    current_app.chromadb.persist() # needed?


def remove_from_collection(collection, model):
    try:
        c = current_app.chromadb.get_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
    except Exception as e:
        print(f"Could not find collection {collection} - {e}")    
    c.delete(ids=[getattr(model, 'id')])
    current_app.chromadb.persist() # needed?


def query_collection(collection, query_str):
    try:
        c = current_app.chromadb.get_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
    except Exception as e:
        print(f"Could not find collection {collection} - {e}")
    results = c.query(query_texts=[query_str])
    ids = [int(i.split(f"{collection}_")[1]) for i in results['ids'][0]] # for isinstance(query,str) you unpack list of lists in results
    scores = [float(i) for i in results['distances'][0]] # for isinstance(query,str) you unpack list of lists in results
    return ids, scores

In add_to_collection, we handle the logic for every occasion we want to embed some data stored under a model class in our ChromaDB collection object. We first create an entry to be embedded by concatenating all the searchable fields in the database row/instance. The metadata object required by ChromaDB requires a ‘source’ key-value pair, providing a reference source from the relevant model.identifiers (sometimes a secondary_identifier exists, but that’s just some extra logic). Thus, the embedded information for a paper will be referenced in the source by the paper’s identifier(s) in the database. The entry is then upserted in the collection, and the collection’s id is the same as the database’s model.id. The requisiteness of the current_app.chromadb.persist() call is not clear for me, but I added it for good measure. The remove_from_collection function just deletes an object from the collection.

When the add_to_collection is run for the first time, the default embeddings model will be downloaded and stored in a local cache directory.

The querying of the ChromaDB collection happens at query_collection, where there’s a query_str as an argument. We call the query() method of the collection, and we retrieve a results dictionary. This contains ids and calculated similarity distances, as a list corresponding to our query string. A collection’s query() method allows for a list of query strings, but in our case we provide a single one. We return the integer ids for each vector in the collection, along with the similarity scores calculated against our query.

3. SQLALchemy Mixin

Now that we have references to the ids and the scores, we can interface with SQLAlchemy queries.

I hinted at a Mixin Class earlier, and this is where it gets into play. We basically want to create a class from which our db.models will inherit a search() method from. We also need to handle automatic synchronization between the SQL database and ChromaDB, at insertion, modification and deletion time of to-be-queried paper data.

Here’s the class:

# app/models.py

# ... 

# searchable SQLalchemy mixin for generic searching functionality
class SearchableMixin(object):
    @classmethod
    def search(cls, expression, user, page=None, per_page=None):
        ids, scores = query_collection(cls.__tablename__, expression, page, per_page)
        if len(scores)==0:
            return cls.query.filter(False), []
        ranking = []
        # print(ids)
        for i in range(len(ids)):
            ranking.append((ids[i],i))
        results = db.session.query(PaperShowcase)\
            .filter(cls.id.in_(ids))\
            .order_by(db.case(*ranking, value=cls.id))
        return results, scores

    @classmethod
    def before_commit(cls, session):
        setattr(session, f"{cls.__tablename__}_changes", {
            'add': [obj for obj in session.new if isinstance(obj, cls)],
            'update': [obj for obj in session.dirty if isinstance(obj, cls)],
            'delete': [obj for obj in session.deleted if isinstance(obj, cls)],
        })

    @classmethod
    def after_commit(cls, session):
        session_idx_change = getattr(session, f"{cls.__tablename__}_changes")
        if session_idx_change:
            for obj in session_idx_change['add']:
                add_to_collection(cls.__tablename__, obj)
            for obj in session_idx_change['update']:
                add_to_collection(cls.__tablename__, obj)
            for obj in session_idx_change['delete']:
                remove_from_collection(cls.__tablename__, obj)
            setattr(session, f"{cls.__tablename__}_changes", None)
            current_app.chromadb.persist()

    @classmethod
    def reindex(cls):
        for obj in cls.query:
            add_to_collection(cls.__tablename__, obj)

# ...

For the search method, we get ids and scores from the query_collection function introduced previously in the app/search.py module. The rest is a procedure to use the ids as filters in a SQLAlchemy query. For this, we create a ranking list of tuples, that we can use in the order_by method. This way, we receive SQL query results based on the similarity score. An edge case for 0 scores returned is also handled, corresponding basically to an empty database as this implementation should return something no matter how small the similarity score is.

The before_commit and after_commit class methods are defined so that they perform the necessary manipulations based on session-level changes. Note that the session object here is a db.session, not Flask’s session cookie storage. We use these methods in SQLAlchemy listeners created at the end of the app/models.py file:

# app/models.py

# ...

db.event.listen(db.session, 'before_commit', PaperShowcase.before_commit)
db.event.listen(db.session, 'after_commit', PaperShowcase.after_commit)

We also implement a reindex() class method, to be used when an empty ChromaDB needs to be re-populated based on our relational database.

4. Search endpoint route

The final piece in this puzzle is the definition of an endpoint to handle the search requests. As done in Grinberg’s post, I assume a GET-capable search form in the navbar (accessible from all routes, via flask’s g global storage):

# app/main/routes.py

# ...
@main.before_app_request
def before_request():
    g.search_form = SearchNavbarForm()
# ...

and now we can handle this with GET requests:

# app/main/routes.py

# ...
@main.route('/search')
def search():
    if not g.search_form.validate():
        print(g.search_form.errors)
        return redirect(url_for('main.index'))
    results, scores = PaperShowcase.search(g.search_form.q.data, current_user)
    results = [i.to_dict() for i in results]
    return render_template('showcases/search.html', 
                           type="papers",
                           query=g.search_form.q.data, 
                           results=results, 
                           scores=scores)
# ...

Here, the to_dict() is a helper instance method for the PaperShowcase model, allowing us to convert to JSON on the jinja frontend to use with javascript components. Similar handling could of course happen without converting from SQLAlchemy objects if there’s no such requirement.

Wrap-up

Et voila! You have implemented full-text search functionality using modern text-embedding models!

Here’s how it works in the Journal Hub app:

As a reminder, note that to scale your development, you need to reconsider the local duckdb+parquet implementation of ChromaDB and go to a client/server implementation. If you don’t, I’ve had inconsistencies and bugs pop-up even when using a gunicorn production server with many worker threads.