Categorizing illustrations using deep learning: Part 2.
Welcome back.
This is the second part of a two-part series.
- Part 1: Backstory, data gathering, model training, and experimentation.
- Part 2: Deployment and search integration
In part 1 we trained, tuned, and tested a CNN to categorize illustration styles at Unsplash. In this post, we’ll get it out of the lab and into production.
The current state of affairs
Well… it’s been, what, six months? Time flies. To be honest, the project has been running in production for quite some time, but I’ve been procrastinating writing this follow-up post.
I’ve got my enhanced coffee (read: anti-procrastination fuel), so let’s do this.
Just to be crystal clear, we are deploying an ML model to categorize illustration styles on Unsplash. The goal is to improve search and discovery of illustrations by allowing users to filter results by style.
Here’s a little sneak peek of the feature in action:
Let’s now rewind a little bit and see how we got here.
The inference phase
The model has been trained and is ready for its first center-stage appearance. The time has come to set it loose. This is the inference phase: predictions step out of the lab and into the real world. We deploy, we integrate, and we listen closely as the model speaks.
Of course, the real world comes with network issues, IAM roles, Docker images, and the occasional existential crisis.
MLOps and infrastructure work are an integral part of developing and deploying machine learning models. As with the rest of this project, I had to learn how to deploy models on the go. At Unsplash, we were already all-in on AWS so naturally this was the first place I looked.
After skimming the options, I landed on SageMaker, AWS’s managed service for building, training, and deploying machine learning models at scale. SageMaker has a lot of features and figuring out what exactly we needed took some time. First, we need to clarify the two main ways of running inference in production:
- Online inference: latency sensitive, over the network HTTP endpoint for real-time predictions.
- Batch inference: offline processing of input batches.
Usually, the differentiating factors for choosing either are latency requirements, access, and cost. Does the model need to respond in real-time to user requests or can we afford processing inputs later?
The search feature we were trying to build required categorizing illustration styles. Essentially, we wanted to store the predicted category alongside the image metadata in our database and make it available at search time. This task doesn’t require the model to be always available and respond in real time. Instead, we can batch process a bunch of illustrations and send them to the model every once in a while. This makes batch inference the perfect fit for our use case.
And guess what? We decided to do online inference 😅. Online inference wasn’t the cheapest option, but it was the lowest-friction option and we cared more about shipping a usable feature than building a perfect pipeline upfront.
Sometimes the right architecture is the one that ships before your model becomes a forgotten MLflow artifact.
Having a live endpoint anyone in the company could invoke and get predictions right away was the most straightforward version. Batch inference would have required more moving parts: scheduling transform jobs, storing the results in S3, having the backend read from S3 and store it in Postgres for finally making it available in the search index. We’re not ruling out the possibility of switching to batch inference in the future if we find out that online inference becomes too costly.
Deploying the model
If you read the previous post, you’ll know that we tracked our model parameters using MLflow. We also logged the trained model as a pyfunc model so it could be easily loaded back up for inference.
As it turns out, MLflow has built-in support for deploying models to SageMaker endpoints. It handles creating the docker image and wires up most of the SageMaker deployment plumbing. You end up with a network-accessible endpoint that serves predictions following MLflow serving API.
Without further ado, let’s create the Docker image.
# A unique identifier for models tracked in MLflow.MODEL_URI="models:/<model-id>"
mlflow models generate-dockerfile --model-uri "$MODEL_URI" --output-directory .Before we continue, I should mention that MLflow includes a default inference server as part of its Docker image. In fact, that’s exactly what generate-dockerfile does. The container entrypoint is an MLflow server pointing to this function (for pyfunc models).
This may work for you, but for us it wasn’t ideal: we saw overhead on the order of seconds per request. Switching to a custom FastAPI server brought latency down to a few hundred milliseconds.
I commented on an issue in the MLfLow repo to try and get more clarity.
Our server wraps FastAPI and provides the necessary paths to comply with SageMaker.
5 collapsed lines
from fastapi import FastAPI, Depends, Request, Response, statusfrom contextlib import asynccontextmanagerfrom fastapi.responses import JSONResponsefrom mlflow.pyfunc import load_model, PyFuncModelimport pydanticimport os
def get_model(req: Request) -> PyFuncModel: return req.app.state.model
@asynccontextmanagerasync def lifespan(app: FastAPI): MODEL_PATH = os.getenv("MODEL_PATH")
# Our model was logged as a `pyfunc` model so we can load it back up. # https://mlflow.org/docs/latest/ml/traditional-ml/tutorials/creating-custom-pyfunc/part2-pyfunc-components/ model = load_model(MODEL_PATH)
app.state.model = model yield
app = FastAPI(lifespan=lifespan)
class RequestBody(pydantic.BaseModel): inputs: list[str] = pydantic.Field(min_length=1)
# This must be named "invocations" to comply with SageMaker@app.post("/invocations")async def invocations(body: RequestBody, model: PyFuncModel = Depends(get_model)): try: return { "predictions": model.predict(body.inputs), } except Exception as e: return JSONResponse( content={"error": "Unknown error"}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, )
# This must be named "ping" to comply with SageMaker@app.route("/ping", methods=["GET"])async def ping(_request: Request): return Response(content="\n", status_code=status.HTTP_200_OK)
if __name__ == "__main__": import uvicorn
port = int(os.getenv("PORT", 3000)) host = os.getenv("HOST", "0.0.0.0")
uvicorn.run( "server:app", port=port, host=host, )Nothing too fancy. We need to create two endpoints: /ping for health checks and /invocations for predictions. Both are required by SageMaker. The model is loaded once at server startup and stored in the application state.
Next up is to start this server when the container starts. We can’t point the container directly to our server script because SageMaker will call the entrypoint with a "serve" argument. Since we would also like to be able to run the server locally for testing, we can create a small entrypoint.py script that will launch our FastAPI server when called with serve inside SageMaker.
SageMaker AI overrides default CMD statements in a container by specifying the serve argument after the image name. The serve argument overrides arguments that you provide with the CMD command in the Dockerfile.
"""This entrypoint is called within the docker container. SageMaker overrides the docker image CMD and call it with the `serve` command.Upon calling `serve`, we need to launch FastAPI.https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-image"""
import osimport sys
def main(): # Let SageMaker pass "serve", we translate it to FastAPI CLI if len(sys.argv) > 1 and sys.argv[1] == "serve": # Exec Python so it becomes PID 1 and receives SIGTERM/SIGINT properly os.execvp(sys.executable, [sys.executable, "-u", "-m", "inference.server"])
# Fallback: allow other subcommands (e.g., "bash", "python", etc.) os.execvp(sys.argv[1], sys.argv[1:])
if __name__ == "__main__": main()We now need to tell Docker to use our entrypoint for the container. We can do this by modifying the generated Dockerfile.
ECR_REPO="xxx.dkr.ecr.us-east-1.amazonaws.com"ECR_IMAGE_QUALIFIED_NAME="$ECR_REPO/unsplash-categorize-illustrations:latest"# A unique identifier for models tracked in MLflow.MODEL_URI="models:/<model-id>"
mlflow models generate-dockerfile --model-uri "$MODEL_URI" --output-directory .
# SageMaker will override the CMD and call it with `serve`. We mimic this behaviour when running the container locally.# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-imagesed -i '/^ENTRYPOINT/ i\CMD ["serve"]' Dockerfile
# Replace the ENTRYPOINT command to use the entrypoint scriptsed -i 's|^ENTRYPOINT.*|ENTRYPOINT ["python", "-u", "/opt/inference/entrypoint.py"]|' Dockerfile
docker build -f Dockerfile -t "$ECR_IMAGE_QUALIFIED_NAME" .We can push to the ECR repository next.
Note: for the following I’ll assume some prior knowledge about ECR and IAM roles. Getting into those is beyond the scope of this post.
ECR_REPO="xxx.dkr.ecr.us-east-1.amazonaws.com"ECR_IMAGE_QUALIFIED_NAME="$ECR_REPO/unsplash-categorize-illustrations:latest"# A unique identifier for models tracked in MLflow.MODEL_URI="models:/<model-id>"AWS_DEFAULT_REGION="us-east-1"
mlflow models generate-dockerfile --model-uri "$MODEL_URI" --output-directory .
# SageMaker will override the CMD and call it with `serve`# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-run-imagesed -i '/^ENTRYPOINT/ i\CMD ["serve"]' Dockerfile
# Replace the ENTRYPOINT command to use the custom inference server scriptsed -i 's|^ENTRYPOINT.*|ENTRYPOINT ["python", "-u", "/opt/inference/entrypoint.py"]|' Dockerfile
docker build -f Dockerfile -t "$ECR_IMAGE_QUALIFIED_NAME" .
aws ecr get-login-password --region "$AWS_DEFAULT_REGION" | docker login --username AWS --password-stdin "$ECR_REPO"docker push "$ECR_IMAGE_QUALIFIED_NAME"With the image pushed to ECR, we can now create a SageMaker model and endpoint using the MLflow CLI.
SAGEMAKER_ENDPOINT="some-endpoint"
mlflow deployments update \ --target "sagemaker:/$AWS_DEFAULT_REGION" \ --config image_url="$ECR_IMAGE_QUALIFIED_NAME" \ --config execution_role_arn=arn:aws:iam::some-exec-role \ --config bucket=bucket-to-store-sagemaker-models \ --model-uri "$MODEL_URI" \ --config instance_type=<instance-type> \ --config env="{\"MODEL_ID\":\"$MODEL_URI\",\"MODEL_ENV\":\"production\"}" \ --config instance_count=1 \ --name "$SAGEMAKER_ENDPOINT"
# Optional, tag the endpoint with the model ID and environmentENDPOINT_ARN=$( aws sagemaker describe-endpoint \ --endpoint-name "$SAGEMAKER_ENDPOINT" \ --query 'EndpointArn' \ --output text)
aws sagemaker add-tags \ --resource-arn "$ENDPOINT_ARN" \ --tags \ Key=mlflow_model_id,Value="$MODEL_URI" \ Key=environment,Value=productionAnd just like that, we have a live SageMaker endpoint serving our model! Hopefully it was smooth sailing for you too — AWS infrastructure is always a breeze am I right?
The endpoint is now ready to accept requests and return predictions for images. You should note that SageMaker endpoints are not internet-facing by default. You can either use the aws SDK or set up a gateway to access it from the outside.
In our case, we’ve decided to leverage the SDK.
aws sagemaker-runtime invoke-endpoint \ --endpoint-name your-endpoint \ --body '{"inputs": ["https://plus.unsplash.com/premium_vector-1755095018170-a30cc2a1d6a4?q=80&w=800&auto=format&fit=crop"]}' \ --content-type application/json \ --accept application/json \results.json{ "predictions": [ { "pred": "hand_drawn", "probs": { "3d": 0.0022154166363179684, "flat": 0.023129383102059364, "hand_drawn": 0.9739536046981812, "line_art": 0.0007015919545665383 } } ]}The model thinks our image is most likely hand-drawn. Since this is a discriminative model, it will give a probability distribution over all classes. In this case, the model is quite confident about its prediction.
Do you confirm it’s hand-drawn?
I think it is.
Search integration
Now the model is deployed, it’s time to put it to good use. Earlier, I showed a small sneak peek of the search feature we’ve built around the model predictions. You can try it yourself on Unsplash.
Notice we’ve designed a set of filters that allow users to filter illustrations by style. These filters are powered by the model predictions. When someone uploads an illustration, we enqueue it. A worker picks up the job and calls the SageMaker endpoint to get predictions for the image. The style is stored alongside the image metadata and indexed during search.
Every new illustration goes through a human review before it’s published. During review, the team can sanity-check the model’s prediction and correct it when needed. This allows us to build a feedback loop and collect more labeled data for future training sessions.
As with any non-deterministic system, you can imagine the model will make mistakes from time to time. We monitor model performance closely and make sure it doesn’t degrade over time. If we start noticing a drop in performance, the plan is to retrain the model with fresh data.
Additionally, there are many other things we can do to improve the model:
- Experiment with different model architectures
- Hyperparameter tuning
- Data augmentation techniques
- Ensemble methods
- Drift detection
In the future, I could write up a whole blog post about future model improvements. We’re still quite early in the process and learning a lot along the way.
Conclusion
Well, that’s it for now. This wraps up the two-part series on categorizing illustration styles at Unsplash. I hope you found it insightful. It was certainly eye-opening for me.
A lot of work, a ton of learning, and honestly a lot of fun.
If you have any questions or feedback, feel free to reach out!