Google's Tensor Processing Units: A Deep Dive into TPU Architecture and AI Acceleration in 2026
How Google's TPUs actually work — systolic arrays, the MXU, bfloat16, TPU pods and optical interconnects, the XLA/JAX stack, and when to choose TPUs over GPUs for training and inference.
Introduction
Every time you call a large language model, something has to multiply matrices — a lot of them, very fast. The economics of modern AI are, at bottom, the economics of matrix multiplication: how many multiply-accumulate operations you can do per second, per watt, per dollar. For most of the last decade the answer was "buy more GPUs." But there has always been a second answer, quietly powering Google Search, Translate, Photos, and Gemini behind the scenes: the Tensor Processing Unit.
I have spent years building and deploying LLM systems at Hureka Technologies, and one of the most consequential — and least understood — decisions a team makes is what silicon it runs on. GPUs get all the attention. TPUs get almost none, despite the fact that some of the largest models in the world are trained on them. That imbalance exists mostly because TPUs are less familiar, not because they are less capable.
This guide fixes that. By the end you will understand what a TPU actually is at the transistor-and-dataflow level, why the systolic array is such an elegant answer to the matrix-multiply problem, how Google wires tens of thousands of these chips into a single supercomputer, how the software stack (XLA, JAX, PyTorch/XLA) actually maps your model onto the hardware, and — the part that matters for your budget — when a TPU is the right choice over a GPU and when it absolutely is not. This is written for engineers who have to make real infrastructure decisions, not for a marketing deck.
What Is a Tensor Processing Unit? The Foundation
A Tensor Processing Unit is an application-specific integrated circuit (ASIC) designed by Google to do one family of things extraordinarily well: dense linear algebra, specifically the matrix multiplications and convolutions that dominate neural network workloads. The word "tensor" is the giveaway — the chip is purpose-built to move and multiply multi-dimensional arrays of numbers.
To understand why TPUs exist, you have to understand what they are not. A CPU is a latency-optimized generalist: a handful of very sophisticated cores with deep cache hierarchies and branch predictors, built to run arbitrary sequential code quickly. A GPU is a throughput-optimized generalist: thousands of simpler cores that run the same instruction across many data elements (SIMT), originally for graphics, later repurposed for parallel compute. A TPU is neither generalist — it is a domain-specific architecture. It throws away most of the flexibility of a CPU or GPU and spends the reclaimed silicon area on one thing: a giant hardware matrix multiplier.
The story begins around 2013. Google did an internal projection and realized that if every Android user used voice search for just three minutes a day, the company would need to roughly double its data center footprint to serve the neural network inference — using CPUs. That was economically impossible. The response was a crash program to build custom inference silicon. The first TPU (TPU v1) went into production in 2015 and was publicly revealed in 2016. It was inference-only, integer-based (int8), and it delivered order-of-magnitude improvements in performance-per-watt over the CPUs and GPUs of the day for Google's workloads.
Since then the lineage has advanced roughly every 12–24 months. TPU v2 (2017) added training support and floating-point via bfloat16. v3 added liquid cooling and roughly doubled throughput. v4 (2021) introduced reconfigurable optical interconnects between chips. The v5 generation split into v5e (efficiency, cost-optimized) and v5p (peak performance for the largest training runs). v6, codenamed Trillium, pushed a large per-chip compute and HBM jump, and by 2026 the seventh generation (Ironwood) is oriented heavily toward large-scale inference for reasoning models. The through-line across all seven generations is unchanged: maximize matrix-multiply throughput per watt, and make the chips easy to gang together by the thousand.
How It Works: The Architecture
The heart of every TPU is the Matrix Multiply Unit (MXU), and the MXU is built on a beautifully old idea from the 1970s: the systolic array. Understanding the systolic array is understanding the TPU.
In a conventional processor, computing \C = A × B\ means repeatedly reading operands from registers or cache, multiplying, accumulating, and writing back. Every one of those reads and writes touches the memory hierarchy, and memory access — not arithmetic — is what burns time and energy. The systolic array's insight is to stop moving data back to memory between operations. Instead, you build a physical grid of small processing elements (PEs), each of which does one multiply-accumulate (MAC), and you let data flow rhythmically through the grid, PE to PE, like blood pumping through tissue (hence "systolic"). Each value is read from memory once, then reused by every PE it passes through.
SYSTOLIC ARRAY MATRIX MULTIPLY (C = A x B)
──────────────────────────────────────────────────────────────Weights (B) are pre-loaded and held stationary in the grid. Activations (A) stream in from the left, one column per cycle. Partial sums accumulate downward and exit at the bottom.
a31 a21 a11 ─►┌─────┬─────┬─────┬─────┐ a32 a22 a12 ─►│ PE │ PE │ PE │ PE │ each PE: a33 a23 a13 ─►│ w11 │ w12 │ w13 │ w14 │ sum += a * w ├─────┼─────┼─────┼─────┤ pass a right │ PE │ PE │ PE │ PE │ pass sum down │ w21 │ w22 │ w23 │ w24 │ ├─────┼─────┼─────┼─────┤ │ PE │ PE │ PE │ PE │ │ w31 │ w32 │ w33 │ w34 │ └──┬──┴──┬──┴──┬──┴──┬──┘ ▼ ▼ ▼ ▼ partial sums accumulate out the bottom ```
A TPU MXU is typically a 128×128 grid of these MACs. That is 16,384 multiply-accumulate units doing useful work every single clock cycle once the pipeline is full. Because the weights sit still and the activations flow through, the data-reuse factor is enormous: a weight loaded once is used across an entire batch of activations, and an activation read once is multiplied against a whole row of weights. This is why a TPU can hit such high arithmetic intensity — the ratio of compute to memory traffic — which is exactly the ratio that determines whether you are compute-bound (good, the silicon is busy) or memory-bound (bad, the silicon is starving).
Surrounding the MXU are the supporting organs: a Vector Processing Unit (VPU) for elementwise operations, activations, and normalization; a large pool of High-Bandwidth Memory (HBM) stacked next to the die to feed the array; and a scalar unit plus Sparse Cores in modern generations that accelerate the giant embedding lookups common in recommender systems. Crucially, the whole thing is orchestrated by a compiler ahead of time — there is very little dynamic scheduling on-chip. The TPU trades runtime flexibility for a dead-simple, extremely dense, extremely efficient dataflow.
Why bfloat16 Matters
TPUs popularized the bfloat16 number format, and it is one of those small design choices with outsized consequences. A standard IEEE fp16 has 5 exponent bits and 10 mantissa bits. bfloat16 keeps the full 8 exponent bits of fp32 but truncates the mantissa to 7 bits. The result has the same dynamic range as fp32 (so gradients rarely overflow or underflow during training) but half the storage. You get most of fp32's numerical stability at half the memory bandwidth and roughly double the compute throughput, with no loss-scaling gymnastics. It is the format that made large-scale mixed-precision training on TPUs routine.
Core Components Deep Dive
Let's go one level deeper into the pieces that determine real-world performance.
The MXU and Arithmetic Intensity
The MXU only earns its keep when it is kept full. A single 128×128 MXU can retire 16,384 MACs/cycle, but only if activations arrive fast enough. The governing relationship is the roofline model: your achievable throughput is \min(peak_FLOPS, arithmetic_intensity × memory_bandwidth)\. For a matrix multiply, arithmetic intensity scales with batch size and matrix dimensions. This is why TPUs love large batches and large, dense matmuls — they push you to the compute-bound side of the roofline. Small batches or skinny matrices leave the array idle.
# A mental model of MXU utilization for a dense layer.
# The point: bigger contraction dimension -> higher array utilization.def mxu_utilization(batch, in_features, out_features, mxu=128): # A matmul tiles onto the systolic array in 128x128 blocks. m_tiles = -(-batch // mxu) # ceil division k_tiles = -(-in_features // mxu) n_tiles = -(-out_features // mxu)
used = batch in_features out_features padded = (m_tiles mxu) (k_tiles mxu) (n_tiles * mxu) return used / padded # fraction of the padded work that is "real"
print(mxu_utilization(512, 4096, 4096)) # ~1.0 -> great, dimensions align print(mxu_utilization(1, 4096, 4096)) # ~0.008 -> terrible, batch=1 wastes the array ```
That second number is the single most important intuition in this article: a TPU running batch-size-1 inference wastes ~99% of its MXU. TPUs are throughput monsters, not latency monsters. Design accordingly.
High-Bandwidth Memory and the Feeding Problem
Each TPU chip pairs its compute with tens of gigabytes of HBM delivering multiple terabytes per second of bandwidth. This is what feeds the array. When a model does not fit in one chip's HBM — which is the normal case for anything with tens of billions of parameters — you must shard it across many chips, and now the interconnect between chips becomes the bottleneck. Which brings us to the most distinctive part of the TPU story.
The Inter-Chip Interconnect (ICI) and Optical Circuit Switching
A single TPU is unremarkable. A TPU pod is the point. Google connects chips with a dedicated high-speed Inter-Chip Interconnect (ICI) arranged as a 3D torus — each chip talks directly to its neighbors with no host CPU or Ethernet switch in the path. Starting with v4, Google added Optical Circuit Switches (OCS): reconfigurable mirrors that let the topology be rewired in software. This means a pod can be sliced into many independent "slices" for different jobs, faulty chips can be routed around without killing a training run, and the network topology can be matched to the communication pattern of the model (data-parallel, tensor-parallel, pipeline-parallel). A full pod scales to thousands of chips presenting themselves to your software as one enormous accelerator.
TPU POD TOPOLOGY (simplified 2D slice of a 3D torus)[TPU]══ICI══[TPU]══ICI══[TPU]══ICI══[TPU] ║ ║ ║ ║ [TPU]══ICI══[TPU]══ICI══[TPU]══ICI══[TPU] ║ ║ ║ ║ [TPU]══ICI══[TPU]══ICI══[TPU]══ICI══[TPU] ║ ║ ║ ║ (wraps around edges -> torus; OCS can rewire slices) ```
Real-World Applications and Use Cases
TPUs are not a science project — they run at planetary scale. Here are concrete places they earn their keep:
- Frontier model training. Google's Gemini family is trained on large TPU pods. When your training run needs tens of thousands of accelerators cooperating on one gradient step, the ICI/OCS fabric's tight, switch-free coupling is a genuine advantage over commodity GPU clusters stitched together with InfiniBand.
- High-throughput serving. Search ranking, Translate, YouTube recommendations, and Photos all lean on TPU inference where the workload is high-volume, batchable, and latency-tolerant enough to fill the array. Cost-per-query is the metric, and TPUs win it.
- Recommender systems with huge embeddings. The SparseCore on modern TPUs accelerates the massive embedding-table lookups that dominate ranking models — a workload GPUs handle less gracefully.
- Enterprise LLM workloads on Cloud TPU. Teams renting Cloud TPU v5e slices for fine-tuning and batch inference of open-weight models frequently see materially better throughput-per-dollar than comparable GPU instances, provided their software stack is JAX- or XLA-friendly. In an internal Hureka benchmark on a batch summarization pipeline, a v5e slice beat a comparable GPU deployment on cost-per-million-tokens once we tuned batch sizes to fill the MXU.
Implementation Guide
The honest truth: the hardware is only as good as the compiler that targets it. On TPUs that compiler is XLA (Accelerated Linear Algebra), and the most idiomatic way to reach it is JAX. Here is a minimal but real end-to-end shape.
import jax
import jax.numpy as jnp
from jax import random# 1) Confirm you actually have TPU devices. print(jax.devices()) # e.g. [TpuDevice(id=0, ...), TpuDevice(id=1, ...), ...]
# 2) A simple forward pass. XLA fuses these ops and maps the matmul to the MXU. def mlp(params, x): for W, b in params[:-1]: x = jax.nn.gelu(x @ W + b) Wf, bf = params[-1] return x @ Wf + bf
# 3) jit-compile: XLA traces the graph once, then runs it natively on the TPU. fast_mlp = jax.jit(mlp)
key = random.PRNGKey(0) x = random.normal(key, (512, 4096), dtype=jnp.bfloat16) # big batch, bf16 -> fills the MXU ```
To scale across a pod, JAX exposes explicit sharding rather than hiding it. You declare a device mesh and annotate how arrays are partitioned; XLA inserts the collective communication (all-reduce, all-gather) over the ICI for you.
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np# Build a 2D logical mesh over the physical TPU slice. devices = np.array(jax.devices()).reshape(4, 2) # e.g. 8 chips -> (data, model) mesh = Mesh(devices, axis_names=('data', 'model'))
# Shard activations along the batch (data-parallel) and weights along 'model' # (tensor-parallel). The compiler handles the cross-chip collectives. act_sharding = NamedSharding(mesh, P('data', None)) w_sharding = NamedSharding(mesh, P(None, 'model')) ```
If you live in PyTorch, you are not locked out: PyTorch/XLA lets you target TPUs by lowering the same graph through XLA, and the newer \torch_xla\ dynamo integration makes this far less painful than it used to be. The mental model is identical — trace, compile, run — you just express it in PyTorch.
The one workflow rule that will save you the most grief: avoid recompilation. XLA compiles for a specific set of input shapes. If your tensor shapes change between calls (variable sequence lengths, dynamic batch sizes), XLA recompiles, and compilation is slow. Pad to fixed "bucket" shapes so the compiled program is reused.
Production Patterns and Best Practices
Lessons learned deploying TPU workloads for enterprise clients at Hureka Technologies:
- Fill the array or go home. The number-one performance mistake is under-batching. If you are serving at batch size 1, you are paying for 16,384 MACs and using ~130 of them. Use continuous/dynamic batching to aggregate requests before they hit the MXU.
- Standardize on bucketed shapes. Pick a small set of \`(batch, seq_len)\` buckets, pad to them, and compile once per bucket at startup. Treat an unexpected recompilation in production as an incident — it will show up as a latency spike.
- Keep data off the critical path. The MXU can consume data faster than most input pipelines can produce it. Use \`tf.data\` / Grain with prefetching and on-device sharding so the accelerator never waits on the host.
- Checkpoint asynchronously. On thousand-chip runs, a synchronous checkpoint stalls every chip. Use async checkpointing (e.g. Orbax) to overlap I/O with compute; a stalled all-reduce on a pod is catastrophically expensive.
- Design for preemption. Cloud TPU slices, especially spot/preemptible ones, can vanish. Idempotent, resumable-from-checkpoint training is not optional at scale.
- Profile with the right tool. The XLA profiler and TensorBoard's trace viewer show you MXU utilization and step-time breakdown. Optimize what the profiler says is slow, not what you assume is slow.
Performance, Benchmarks, and Optimization
Concrete numbers and trade-offs, with the caveat that exact figures shift every generation:
- Peak throughput on a modern per-chip TPU is on the order of hundreds of teraFLOPS to low petaFLOPS in bf16/int8, and a full pod aggregates into the multi-exaFLOP range. But peak is a fiction; Model FLOPs Utilization (MFU) is the honest metric. A well-tuned large-model training run lands roughly in the 45–65% MFU band; if you are under ~30%, you have an input-pipeline, sharding, or shape-bucketing problem, not a hardware problem.
- Performance-per-watt is where TPUs have historically shone, because the systolic array removes the register-file and instruction-scheduling overhead that GPUs pay per operation. For batchable inference this often translates into a better cost-per-query than GPU alternatives.
- Quantization to int8 (and increasingly lower) can double or quadruple effective inference throughput on TPUs, with careful calibration to preserve accuracy. For serving, this is usually the highest-leverage optimization after batching.
- Optimization priority order that I actually use: (1) maximize batch / fill the MXU, (2) eliminate recompilation via fixed shapes, (3) fix the input pipeline so it out-runs the accelerator, (4) tune the sharding strategy to minimize cross-chip collectives, (5) quantize for inference. Do them in that order; the early ones dwarf the later ones.
Common Mistakes and How to Avoid Them
- 1Serving at batch size 1 and blaming the hardware. The fix is aggregation: continuous batching, request coalescing, larger serving batches. The TPU is not slow; it is starving.
- 2Ignoring recompilation. Dynamic shapes silently retrigger XLA compilation. Fix: pad to a fixed set of bucketed shapes and warm up each at startup.
- 3Treating a TPU like a GPU. Porting CUDA-kernel-level thinking wholesale fails. TPUs want a *graph* handed to a compiler, not hand-written kernels. Fix: express the model in JAX/XLA idioms and let the compiler do the mapping.
- 4Starving the array with a weak input pipeline. Host-side preprocessing becomes the bottleneck. Fix: prefetch, shard input on-device, and profile step-time to confirm the accelerator is the bottleneck (it should be).
- 5Naive sharding that maximizes communication. Splitting a model the wrong way turns the ICI into the bottleneck. Fix: match the parallelism strategy (data / tensor / pipeline) to layer shapes and pod topology; profile the collectives.
- 6Assuming fp16 habits transfer. Loss-scaling and fp16 overflow tricks are largely unnecessary on TPUs thanks to bfloat16's fp32-range exponent. Fix: default to bf16 and drop the loss-scaling machinery.
Tool and Technology Comparison
| Dimension | Google TPU (v5p / Trillium class) | NVIDIA GPU (H100 / B200 class) | AWS Trainium/Inferentia | CPU |
|---|---|---|---|---|
| Core compute primitive | Systolic array (MXU) | SIMT cores + Tensor Cores | Systolic-style engines | Scalar/vector cores |
| Best-fit workload | Large dense matmul, huge-scale training & batch inference | Broad: training, inference, research, sparse/dynamic | Cost-optimized AWS training/inference | Control logic, small models |
| Interconnect | ICI 3D-torus + optical (OCS), switch-free | NVLink + InfiniBand/Ethernet fabric | NeuronLink | PCIe / network |
| Software stack | JAX / XLA / TF, PyTorch-XLA | CUDA / cuDNN, huge ecosystem | Neuron SDK | Everything |
| Ecosystem maturity | Strong but narrower; JAX-centric | Largest by far; de-facto standard | Growing, AWS-only | Universal |
| Availability | Google Cloud only | Everywhere (all clouds + on-prem) | AWS only | Everywhere |
| Sweet spot | You are all-in on GCP and workloads are batchable | You need flexibility, portability, or bleeding-edge kernels | You are all-in on AWS and cost-sensitive | Non-matmul workloads |
The short version: GPUs win on flexibility, portability, and ecosystem; TPUs win on throughput-per-dollar and per-watt for large, batchable, dense workloads — if you are willing to live inside Google Cloud and the XLA/JAX world. For a startup that needs to move fast across clouds, GPUs are usually right. For a team training or serving at massive, steady scale on GCP, TPUs can be dramatically cheaper.
Future Trends and What Is Coming Next
The trajectory through 2026–2028 is clear even if the exact specs are not. First, inference is eating the roadmap. The v7 (Ironwood) generation is explicitly oriented toward serving reasoning models at scale, because the compute cost of test-time reasoning — models that "think" longer before answering — is exploding. Expect future TPUs to optimize aggressively for low-precision inference (int8, fp8, and below), giant HBM capacity to hold long contexts and KV caches, and SparseCore-style acceleration generalized beyond embeddings.
Second, the network is becoming the computer. Optical circuit switching will deepen; the differentiator between generations is increasingly how many chips you can couple tightly and reconfigure dynamically, not the per-chip FLOPS. Third, the software moat is narrowing. JAX and XLA are getting easier to adopt, PyTorch/XLA is maturing, and OpenXLA is opening the compiler stack — which erodes the historical "but everything is CUDA" objection and makes TPUs a realistic option for more teams. Finally, expect co-design: models architected specifically to exploit the systolic array's love of large dense matmuls, and hardware tuned to the shapes real frontier models actually use.
Conclusion and Next Steps
Tensor Processing Units are the clearest example in modern computing of what happens when you specialize hardware to a workload. By throwing away generality and building the entire chip around a systolic-array matrix multiplier, Google produced silicon that — for the specific shape of neural network math — delivers throughput-per-watt and throughput-per-dollar that general-purpose chips struggle to match. The elegance of the systolic array, the pragmatism of bfloat16, and the audacity of optically-switched pods are worth understanding whether or not you ever deploy on a TPU, because they teach you how to think about the physical economics of AI.
Your actionable next steps: if you are on Google Cloud and your workload is batchable and matmul-dominated, spin up a Cloud TPU v5e slice and benchmark your real pipeline against your current GPU cost-per-token — measure MFU, not peak FLOPS. Learn enough JAX to express a model and \jax.jit\ it, because that is the shortest path to seeing the hardware perform. And whatever silicon you run on, internalize the one lesson that transfers everywhere: keep the matrix multiplier full. That single principle — arithmetic intensity over everything — is what separates a cost-effective AI system from an expensive one.
If you want help deciding whether TPUs, GPUs, or a hybrid fits your specific training and serving workloads, [reach out via the contact page](/contact) — hardware-aware AI architecture is exactly the kind of work we do for enterprise clients at Hureka Technologies. And if you found this useful, the companion piece on [IBM's 0.7 nm chip](/blog/ibm-07nm-chip-ai-future-2026) covers the manufacturing side of the same story.