Embed custom operations in ONNX graphs.

Recently, I’ve been working on a project where we needed to apply a small transformation to the embeddings coming out of a model and make sure that exact same transformation happened during inference, so everything stayed in the same embedding space. We didn’t want to re-implement that logic in three different services and then spend the next month debugging drift. So we reached for ONNX: bake the transformation into the ONNX graph itself, so clients only need a single model file and nothing else.

With a bit of ONNX graph surgery, you can update an existing model by adding, removing, or replacing mathematical operations. That’s a practical way to layer custom behavior on top of a pre-trained model without relying on external post-processing steps.

Open Neural Network Exchange (ONNX)

First, we need to quickly understand what an ONNX graph is. Their website is better suited to explain, but the TL;DR is that ONNX is an open-source format for representing machine learning models as a standardized computational graph. That graph is made up of inputs, outputs, and nodes, where each node represents a mathematical operation applied to tensors flowing through the graph.

ONNX models can be exported from frameworks like PyTorch or TensorFlow and then executed in a wide range of runtimes.

In our case, we had a model trained in PyTorch that now runs as an ONNX model inside a Ruby application, no Python interpreter involved. In fact, it’s powering a key feature of unsplash.com — Maybe a future blog post?

Hopefully, this gives you an idea of why ONNX is useful in environments where not all systems are running the same stack.

Peeking inside an ONNX graph

Since ONNX graphs are made of nodes, modifying those nodes means modifying the computation itself and ultimately the model output.

Let’s visualize the interior of a graph and peek into it to get a better intuition. I will be using Python for this but I’m sure you can find some equivalent libraries in other languages.

If you prefer a visual approach, tools like Netron let you interactively explore ONNX graphs.

We’ll need a base ONNX model to work with. I picked intfloat/multilingual-e5-base, mainly because it’s a model I’ve worked with extensively now and know it pretty well.

We can use the Hugging Face CLI to download the .onnx file:

Terminal window
hf download intfloat/multilingual-e5-base onnx/model.onnx --local-dir ./model
import onnx
model = onnx.load("./model/onnx/model.onnx") # ModelProto

At this point, we have a ModelProto object that contains the full graph definition.

Inputs and outputs

We’ve got the model loaded in memory. We can start by inspecting its inputs:

model.graph.input
[
name: "input_ids"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "sequence_length"
}
}
}
},
name: "attention_mask"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "sequence_length"
}
}
}
}
]

Alright, what have we got here? We can see the model takes two inputs: input_ids and attention_mask. If you’re not familiar with transformers, these are the standard inputs for most encoder-style models.

input_ids represents the tokenized text. The tokenizer breaks a sentence into smaller chunks called tokens. A token might be a whole word, or a letter, but most of the time it’s somewhere in between. Each token is then mapped to a unique integer from the model’s vocabulary.

It has a shape of (B, S), where:

  • B: Batch size (number of sequences processed together)
  • S: Sequence length (number of tokens per sequence, often padded/truncated to a given length)

attention_mask is also of shape (B, S) and indicates which tokens are “real” (1) versus “padding” (0).

Padding tokens are usually added by the tokenizer so every sequence in a batch ends up the same length (more on that later).

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-base")
tokenized = tokenizer(["Hello world"], return_tensors="pt")
tokenized["input_ids"] # tensor([[ 0, 35378, 8999, 2]])
tokenized["input_ids"].shape # torch.Size([1, 4])
tokenized["attention_mask"] # tensor([[1, 1, 1, 1]])
tokenized["attention_mask"].shape # torch.Size([1, 4])

Now, let’s check the outputs:

model.graph.output
[
name: "last_hidden_state"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch_size"
}
dim {
dim_param: "sequence_length"
}
dim {
dim_value: 768
}
}
}
}
]

This model has a single output last_hidden_state with shape (B, S, D) where:

  • B: Batch size
  • S: Sequence length
  • D: hidden dimensions (768 here)

This is where things get interesting. I thought my job was done, I was expecting an output of shape (B, D) instead. One embedding per input string.

Instead, we get token-level embeddings: one vector per token in the sequence.

We need a way to aggregate these token embeddings into a single sentence embedding. Going from (B, S, D) to (B, D). The term to describe this operation is pooling. There are different pooling strategies but looking at the model card, they recommend using mean pooling.

A python implementation of mean pooling looks like this:

import torch
def mean_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Parameters
----------
last_hidden_states : torch.Tensor
Tensor of shape (batch_size, sequence_length, hidden_size)
attention_mask : torch.Tensor
Tensor of shape (batch_size, sequence_length). 1 for real tokens, 0 for padding tokens.
11 collapsed lines
"""
# Expand mask to (B, S, 1) so it broadcasts across D
mask = attention_mask.unsqueeze(-1).to(last_hidden_states.dtype)
# Zero-out padded token embeddings, then sum across tokens
summed = (last_hidden_states * mask).sum(dim=1)
# Count non-padding tokens.
counts = attention_mask.sum(dim=1).unsqueeze(-1)
return summed / counts

This feels a bit magical at first: we use the attention mask to ignore padding tokens, sum the remaining token vectors, then divide by the number of real tokens to get an average embedding per sequence. We’ll deconstruct this step-by-step later.

This brings us back to our initial goal: clients shouldn’t need to remember to apply pooling (or any other transformation) after the model. Ideally, they load one model from disk and get back sentence embeddings.

ONNX graphs are made of nodes carrying mathematical operations. They take x inputs and produce y outputs. You can think of them as pure functions and compose them together to create more complex operations.

Mean pooling is really just averaging (summing and dividing) token-level embeddings.

We can break down this operation into a series of ONNX nodes:

  • Mul: multiply its inputs together. Useful for zeroing out padding token embeddings so they don’t affect the mean.
  • ReduceSum: sum the provided tensor along specified axes. For summing token embeddings together.
  • Div: divide its inputs. Ultimately needed for computing the mean.

In addition, we’ll have to slide in some additional ops to reshape tensors and make sure type and dimensions match.

  • Cast: convert the type of a tensor, e.g from fp32 to fp16.
  • Unsqueeze: add a dimension of size 1 at a specified axis.

With that in mind, we can start crafting those nodes.

Building pooling nodes

First, we need to understand how to create a node and attach it to the graph. Since nodes are like functions being chained together, the output of one node must match the expected input of the next.

Let’s start with an Identity node to build a better intuition. This node simply return its input.

We’ll feed it the current model output tensor last_hidden_state and have it output a new tensor last_hidden_state_2.

model = onnx.load("./model/onnx/model.onnx")
node = onnx.helper.make_node(
"Identity", inputs=["last_hidden_state"], outputs=["last_hidden_state_2"]
)
model.graph.node.extend([node])
# Verify the model is still valid
onnx.checker.check_model(model)

So far, so good: we’ve added a node, but the model output is still wired to the original last_hidden_state.

assert model.graph.output[0].name == "last_hidden_state"

To make the model return our new tensor, we need to update the graph output definition.

output = onnx.helper.make_tensor_value_info(
"last_hidden_state_2",
onnx.TensorProto.FLOAT,
["batch_size", "sequence_length", 768],
)
# Replace the model output
del model.graph.output[:]
model.graph.output.append(output)
onnx.checker.check_model(model)

We can verify the output has been updated:

assert model.graph.output[0].name == "last_hidden_state_2"

Yay.

Now that we know how to append nodes and rewire outputs, we can start implementing mean pooling step by step.

Casting attention mask

Attention masks are usually a list of integers (0 and 1). In frameworks like PyTorch or Numpy, operations usually handle type coercion automatically. ONNX is much stricter: tensor element types have to match exactly for most operations.

Since we’ll multiply the attention mask with the last hidden state (which is fp32) later, we need to cast the mask to fp32 as well.

nodes = []
nodes.append(
onnx.helper.make_node(
"Cast",
# We can reference model inputs directly
inputs=["attention_mask"],
outputs=["attention_mask_float"],
to=onnx.TensorProto.FLOAT,
)
)

This gives us an attention mask of type FLOAT that we can safely feed in to other nodes.

Unsqueezing attention mask

Here’s another preprocessing step for the attention mask. Recall the shapes:

  • attention mask: (B, S)
  • last hidden state: (B, S, D)

Since we’ll be multiplying these tensors together, their shapes need to be compatible. In PyTorch, you’d typically add a unit dimension with unsqueeze, turning (B, S) into (B, S, 1). That extra dimension allows the mask to broadcast across the embedding dimension.

We can do the same in ONNX using the Unsqueeze node:

nodes.append(
onnx.helper.make_node(
"Unsqueeze",
inputs=["attention_mask_float"],
outputs=["attention_mask_float_unsqueezed"],
axes=[2],
),
)

Note: this model targets ONNX opset 11, which is why some operators (like Unsqueeze) may look a bit different than in more recent versions.

attention_mask_float_unsqueezed is now of shape (B, S, 1) and type FLOAT. We can now start the real business.

Zero out padding tokens

During tokenization, when processing multiple text sequences of different lengths, the tokenizer will usually pad the shorter sequences. This is done by appending special padding tokens so that all sequences in a batch have the same length, which makes batching on GPUs more efficient.

For example:

from transformers import AutoTokenizer
tokenized = tokenizer(["Hello world"], return_tensors="pt", padding=True)
tokenized["input_ids"].shape # torch.Size([1, 4])
tokenized = tokenizer(
["Hello world", "This is a longer sentence."],
return_tensors="pt",
padding=True,
)
tokenized["input_ids"].shape # torch.Size([2, 7])

Notice how the length jumped from 4 tokens to 7. The shorter sequence was padded to match the longer one.

These padding tokens get embedded by the model, but don’t carry any semantic meaning and should be ignored when pooling for sentence embeddings. Fortunately, the tokenizer tells us which tokens are padding via the attention mask.

To filter out token embeddings corresponding to padding tokens, we can multiply the model output by the attention mask and zero them out.

In ONNX, we can use a Mul node to achieve this:

nodes.append(
onnx.helper.make_node(
"Mul",
inputs=["last_hidden_state", "attention_mask_float_unsqueezed"],
outputs=["masked_hidden_state"],
)
)

Here’s a small example to build some intuition. The multiplication is element-wise:

hidden_state = [
[0.5, 0.2, 0.1],
[1.5, 2.2, 3.1],
] # Shape: (1, 2, 3)
attention_mask_unsqueezed = [
[1],
[0],
] # Shape: (1, 2, 1)
# Broadcast to (1, 2, 3) during multiplication
result = hidden_state * attention_mask_unsqueezed
# result:
# [
# [0.5, 0.2, 0.1],
# [0.0, 0.0, 0.0],
# ]

We now have a hidden state with padding tokens zeroed out.

Summing token embeddings

Next up, we need to turn our hidden state from (B, S, D) into (B, D) to get sentence embeddings. Since the ultimate goal is to compute the mean of all token embeddings, we’ll start by summing them together and eliminate the S dimension in the process.

A ReduceSum node takes in the masked hidden state we computed earlier and sums token embeddings along the sequence length dimension. Naturally, zeroed-out tokens don’t contribute to the sum.

nodes.append(
onnx.helper.make_node(
"ReduceSum",
inputs=["masked_hidden_state"],
outputs=["sum_hidden_state"],
axes=[1], # Sum along the sequence length dimension
keepdims=0, # Drop the reduced dimension
)
)

For clarity, here’s a contrived example for a single batch element:

hidden_state = [
[0.5, 0.2, 0.0],
[0.0, 2.2, 3.1],
] # Shape is (1, 2, 3)
# Summing along axis 1 (the sequence length dimension)
sum_ = [
[0.5 + 0.0, 0.2 + 2.2, 0.0 + 3.1],
] # Shape is (1, 3)

Our new hidden state is now (B, D). We’re almost done.

Compute the mean

To compute the mean, we have to count how many non-padding tokens there were in the original sequence. Since the attention mask contains 1s for real tokens and 0s for padding, this is as simple as summing it along the sequence dimension.

Because we already unsqueezed the mask to (B, S, 1), this ReduceSum produces a (B, 1) tensor: one token count per sequence, with a trailing unit dimension that will come in handy for broadcasting.

nodes.append(
onnx.helper.make_node(
"ReduceSum",
inputs=["attention_mask_float_unsqueezed"],
outputs=["token_counts"],
axes=[1], # Sum along the sequence length dimension
keepdims=0,
)
)

And finally, we divide the summed hidden state (B, D) by the token counts (B, 1) and get the mean pooled embeddings.

nodes.append(
onnx.helper.make_node(
"Div",
inputs=["sum_hidden_state", "token_counts"],
outputs=["pooled_embeddings"],
)
)

Done.

Before you run off, we need to make sure the model output is now wired to our new pooled_embeddings tensor for our happy clients.

model.graph.node.extend(nodes)
output = onnx.helper.make_tensor_value_info(
"pooled_embeddings",
onnx.TensorProto.FLOAT,
["batch_size", 768],
)
# Replace the model output
del model.graph.output[:]
model.graph.output.append(output)
onnx.checker.check_model(model)
assert model.graph.output[0].name == "pooled_embeddings"

Alright! Our model is ready for prime time testing. Let’s save it to disk and create an inference session to try it out.

import onnxruntime as ort
onnx.save(model, "./model/onnx/multilingual-e5-base-pooled.onnx")
session = ort.InferenceSession("./model/onnx/multilingual-e5-base-pooled.onnx")
tokenized = tokenizer(
["Hello world", "This is a longer sentence."],
return_tensors="np",
padding=True,
)
output = session.run(
None,
dict(tokenized)
)
assert output[0].shape == (2, 768)

And that’s it. We’ve embedded a custom operation directly inside an ONNX graph by surgically adding nodes to it. Clients can now load a single model file and get sentence embeddings back directly—no extra post-processing required.

This approach extends naturally to other use cases. In our case, we’ve also embedded things like L2 normalization and mean debiasing on top of pooling. I won’t cover those here and will leave them as an exercise for the reader.

Conclusion

ONNX graph surgery is a powerful technique for embedding custom operations directly inside ONNX models. By understanding the structure of ONNX graphs and the available operations, you can extend existing models with additional behavior like pooling, normalization, or other post-processing steps.

In this post, we took a pre-trained embedding model and moved sentence pooling into the ONNX graph itself. The result is a single model file that produces sentence embeddings directly, with no extra client-side logic and no risk of subtle inconsistencies creeping in across environments.

If you’re deploying models across different languages or runtimes, this approach can offer an alternative to re-implementing post processing logic in each client, or relying on an external service.

📨 reach out

Last updated