Update 20/07/2023
This article is referring to ChromaDB version 0.3.26. There have been breaking changes in the API with respect to this article and the latest version 0.4.2. Consequently, a couple of changes are warranted:
- Instead of
chromadb.Client
, one could now usechromadb.PersistentClient
. There’s apath
argument for persistence, andchromadbsettings
is not needed. One should also use thesettings
argument withchromadbSettings(allow_reset = True)
, so that theapp.chromadb.reset()
used by the implementedflask reindex
CLI inapp.py
works.- This post’s
chromadb.persist()
should be removed as it’s not supported anymore.
Like many developers using Flask, my first steps followed on Miguel Grinberg’s excellent tutorials on creating Flask web applications (you can find more info on these resources in his blog).
One of the more advanced features showcased in his Udemy course is the integration of a full-text capable search functionality, implemented using elasticsearch (you may find a written version in Grinberg’s blog here). By running a local elasticsearch installation and initializing a connection from the Flask web app, the author populated elasticsearch indexes with duplicated text data from the app’s relational database. Afterwards, utilizing a mixin class alongside the SQLAlchemy model classes, the searchable fields of the target database model could be specified and queried with a search function linking SQLAlchemy with elasticsearch queries.
While trying to implement such a search functionality during my Journal Hub project’s early development, I had some trouble installing elasticsearch on MacOS 13 via homebrew, and since dockerizing seemed overkill at the time, I wondered what alternatives there were. Having previously worked on information-enhanced chatbot development with langchain
(see this post), I quickly realized that any full-text query & retrieval implementation could also be done via the lightweight and open-source embeddings models out there.
I jump-started with ChromaDB and its default embeddings model, which fortunately is quite slim: the 80 MB all-MiniLM-L6-v2
model from the SentenceTransformers framework, available also in the HuggingFace Hub. It can embed 256-token sequences into a 384-dimensional space (each token is thus a 384-dimensional vector), and is advertised as suitable for many use cases and for basic embedding proximity/similarity searches.
Text embeddings turn words into numbers in a high-dimensional space (in this case, 384 dimensions). The dimensions can be chosen either by hand, or by creating a dictionary of unique word tokens in a dataset, or rather by training a fixed-size Embedding layer in a Neural Network. During this training, the network learns to assign similar vectors to words that have similar meanings or are used in similar contexts. The resulting parameterized embedding layer usually captures far more information about words and language than a simple one-to-one dictionary mapping.
We implement search by passing each sequence we want to be able to query through the embedding function/model and storing the calculated embedding vectors (one vector per embedded sequence). Then, we calculate the similarity between all the stored embeddings and our query
string, and fetch results that result in a high similarity ranking. Similarity is usually calculated as a distance measure between the two vectors, for example the angle between the two multi-dimensional vectors: if the angle is 0 degrees, they are completely similar, if it’s 90 degrees they are completely dissimilar.
Depending on the structure of the Flask app, there are subtle differences in how to implement the necessary parts. I use a blueprint-based structure, where the flask app is initialized in a factory module, so it goes like the following lines:
Here, the app/__init__.py
file contains the factory initialization function create_app
that is used by app.py
to create the app
variable and all related flask extensions.
# app.py
# ...
app = create_app(os.getenv("FLASK_CONFIG") or "default")
# ...
# app/__init__.py
# ...
def create_app(config_class='default'):
app = Flask(__name__)
app.config.from_object(config[config_class])
from .main import bp as main_bp
app.register_blueprint(main_bp)
from .auth import auth as auth_blueprint
app.register_blueprint(auth_blueprint, url_prefix='/auth')
# ...
return app
# ...
We initialize the ChromaDB in the app/__init__.py
file, and we make sure to specify a persistent directory for the calculated embedding vectors for each processed and stored text sequence.
# app/__init__.py
# ...
chroma = chromadb.Client(chromadbSettings(
chroma_db_impl="duckdb+parquet",
persist_directory="cache_chromadb",
))
# ...
Note we are using a locally run Chroma implementation. Much like using an SQLite database instead of a client/server database, however, this is not an ideal scenario when one is looking to scale the app. ChromaDB supports a client/server implementation via a docker image.
We now need to define some functions that will manipulate the ChromaDB collection objects during our application’s runtime.
Remember, as the implementation relates to the Journal Hub project, the search functionality is about searching for papers relevant to a full-text query, that have their respective SQLAlchemy models defined in app/models.py
. Each model can have separate searchable attributes, defined via the __searchable__
dunder attribute.
The approach is as modular as possible with respect to other db.model
s, but there will be need for a few if/else checks within the functions to differentiate the handling for each model class.
# app/models.py
# ...
class PaperShowcase(db.Model):
__searchable__ = ['source', 'identifier', 'secondary_identifier', 'abstract', 'title']
__tablename__ = 'papers'
id = db.Column(db.Integer, primary_key=True)
identifier = db.Column(db.String(50), unique=True, index=True)
secondary_identifier = db.Column(db.String(50), index=True)
abstract = db.Column(db.String(10000))
title = db.Column(db.String(10000))
# ...
We create a new app/search.py
file to contain these functions:
from flask import current_app
from chromadb.utils import embedding_functions
def add_to_collection(collection, model):
# collection builder for SQLalchemy model
c = current_app.chromadb.get_or_create_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
# searchable document is the concatenation of all relevant text fields
entry = ""
for field in model.__searchable__:
if getattr(model, field) is not None:
entry = entry + " " + getattr(model, field)
metadatas = [{'source': getattr(model, 'identifier')+getattr(model,'secondary_identifier') if getattr(model,'secondary_identifier') is not None else getattr(model,'identifier')}]
# note upsert instead of add
c.upsert(documents=[entry],
metadatas=metadatas,
ids=[f"{collection}_{str(getattr(model, 'id'))}"])
current_app.chromadb.persist() # needed?
def remove_from_collection(collection, model):
try:
c = current_app.chromadb.get_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
except Exception as e:
print(f"Could not find collection {collection} - {e}")
c.delete(ids=[getattr(model, 'id')])
current_app.chromadb.persist() # needed?
def query_collection(collection, query_str):
try:
c = current_app.chromadb.get_collection(collection, embedding_function=embedding_functions.DefaultEmbeddingFunction())
except Exception as e:
print(f"Could not find collection {collection} - {e}")
results = c.query(query_texts=[query_str])
ids = [int(i.split(f"{collection}_")[1]) for i in results['ids'][0]] # for isinstance(query,str) you unpack list of lists in results
scores = [float(i) for i in results['distances'][0]] # for isinstance(query,str) you unpack list of lists in results
return ids, scores
In add_to_collection
, we handle the logic for every occasion we want to embed some data stored under a model
class in our ChromaDB collection
object. We first create an entry to be embedded by concatenating all the searchable fields in the database row/instance. The metadata object required by ChromaDB requires a ‘source’ key-value pair, providing a reference source from the relevant model.identifiers
(sometimes a secondary_identifier
exists, but that’s just some extra logic). Thus, the embedded information for a paper will be referenced in the source by the paper’s identifier(s) in the database. The entry is then upserted in the collection, and the collection’s id
is the same as the database’s model.id
. The requisiteness of the current_app.chromadb.persist()
call is not clear for me, but I added it for good measure. The remove_from_collection
function just deletes an object from the collection.
When the add_to_collection
is run for the first time, the default embeddings model will be downloaded and stored in a local cache directory.
The querying of the ChromaDB collection happens at query_collection
, where there’s a query_str
as an argument. We call the query()
method of the collection, and we retrieve a results
dictionary. This contains ids
and calculated similarity distances
, as a list corresponding to our query string. A collection’s query()
method allows for a list of query strings, but in our case we provide a single one. We return the integer ids for each vector in the collection, along with the similarity scores calculated against our query.
Now that we have references to the ids and the scores, we can interface with SQLAlchemy queries.
I hinted at a Mixin Class earlier, and this is where it gets into play. We basically want to create a class from which our db.model
s will inherit a search()
method from. We also need to handle automatic synchronization between the SQL database and ChromaDB, at insertion, modification and deletion time of to-be-queried paper data.
Here’s the class:
# app/models.py
# ...
# searchable SQLalchemy mixin for generic searching functionality
class SearchableMixin(object):
@classmethod
def search(cls, expression, user, page=None, per_page=None):
ids, scores = query_collection(cls.__tablename__, expression, page, per_page)
if len(scores)==0:
return cls.query.filter(False), []
ranking = []
# print(ids)
for i in range(len(ids)):
ranking.append((ids[i],i))
results = db.session.query(PaperShowcase)\
.filter(cls.id.in_(ids))\
.order_by(db.case(*ranking, value=cls.id))
return results, scores
@classmethod
def before_commit(cls, session):
setattr(session, f"{cls.__tablename__}_changes", {
'add': [obj for obj in session.new if isinstance(obj, cls)],
'update': [obj for obj in session.dirty if isinstance(obj, cls)],
'delete': [obj for obj in session.deleted if isinstance(obj, cls)],
})
@classmethod
def after_commit(cls, session):
session_idx_change = getattr(session, f"{cls.__tablename__}_changes")
if session_idx_change:
for obj in session_idx_change['add']:
add_to_collection(cls.__tablename__, obj)
for obj in session_idx_change['update']:
add_to_collection(cls.__tablename__, obj)
for obj in session_idx_change['delete']:
remove_from_collection(cls.__tablename__, obj)
setattr(session, f"{cls.__tablename__}_changes", None)
current_app.chromadb.persist()
@classmethod
def reindex(cls):
for obj in cls.query:
add_to_collection(cls.__tablename__, obj)
# ...
For the search
method, we get ids
and scores
from the query_collection
function introduced previously in the app/search.py
module. The rest is a procedure to use the ids
as filters in a SQLAlchemy query. For this, we create a ranking list of tuples, that we can use in the order_by
method. This way, we receive SQL query results based on the similarity score. An edge case for 0 scores returned is also handled, corresponding basically to an empty database as this implementation should return something no matter how small the similarity score is.
The before_commit
and after_commit
class methods are defined so that they perform the necessary manipulations based on session-level changes. Note that the session
object here is a db.session
, not Flask’s session
cookie storage. We use these methods in SQLAlchemy listeners created at the end of the app/models.py
file:
# app/models.py
# ...
db.event.listen(db.session, 'before_commit', PaperShowcase.before_commit)
db.event.listen(db.session, 'after_commit', PaperShowcase.after_commit)
We also implement a reindex()
class method, to be used when an empty ChromaDB needs to be re-populated based on our relational database.
The final piece in this puzzle is the definition of an endpoint to handle the search requests. As done in Grinberg’s post, I assume a GET
-capable search form in the navbar (accessible from all routes, via flask’s g
global storage):
# app/main/routes.py
# ...
@main.before_app_request
def before_request():
g.search_form = SearchNavbarForm()
# ...
and now we can handle this with GET
requests:
# app/main/routes.py
# ...
@main.route('/search')
def search():
if not g.search_form.validate():
print(g.search_form.errors)
return redirect(url_for('main.index'))
results, scores = PaperShowcase.search(g.search_form.q.data, current_user)
results = [i.to_dict() for i in results]
return render_template('showcases/search.html',
type="papers",
query=g.search_form.q.data,
results=results,
scores=scores)
# ...
Here, the to_dict()
is a helper instance method for the PaperShowcase
model, allowing us to convert to JSON on the jinja frontend to use with javascript components. Similar handling could of course happen without converting from SQLAlchemy objects if there’s no such requirement.
Et voila! You have implemented full-text search functionality using modern text-embedding models!
Here’s how it works in the Journal Hub app:
As a reminder, note that to scale your development, you need to reconsider the local duckdb+parquet
implementation of ChromaDB and go to a client/server implementation. If you don’t, I’ve had inconsistencies and bugs pop-up even when using a gunicorn production server with many worker threads.