Late last year I read a blog post about our CSAM image scanning tool. I remember thinking: this is so cool! Image processing is always hard, and deploying a real image identification system at Cloudflare is no small achievement!
Some time later, I was chatting with Kornel: "We have all the pieces in the image processing pipeline, but we are struggling with the performance of one component." Scaling to Cloudflare needs ain't easy!
The problem was in the speed of the matching algorithm itself. Let me elaborate. As John explained in his blog post, the image matching algorithm creates a fuzzy hash from a processed image. The hash is exactly 144 bytes long. For example, it might look like this:
00e308346a494a188e1043333147267a 653a16b94c33417c12b433095c318012 5612442030d14a4ce82c623f4e224733 1dd84436734e4a5d6e25332e507a8218 6e3b89174e30372d
The hash is designed to be used in a fuzzy matching algorithm that can find "nearby", related images. The specific algorithm is well defined, but making it fast is left to the programmer — and at Cloudflare we need the matching to be done super fast. We want to match thousands of hashes per second, of images passing through our network, against a database of millions of known images. To make this work, we need to seriously optimize the matching algorithm.
Naive quadratic algorithm
The first algorithm that comes to mind has
O(K*N) complexity: for each query, go through every hash in the database. In naive implementation, this creates a lot of work. But how much work exactly?
First, we need to explain how fuzzy matching works.
Given a query hash, the fuzzy match is the "closest" hash in a database. This requires us to define a distance. We treat each hash as a vector containing 144 numbers, identifying a point in a 144-dimensional space. Given two such points, we can calculate the distance using the standard Euclidean formula.
For our particular problem, though, we are interested in the "closest" match in a database only if the distance is lower than some predefined threshold. Otherwise, when the distance is large, we can assume the images aren't similar. This is the expected result — most of our queries will not have a related image in the database.
The Euclidean distance equation used by the algorithm is standard:
To calculate the distance between two 144-byte hashes, we take each byte, calculate the delta, square it, sum it to an accumulator, do a square root, and ta-dah! We have the distance!
This function returns the squared distance. We avoid computing the actual distance to save us from running the square root function - it's slow. Inside the code, for performance and simplicity, we'll mostly operate on the squared value. We don't need the actual distance value, we just need to find the vector with the smallest one. In our case it doesn't matter if we'll compare distances or squared distances!
As you can see, fuzzy matching is basically a standard problem of finding the closest point in a multi-dimensional space. Surely this has been solved in the past — but let's not jump ahead.
While this code might be simple, we expect it to be rather slow. Finding the smallest hash distance in a database of, say, 1M entries, would require going over all records, and would need at least:
- 144 * 1M subtractions
- 144 * 1M multiplications
- 144 * 1M additions
And more. This alone adds up to 432 million operations! How does it look in practice? To illustrate this blog post we prepared a full test suite. The large database of known hashes can be well emulated by random data. The query hashes can't be random and must be slightly more sophisticated, otherwise the exercise wouldn't be that interesting. We generated the test smartly by byte-swaps of the actual data from the database — this allows us to precisely control the distance between test hashes and database hashes. Take a look at the scripts for details. Here's our first run of the first, naive, algorithm:
$ make naive < test-vector.txt ./mmdist-naive > test-vector.tmp Total: 85261.833ms, 1536 items, avg 55.509ms per query, 18.015 qps
We matched 1,536 test hashes against a database of 1 million random vectors in 85 seconds. It took 55ms of CPU time on average to find the closest neighbour. This is rather slow for our needs.
SIMD for help
An obvious improvement is to use more complex SIMD instructions. SIMD is a way to instruct the CPU to process multiple data points using one instruction. This is a perfect strategy when dealing with vector problems — as is the case for our task.
We settled on using AVX2, with 256 bit vectors. We did this for a simple reason — newer AVX versions are not supported by our AMD CPUs. Additionally, in the past, we were not thrilled by the AVX-512 frequency scaling.
Using AVX2 is easier said than done. There is no single instruction to count Euclidean distance between two uint8 vectors! The fastest way of counting the full distance of two 144-byte vectors with AVX2 we could find is authored by Vlad:
It’s actually simpler than it looks: load 16 bytes, convert vector from uint8 to int16, subtract the vector, store intermediate sums as int32, repeat. At the end, we need to do complex 4 instructions to extract the partial sums into the final sum. This AVX2 code improves the performance around 3x:
$ make naive-avx2 Total: 25911.126ms, 1536 items, avg 16.869ms per query, 59.280 qps
We measured 17ms per item, which is still below our expectations. Unfortunately, we can't push it much further without major changes. The problem is that this code is limited by memory bandwidth. The measurements come from my Intel i7-5557U CPU, which has the max theoretical memory bandwidth of just 25GB/s. The database of 1 million entries takes 137MiB, so it takes at least 5ms to feed the database to my CPU. With this naive algorithm we won't be able to go below that.
Vantage Point Tree algorithm
Since the naive brute force approach failed, we tried using more sophisticated algorithms. My colleague Kornel Lesiński implemented a super cool Vantage Point algorithm. After a few ups and downs, optimizations and rewrites, we gave up. Our problem turned out to be unusually hard for this kind of algorithm.
We observed "the curse of dimensionality". Space partitioning algorithms don't work well in problems with large dimensionality — and in our case, we have an enormous number of 144 dimensions. K-D trees are doomed. Locality-sensitive hashing is also doomed. It's a bizarre situation in which the space is unimaginably vast, but everything is close together. The volume of the space is a 347-digit-long number, but the maximum distance between points is just 3060 - sqrt(255*255*144).
Space partitioning algorithms are fast, because they gradually narrow the search space as they get closer to finding the closest point. But in our case, the common query is never close to any point in the set, so the search space can’t be narrowed to a meaningful degree.
A VP-tree was a promising candidate, because it operates only on distances, subdividing space into near and far partitions, like a binary tree. When it has a close match, it can be very fast, and doesn't need to visit more than
O(log(N)) nodes. For non-matches, its speed drops dramatically. The algorithm ends up visiting nearly half of the nodes in the tree. Everything is close together in 144 dimensions! Even though the algorithm avoided visiting more than half of the nodes in the tree, the cost of visiting remaining nodes was higher, so the search ended up being slower overall.
Smarter brute force?
This experience got us thinking. Since space partitioning algorithms can't narrow down the search, and still need to go over a very large number of items, maybe we should focus on going over all the hashes, extremely quickly. We must be smarter about memory bandwidth though — it was the limiting factor in the naive brute force approach before.
Perhaps we don't need to fetch all the data from memory.
The breakthrough came from the realization that we don't need to count the full distance between hashes. Instead, we can compute only a subset of dimensions, say 32 out of the total of 144. If this distance is already large, then there is no need to compute the full one! Computing more points is not going to reduce the Euclidean distance.
The proposed algorithm works as follows:
1. Take the query hash and extract a 32-byte short hash from it
2. Go over all the 1 million 32-byte short hashes from the database. They must be densely packed in the memory to allow the CPU to perform good prefetching and avoid reading data we won't need.
3. If the distance of the 32-byte short hash is greater or equal a best score so far, move on
4. Otherwise, investigate the hash thoroughly and compute the full distance.
Even though this algorithm needs to do less arithmetic and memory work, it's not faster than the previous naive one. See
make short-avx2. The problem is: we still need to compute a full distance for hashes that are promising, and there are quite a lot of them. Computing the full distance for promising hashes adds enough work, both in ALU and memory latency, to offset the gains of this algorithm.
There is one detail of our particular application of the image matching problem that will help us a lot moving forward. As we described earlier, the problem is less about finding the closest neighbour and more about proving that the neighbour with a reasonable distance doesn't exist. Remember — in practice, we don't expect to find many matches! We expect almost every image we feed into the algorithm to be unrelated to image hashes stored in the database.
It's sufficient for our algorithm to prove that no neighbour exists within a predefined distance threshold. Let's assume we are not interested in hashes more distant than, say, 220, which squared is 48,400. This makes our short-distance algorithm variation work much better:
$ make short-avx2-threshold Total: 4994.435ms, 1536 items, avg 3.252ms per query, 307.542 qps
Origin distance variation
At some point, John noted that the threshold allows additional optimization. We can order the hashes by their distance from some origin point. Given a query hash which has origin distance of A, we can inspect only hashes which are distant between |A-threshold| and |A+threshold| from the origin. This is pretty much how each level of Vantage Point Tree works, just simplified. This optimization — ordering items in the database by their distance from origin point — is relatively simple and can help save us a bit of work.
While great on paper, this method doesn't introduce much gain in practice, as the vectors are not grouped in clusters — they are pretty much random! For the threshold values we are interested in, the origin distance algorithm variation gives us ~20% speed boost, which is okay but not breathtaking. This change might bring more benefits if we ever decide to reduce the threshold value, so it might be worth doing for production implementation. However, it doesn't work well with query batching.
Transposing data for better AVX
But we're not done with AVX optimizations! The usual problem with AVX is that the instructions don't normally fit a specific problem. Some serious mind twisting is required to adapt the right instruction to the problem, or to reverse the problem so that a specific instruction can be used. AVX2 doesn't have useful "horizontal" uint16 subtract, multiply and add operations. For example,
_mm_hadd_epi16 exists, but it's slow and cumbersome.
Instead, we can twist the problem to make use of fast available uint16 operands. For example we can use:
- and _mm256_add_epu16.
add would overflow in our case, but fortunately there is add-saturate _mm256_adds_epu16.
add is great and saves us conversion to uint32. It just adds a small limitation: the threshold passed to the program (i.e., the max squared distance) must fit into uint16. However, this is fine for us.
To effectively use these instructions we need to transpose the data in the database. Instead of storing hashes in rows, we can store them in columns:
So instead of:
- [a1, a2, a3],
- [b1, b2, b3],
- [c1, c2, c3],
We can lay it out in memory transposed:
- [a1, b1, c1],
- [a2, b2, c2],
- [a3, b3, c3],
Now we can load 16 first bytes of hashes using one memory operation. In the next step, we can subtract the first byte of the querying hash using a single instruction, and so on. The algorithm stays exactly the same as defined above; we just make the data easier to load and easier to process for AVX.
The hot loop code even looks relatively pretty:
With the well-tuned batch size and short distance size parameters we can see the performance of this algorithm:
$ make short-inv-avx2 Total: 1118.669ms, 1536 items, avg 0.728ms per query, 1373.062 qps
Whoa! This is pretty awesome. We started from 55ms per query, and we finished with just 0.73ms. There are further micro-optimizations possible, like memory prefetching or using huge pages to reduce page faults, but they have diminishing returns at this point.
If you are interested in architectural tuning such as this, take a look at the new performance book by Denis Bakhvalov. It discusses roofline model analysis, which is pretty much what we did here.
Do take a look at our code and tell us if we missed some optimization!
What an optimization journey! We jumped between memory and ALU bottlenecked code. We discussed more sophisticated algorithms, but in the end, a brute force algorithm — although tuned — gave us the best results.
To get even better numbers, I experimented with Nvidia GPU using CUDA. The CUDA intrinsics like
dp4a fit the problem perfectly. The V100 gave us some amazing numbers, but I wasn't fully satisfied with it. Considering how many AMD Ryzen cores with AVX2 we can get for the cost of a single server-grade GPU, we leaned towards general purpose computing for this particular problem.
This is a great example of the type of complexities we deal with every day. Making even the best technologies work “at Cloudflare scale” requires thinking outside the box. Sometimes we rewrite the solution dozens of times before we find the optimal one. And sometimes we settle on a brute-force algorithm, just very very optimized.
The computation of hashes and image matching are challenging problems that require running very CPU intensive operations.. The CPU we have available on the edge is scarce and workloads like this are incredibly expensive. Even with the optimization work talked about in this blog post, running the CSAM scanner at scale is a challenge and has required a huge engineering effort. And we’re not done! We need to solve more hard problems before we're satisfied. If you want to help, consider applying!