Categorizing illustrations using deep learning: Part 1.

In this post, I’ll share how I convinced Unsplash to train a custom neural network to categorize illustrations, so people can find the right image faster. Machine learning isn’t really part of our core expertise, so this was something I pushed for. I’ll assume some ML knowledge, but I’ll link helpful resources too. If you’re curious about putting ML to work in real projects, read on.

This is the first part of a two-part series:

  • Part 1: Backstory, data gathering, model training, and experimentation.
  • Part 2: Deployment and search integration (Coming soon)

Humble beginnings

I’ve been really interested in machine learning for a while now. I haven’t studied it formally, but as a software engineer, I wanted to learn what I could do with it. My journey began when my coworker @luke introduced me to the fastai course. It was eye-opening to see how easy it is to get started with ML. I worked my way through the course and learned a lot. It was intense and the material was dense, but I pushed through and the effort eventually paid off (more on that later).

I needed to start applying what I was learning, so my first project was to train a neural net to classify images of my cat. Getting decent results with very little time and effort was incredibly motivating. I decided to take on a more challenging problem: building a model to predict AQI levels in Montreal, QC. I was quite happy with the outcome. It wasn’t perfect, but it exceeded what I thought I could achieve. I was getting more comfortable with fastai and core ML concepts. I also explored the math behind it all from 3blue1brown, and things started to click (I’ve forgotten some of it by now).

Those two projects were fun, but I wanted to do more. I wanted to build something real, something users could actually use and I wanted to work on it full-time. Around that time, I came across Designing Machine Learning Systems by Chip Huyen. The timing felt perfect, so I picked it up and devoured every page. I was excited by what I was reading, all the tools and techniques for bringing an ML project into production: build, train, deploy, monitor, and iterate.

I wanted to do that.
I was ready to find a use case for Unsplash.

Do I CLIP or do I not?

We recently added support for illustrations on Unsplash, and while talking with my coworker @kirillz, he showed me some new search filters. Users could narrow down illustrations by style, like “line art” or “3D.” I thought it was cool, but then I asked:

Do we have the data to confidently filter illustrations by these styles?

We did not have it.

I was between tasks at the time, and I figured this was a good opportunity to practice what I’d learned.

Before investing time into training a model, I decided to leverage one of the key technologies behind our search: CLIP. I’m not going to dive deeply into what CLIP is or how it works, but in short, we use it for semantic search. By comparing text and image embeddings, we can find the most relevant results for a given search query.

Jeremy Howard explains how CLIP embeddings work in one of his Stable Diffusion lessons.
If you’re caffeinated enough, I highly recommend watching the whole lesson, it’s a helpful way to understand CLIP’s role in the diffusion process.

My initial plan was to craft prompts, encode them into embeddings, and compare them to image embeddings to find illustrations that matched certain styles, like “line art” or “3D.” In a categorization context, this approach is called zero-shot classification. Here are some of the prompts I tried:

styles = {
"3d": "3D illustrations, 3D image, shading, lighting, realistic materials",
"flat": "flat illustrations, solid color areas, clean edges, minimal shading",
"hand_drawn": "hand drawn illustrations, pencil, ink, digital brush.",
"line_art": "line art illustrations. outlines, minimal shading"
"cartoonish": "cartoon illustrations, playful, exaggerated, bright, pastel colors",
"doodle": "doodle illustrations, casual, freeform look, sketch feel",
"minimalistic": "minimalistic illustrations, few colors, negative space",
"isometric": "isometric illustrations, shapes, symmetry, grid, angular forms",
"abstract": "abstract illustrations, shapes, form",
"retro": "retro and vintage illustrations, muted or grainy colors, typography, printing",
"realistic": "realistic illustrations, accurate proportions, realism"
}

While I was getting some results, I didn’t think they were good enough. To understand why CLIP wasn’t returning what I expected, I started digging into the embeddings it produced because I wanted to see what kind of information was being captured.

To analyze the embeddings more effectively, we can create clusters of images based on their dimensions and visualize how they relate to each other using KNN.

Here we have roughly 20 clusters. This is what it looks like when plotted in FiftyOne.

Visualization of the different clusters

I’ve highlighted one of the clusters so you can visually see how CLIP embeddings can be grouped together.

I discovered that CLIP captures a lot of information about themes and subjects, but much less about aesthetic styles. It tends to pick up on similarities like subject matter, number of people, colors, and shapes, but not as much on specific styles like “3D” or “hand-drawn.”

In the example above, we can see it grouping images that contain a single individual, with similar color palettes and backgrounds.

Perhaps there were some dimensions capturing aesthetics, but they weren’t dominant enough to show up in the clusters.

To wrap up this experiment, I tried various dimensionality reduction techniques like PCA and UMAP to downplay the dominant features that might have been working against me and surface the aesthetic components of the embeddings. But that didn’t produce noticeably better results.

Maybe CLIP wasn’t the best tool for this job — it was time to try something else.

A new model

I was convinced a vision model would be able to categorize illustrations, it didn’t seem like an impossible problem to me.

I went on a quest to build a proof of concept. Gathering data was the easy part, since I had the whole Unsplash library at my disposal. I downloaded a few thousand illustrations, asked the team to help me label them and had a dataset ready to go. We didn’t really have official labels yet but that did not matter much. I used placeholders like “line art” or “isometric” just to get started, and to explore what future search feeds could look like when filtered by style.

With the dataset in my hand, it was time to get coding.

Since I had used fastai before, I decided to stick with it. If you haven’t tried fastai before, here’s a quick example of how easy it is to train a vision model.

Consider I had a CSV with my labels like this:

filename,label
/images/image1.jpg,line_art
/images/image2.jpg,isometric
/images/image3.jpg,hand_drawn

We can write a training script like so:

from fastai.vision.all import *
import pandas as pd
df = pd.read_csv('labels.csv')
dblock = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_x=ColReader("filename"),
get_y=ColReader("label"),
splitter=RandomSplitter(valid_pct=0.2, seed=42), # 80/20 train/val split
item_tfms=Resize(224),
batch_tfms=aug_transforms(size=224)
)
dls = dblock.dataloaders(df)
learner = vision_learner(dls, resnet34, metrics=..., loss_func=...)
learner.fine_tune(5)

There’s obviously a lot more nuance to it, but this is the gist. This is how you use transfer learning to train a vision model using fastai. You might notice I left out metrics and the loss function. We’ll talk about those later.

I was confident I could pull it off, so I wrapped up my proof of concept and presented it to my team. To my surprise, they were onboard! After a bit of back and forth, I got the green light to develop this model officially.

I convinced my company to train a neural network to categorize illustrations.

Now that I was officially full-time on this project, what was the next step?
Good question. I had to figure it out as I went. I had no prior experience in building ML systems for production, so there was a lot to learn. New tools, new workflows, and plenty of unknowns. I had to learn how to gather more data, how to label it, how to train the model, how to track metrics, how to deploy it and finally how to integrate the predictions in our search system.

Gathering data

Gathering more data was the first, and luckily, one of the easier ones. Since it was mostly a software engineering task, I won’t go too deep into the details.

I had access to the whole Unsplash library. I put together a few Jupyter notebooks to download, preprocess and store images in a postgres database.

Labeling the images was fairly straightforward. We built a simple web app to associate images with labels, and a few members of the team helped us with tagging. The tricky part was making sure we were all labeling things consistently. That’s crucial for training a reliable model.

We found that many images could reasonably fit into multiple styles, but we wanted each image to have a single label. So we had to agree on a clear set of rules to follow while labeling.

In the end, we settled on four different labels:

Once we had enough data, we created a dataset from a snapshot of the database.

We take a snapshot to freeze the data in time, which helps with reproducibility. That way, if we ever need to understand exactly what went into training, we’re not relying on live data that might have changed, for example, if someone updates the label on an existing image. We also use the snapshot to create our validation and test sets, so the results are fully deterministic.

Evening out the dataset

After analyzing the distribution of the labels in the dataset, I realized it was quite imbalanced. We had tens of thousands of images for some labels, but only a few thousands for others. This kind of imbalance can lead to a model that’s biased toward the majority class.

For example, here’s what the distribution looked like in a dataset of around 30k images:

Label distribution in the dataset

As you can see, the distribution is quite uneven. The label flat is dominating, followed by hand_drawn. We are seriously lacking 3d and line_art illustrations. The obvious solution would be to gather more images for the underrepresented classes, but that’s easier said than done.

Since we didn’t have any prior data explicitly categorizing images as “3d” or “line art,” it was difficult to target those styles directly. Instead, we had to bulk import a large number of images and rely on human labeling to find the ones that fit.

We did try using CLIP and Gemini to auto-label images, but it wasn’t reliable enough. We still needed humans to review and correct the labels, so the amount of work ended up being pretty much the same.

Aside from getting more data, there are a few other ways to deal with class imbalance. I’ll break the next section down into three parts:

  • Sampling strategies
  • Loss function
  • Metrics

Sampling strategies

Before diving into sampling, it’s worth mentioning the strategy we used to split the dataset into train, validation, and test sets.

Since the training data is heavily imbalanced, skewed towards flat, we can reasonably assume that production data will have a similar distribution. In other words, the model will most likely encounter more flat illustrations. With that in mind, we want our validation and test sets to reflect the kind of data the model will see in the real world.

To achieve that, we used a stratified split. This allowed us to preserve the original label distribution across all sets, making sure the model is evaluated on a sample that’s truly representative of what it will encounter in production.

We can use train_test_split from sklearn to do just that:

from sklearn.model_selection import train_test_split
def stratified_split(
df, train_frac=0.8, valid_frac=0.15, test_frac=0.05, random_state=42
):
# 1a. Reserve test set
df_trainval, df_test = train_test_split(
df, test_size=test_frac, stratify=df["label"], random_state=random_state
)
# 1b. Split train vs. validation
valid_adj = valid_frac / (train_frac + valid_frac)
df_train, df_valid = train_test_split(
df_trainval,
test_size=valid_adj,
stratify=df_trainval["label"],
random_state=random_state,
)
return df_train, df_valid, df_test
df_train, df_valid, df_test = stratified_split(df)

I’m mentioning splitting here because it’s important to perform this step before applying any sampling strategy to the training set. Otherwise, you risk introducing data leakage. Data leakage can happen if, during sampling, you accidentally create overlap or correlations that end up in both the training and validation/test sets. This can prevent the model from generalizing properly, inflate your evaluation metrics, and lead to overfitting.

Now that we have our train, validation and test sets in place, we can apply sampling techniques to the training set to help mitigate the imbalance. Here are a couple of strategies we considered:

  • Undersampling the majority class
  • Oversampling the minority class

I wasn’t particularly a fan of oversampling, since it usually involves duplicating existing images—sometimes with slight variations like padding, cropping, or skewing. That might be useful later as a form of augmentation, but for now, I decided to focus on undersampling the majority class instead.

I implemented a soft undersampling of the majority class, aiming to reduce the number of flat images so it’s closer to the size of the minority classes. I didn’t want to discard too much data, so I introduced a tolerance threshold, something we could experiment with and adjust over time.

Implementing that in python was quite straightforward:

"""
We'll tolerate a bit of imbalance.
Ideally, we'll need to reduce this number as we acquire more minority samples.
A low number means that we won't let majority classes dominate the dataset too much, bringing more balance between classes but less samples overall.
A high number means we will allow majority to unbalance the dataset, but we will have more samples overall.
"""
SOFT_SAMPLING_TOLERANCE = 1
counts = df_train["label"].value_counts()
min_count = counts.min()
# build target count per class
target_counts = {
label: min(int(min_count * (1 + SOFT_SAMPLING_TOLERANCE)), count)
for label, count in counts.items()
}
# Soft sampling based on the target counts
dfs = []
for label, group in df_train.groupby("label", sort=False):
n = target_counts[label]
dfs.append(group.sample(n=n, random_state=42))
df_train_soft = pd.concat(dfs, ignore_index=True)

Various tests showed that lowering the tolerance too much actually hurt model performance. My guess is that it’s because we end up with too few images per class. The hard part is finding the sweet spot—enough reduction to address the imbalance, but not so much that the model lacks data to learn from.

Over time, this threshold should naturally decrease as we gather more data and the dataset becomes more balanced.

For example, here’s what our splits look like after sampling. Notice how the distribution is preserved in the validation and test sets. We can then be confident the model will be evaluated on a realistic distribution of the data.

Label distributions after sampling
splitcountpercent
soft_train1196866.370896
validation454825.221828
test15168.407276

Loss function

The loss function is another important part of the training process. I’m not going to go into a deep explanation here, there are plenty of great resources out there that cover it in detail.

In short, the loss is calculated during the forward pass and tells us how far off the model’s predictions are from the actual labels. Optimizers like SGD use this value to adjust the model’s weights during training, with the goal of minimizing the loss over time.

For categorization tasks, a standard choice is cross-entropy loss, and that’s what I’ve used, though with a small twist. As you might remember, our dataset is quite imbalanced. We can bake this imbalance in the loss function itself by penalizing the model more for missclassifying the minority classes. Since the model sees fewer examples of these classes during training, we want to make sure it pays extra attention to them. On the other hand, the majority classes are easier for the model to learn, simply because they appear more often.

To account for this, we assign a weight to each class and use a weighted cross-entropy loss, where the minority classes carry more weight during training.

With fastai, we can create a weighted loss function like so:

from fastai.vision.all import CrossEntropyLossFlat
import torch
"""
Compute a loss function that is weighted by the class imbalance in the dataset.
The model's optimizer will leverage this to penalize the model more for misclassifying the minority classes and penalize it less for misclassifying the majority classes. This will help the model learn more about misrepresented classes.
Remember, optimizers like to decrease the loss.
"""
train_labels = df.loc[df["split"] == "train", "label"]
counts = (
train_labels.value_counts(sort=False).reindex(dls.vocab, fill_value=0).values
)
counts = torch.tensor(counts, dtype=torch.float32)
# We'll build weight using inverse-frequency.
# Rare classes will get a bigger weight than the more frequent ones.
weights = 1.0 / counts
# Rescale the weights so that they sum to the number of classes.
# This is mostly to keep the loss function in a reasonable range so it doesn't scale weirdly by the sum of the weights.
weights = weights / weights.sum() * len(counts)
loss_fn = CrossEntropyLossFlat(weight=weights)

Metrics

Whereas the loss is used by optimizers to guide training, metrics are a proxy for humans to understand how well a model is performing on the task at hand.

Metrics help us define goals and track progress over time. These are usually the numbers you show your boss to earn a nice little tap on the back — when they’re improving of course.

Every single book I’ve read emphasizes how important it is to choose the right metrics for your task. If you don’t, you might end up optimizing for the wrong thing. That can lead to a bunch of wasted time troubleshooting instead of improving the model. In some cases, you might even have to start over from scratch. And then the boss won’t be so happy anymore.

Figuring out the metrics was a bit more challenging for me, since I had no prior experience building production models. After doing some research and reasoning about the problem I was trying to solve, I decided to use precision, recall, and their harmonic mean f1. I’d come across these metrics before when benchmarking embedding models. If you’re not familiar with them, I recommend pausing here and reading this section, it’ll help you follow along.

We can compute precision, recall and f1 on each class individually which gives us a clear picture of where the model struggles most. But to get a quick sense of overall performance, it’s helpful to have a single score, some kind of average across all classes.

There are three common ways to average these metrics, micro, macro and weighted. Each has its own use case, and it’s important to understand the difference. In my case, I wanted every class to be treated equally. Missclassifying a 3d image should matter as much as misclassifying a flat one. That’s why I chose to use macro averaging.

In fastai this looks like this

from fastai.vision.all import F1Score
metric = F1Score(average="macro")
  • macro computes the unweighted mean of the f1 scores for each class. Every class contributes equally, regardless of how many samples it has.
  • weighted also averages the f1 scores across classes, but weights each one by its support (i.e. the number of samples). This is useful if you want the performance on more common classes to have a bigger influence on the final score.
  • micro aggregates the contributions of all classes to compute a global score. In multi-class classification, it behaves similarly to accuracy.

You can learn more about the different strategies here.


Another metric I considered was accuracy.

It’s a very simple and intuitive metric. Basically, “out of all the predictions made, how many were correct?” If we made 100 predictions and 80 of them were right, then the accuracy is 80%. Sounds good, right? But now imagine this: out of those 100 predictions, 80 were from the majority class flat, 10 were hand_drawn, and the other 10 were 3d. If the model gets all the flat predictions right but completely fails on hand_drawn and 3d, the accuracy is still 80%. On paper that looks good, but in reality, the model is doing a terrible job with anything that isn’t the majority class.

That’s why I didn’t choose accuracy. It’s not a good metric for imbalanced datasets, as it tends to be skewed toward the majority classes.

Training

If we come back to our simple training script, we can now plug in the loss function and metrics we’ve just talked about:

def weighted_cross_entropy():
train_labels = df.loc[df["split"] == "train", "label"]
counts = (
train_labels.value_counts(sort=False).reindex(dls.vocab, fill_value=0).values
)
counts = torch.tensor(counts, dtype=torch.float32)
# We'll build weight using inverse-frequency.
# Rare classes will get a bigger weight than the more frequent ones.
weights = 1.0 / counts
# Rescale the weights so that they sum to the number of classes.
# This is mostly to keep the loss function in a reasonable range so it doesn't scale weirdly by the sum of the weights.
weights = weights / weights.sum() * len(counts)
return CrossEntropyLossFlat(weight=weights)
learner = vision_learner(dls, resnet34, metrics=F1Score(average="macro"), loss_func=weighted_cross_entropy())

We can then train the model for a couple of epochs:

learner.fine_tune(5)

If you haven’t noticed, there is something we haven’t covered yet: the model’s architecture. In the earlier snippet, we used resnet34. We can break it down into two parts:

  • The architecture: resnet short for residual network is a type of convolutional neural network (CNN) that is designed to train deeper networks more effectively by using skip connections.
  • The depth: 34, refers to the number of learnable layers. Deeper models can capture more complex patterns, but they’re also more prone to overfitting, especially with smaller datasets.

resnet34 was a solid starting point for our early iterations. We can always try deeper models like resnet50 or resnet101 down the line. There are also plenty of other CNN architectures we could experiment with in the future.

Note: I haven’t tried using any vision transformer models yet. I don’t understand them as well as CNNs, and they seem more complex and data-hungry. I’ll probably explore later once the basics are solid.

Experimenting

At this point, we’ve gathered the data, defined our metrics and loss function, and chosen a model architecture to start with. Now it’s time to experiment and see how well the model actually performs.

You can always run these experiments locally or on dedicated platforms like Google Colab and manually track the metrics yourself. But I wanted to create a framework to track experiments, test and tweak different settings to make it relatively easy to compare metrics across runs.

After a bit of research, I came across Mlflow. Needless to say, I was convinced. I deployed a self-hosted instance and started tracking various parameters and metrics.

fastai has a callback system that allows you to tap into the training process. We can leverage that and build a custom callback to log metrics at each epoch:

def format_metrics(names, values, vocab):
"""
names: list of metric names, e.g. ['train_loss','valid_loss','f1']
values: list of metric values (scalars or sequences)
vocab: list of labels to zip against when a value is a sequence
"""
out = {}
for name, val in zip(names, values):
if isinstance(val, np.ndarray):
for label, score in zip(vocab, val.tolist()):
out[f"{name}_{label}"] = score
else:
out[name] = val
return out
class MLFlowTrackingCallback(Callback):
# We want this tracking callback to run _after_ the fastai Recorder callback, otherwise the metrics won't be computed in time.
# https://github.com/fastai/fastai/blob/f4de849baa79b64792ae5e98ef37e5f8fa4be66f/fastai/learner.py#L548
order = 60
def __init__(self, freeze_epochs: int):
super().__init__()
self.freeze_epochs = freeze_epochs
self._parent_started = False
def before_fit(self):
if not self._parent_started:
self._parent_started = True
phase = "head" if self.n_epoch == self.freeze_epochs else "full"
active_run = mlflow.active_run()
if active_run is None:
raise RuntimeError(
"MLflow active run is not set. Please start a run before training."
)
# Start a child run
mlflow.start_run(run_name=f"{phase}-{active_run.info.run_name}", nested=True)
def after_epoch(self):
# Drop the first and last metric names, which are the epoch and time.
metric_names = self.recorder.metric_names[1:-1]
# Grab last epoch metrics.
values = self.recorder.values[-1]
metrics = format_metrics(metric_names, values, self.dls.vocab)
mlflow.log_metrics(metrics, step=self.epoch)
def after_fit(self):
# end the child run
mlflow.end_run()

You can track as many things as you want. The goal is to be able to reproduce the results of a given run, compare them, and understand if the model is improving.

hyper_params = {
"batch_size": 256,
"epochs": 5,
"freeze_epochs": 1,
"arch": "resnet34",
}
with mlflow.start_run():
mlflow.log_artifact("../pyproject.toml", "./model")
mlflow.log_params(hyper_params)
mlflow.log_params({"item_tfms": item_tfms_parsed})
mlflow.log_params({"batch_tfms": batch_tfms_parsed})
mlflow.set_tags(
dict(
git_sha=git_sha,
git_branch=git_branch,
arch=hyper_params["arch"],
dataset_id=dataset_id,
)
)
mlflow.log_text(
dataset_source_sql_string,
"dataset_source_query.sql"
)
learner = vision_learner(dls, resnet34, cbs=[MLFlowTrackingCallback(freeze_epochs=hyper_params["freeze_epochs"])])

Here’s what our mlflow currently looks like:

Unsplash's MLflow dashboard

As you can see, we’re already tracking a quite a bit. Here’s a quick summary of what we log:

  • Parameters: Hyperparameters like learning rate, batch size, number of epochs, and so on.
  • Metrics: f1, precision, recall. We track both the macro-averaged scores and the per-class metrics.
  • Artifacts: Files generated during training, such as the pickled model, as well as files needed to replicate the Python environment used to run the model.
  • Tags: Metadata like the model architecture, the dataset version, the Git SHA, and other context useful for reproducibility.

This isn’t an exhaustive list, but it gives you a good idea of what can be tracked. Moving forward, we also plan to take advantage of the Mlflow Model Registry and its deployment capabilities.

Observations

After running a series of experiments, our best run so far achieved an f1 score of 0.8129 on the test split. This is just the beginning, we’re planning to test many more ideas, but we’re already quite happy with the initial results.

MetricsValue
test_valid_loss0.5029
test_f1_macro0.8129
test_f1_3d0.7692
test_f1_flat0.8774
test_f1_hand_drawn0.7325
test_f1_line_art0.8727

If you remember the distribution of the training set, the metrics loosely correlate with the imbalance in the data. The model struggles the most with the 3d class, which isn’t surprising given how little data we have for it. Adding more 3d images should help improve those metrics.

Interestingly, the model doesn’t struggle as much with line_art, even though it was the second least represented class. My guess is that line_art is visually very distinct, which makes it easier for the model to recognize—possibly even more so than hand_drawn.

Here’s its confusion matrix for fun:

Test confusion matrix

A confusion matrix is a table that helps visualize how well a classification model is performing. Each row represents the actual class, and each column represents the predicted class—so it’s easy to spot where the model is getting things right or mixing them up.

Conclusion

Congratulations, you’ve made it! This post was getting a bit lengthy, so I decided to split it into two parts.

In this first part, we covered some of the backstory behind the project. We walked through the process of gathering data, defining metrics and a loss function, and training a model. We also touched on how we’re using Mlflow to track experiments and metrics.

A few things stood out for me so far:

  • Consistent labels are key — getting everyone involved to label things the same way was the real challenge.
  • Training our own model felt like the right response for this kind of task. I was confident a convolutional neural net could categorize the images well, and having our own model means we can keep improving it over time — something built-in solutions like CLIP don’t really allow.
  • Tracking experiments early saved a lot of guesswork later.

In the next part, we’ll dive into how we plan to deploy the model and use its predictions to power the search feature. That part hasn’t been built yet, so there’s still plenty of problem solving ahead. We’ll also cover our approach to iterating on the model and improving it over time.

Stay tuned! ✌️

📨 reach out

Last updated