← Deep Dives

The Neural Network: Five Embeddings, One Trunk, Three Heads

Deep Dive · 2026-04-19 · 8 min read

A ~234K-parameter policy + value network that learns to play Slay the Spire 2 — told in one flow diagram and three cards.

card_emb 16d nn.Embedding(64, 16) enemy_emb 16d nn.Embedding(32, 16) relic_emb 8d nn.Embedding(128, 8) pile_emb 8d nn.Embedding(64, 8) power_emb 8d nn.Embedding(64, 8) scalars 18d HP, block, energy, turn, debuffs concat → 583 Linear · LayerNorm · ReLU 256 Linear · LayerNorm · ReLU 256 256 256 256 · aux policy head 61 actions masked softmax · Linear(256, 61) value V(s) scalar · Linear(256, 1) discard head aux · 10 · Linear(256, 10) PILL = nn.Embedding table RECT = Linear + LayerNorm + ReLU DIAMOND = policy head (masked) CIRCLE = value head DASHED = auxiliary head (discard)

One shared trunk, five embedding tables, two (really three) heads. The trunk is smaller than the input — almost all of the representation work happens in the embeddings.

~234K params · policy + value + discard · orthogonal init

The shape of it

I kept waiting for this network to need to be bigger. It never did. Two hidden layers of 256, five embedding tables, three linear heads, about 234K parameters total — the whole policy fits in a 2.3 MB .pth file and the shared trunk is actually smaller than the input that feeds it. For a while that felt wrong. I'd read enough RL papers to expect a network with a recognizable name and a lot more depth. Instead, the thing that worked was an MLP small enough to fit on a napkin.

The shape earns its small size because the interesting work happens in the embedding tables on the left side of the diagram, not in the trunk. Each entity type in the game — cards, enemies, relics, pile cards, generic powers — gets its own learned embedding. Five tables, roughly 4K parameters worth of weights across them, and that's where the network builds its picture of what each game object is. The trunk's job is to mix the concatenated picture into something the heads can decode. The heads themselves barely exist: 61 outputs for the policy, one for the value, ten for a discard-priority helper. Most of the architectural decisions in this file were about the tables.

This article walks through those decisions in the order the diagram reads — embeddings on the left, trunk in the middle, heads on the right — and each section tries to answer one question: why is that piece shaped the way it is?

Embeddings

The previous version of this network used one-hot everything. Cards were 80-dim one-hots, enemies were one-hots over a ~15-entity pool, relics and powers were bitmaps. The input cleared 1500 dimensions, and — more importantly — the network had no structural reason to treat Strike and Strike+ as related. That article covers why one-hot was wrong and how I found out. This one picks up where it left off: five nn.Embedding tables, and the reasoning that sized each of them.

self.card_emb  = nn.Embedding(card_vocab_size,  16, padding_idx=0)
self.pile_emb  = nn.Embedding(card_vocab_size,  8,  padding_idx=0)
self.enemy_emb = nn.Embedding(enemy_vocab_size, 16, padding_idx=0)
self.relic_emb = nn.Embedding(relic_vocab_size, 8,  padding_idx=0)
self.power_emb = nn.Embedding(power_vocab_size, 8,  padding_idx=0)

Cards and enemies get 16-dim embeddings; relics, pile cards, and player powers get 8. The split wasn't a theoretical calculation — it was a budget question. At every step the agent takes, the most decision-relevant entities are the cards in hand and the enemies on screen. Playing a card is the action. Targeting an enemy is the action. The gradient signal on card_emb and enemy_emb is dense and immediate, and giving those tables more dimensions gave the network more room to discover useful structure. When I tried dropping them to 8, training took roughly twice as long to reach the same win rate on stage 3. When I tried bumping them to 32, training looked identical and the forward pass cost more. 16 landed.

Relics and powers are different. They modify the game but aren't chosen as actions — relics drop at the end of a combat, not turn-by-turn. The gradient signal on relic_emb is sparse; each relic updates a handful of times per episode, not per step. 8 dimensions turns out to be plenty to encode "what does this relic do" once the network has trained for long enough. I tried 16 for relics once and the table's weights converged to near-identical rows for functionally similar relics — it was overparameterized for the amount of data flowing through it.

pile_emb is the weird one. It uses the same vocab as card_emb (every card in the game has one entry), but a separate embedding table at 8 dimensions. Pile cards aren't played directly; they're context. "Roughly what's in my draw pile" is useful; "this exact slot of the draw pile is Strike+" is not, because the pile is shuffled. A separate smaller table gave the network a distinct representation for cards-as-context versus cards-as-actions, and let the two representations evolve independently. If I'd shared one 16-dim table, the gradient on playing a card and the gradient on pooling-over-pile would have fought each other.

The padding_idx=0 argument on each table is the detail that made masking almost free. Every table's zero-th slot is pinned to a zero vector, never gradient-updated. Empty hand slots, empty enemy slots, pile entries past the observed count — all of them map to index 0, produce a zero vector, and contribute nothing to the downstream concat. The network learns to read "all zeros" as "nothing is here" without me writing a masking branch into the trunk.

Unordered collections go through mean-pooling before they hit the trunk. The draw pile, discard pile, and exhaust pile each mean-pool their card embeddings to 8 dims. Relics mean-pool across all equipped relics. Player powers mean-pool across active powers. Mean-pooling isn't the most expressive operator — attention would let the trunk weight pile cards by relevance — but it matches what the game actually exposes: unordered sets with no positional meaning. Encoding "slot 0 of the draw pile" as a specific representation would be baking in a lie, because slot 0 is wherever the shuffle put the next card.

5 tables · ~4K embedding params total · learned end-to-end

Trunk

self.trunk = nn.Sequential(
    nn.Linear(input_dim, 256),
    nn.LayerNorm(256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.LayerNorm(256),
    nn.ReLU(),
)

Two layers of 256. Linear, LayerNorm, ReLU. Nothing else.

I want to be honest about how little of this network's design is in the trunk. Two of my iterations touched it — once to try three layers instead of two, once to try a 512-dim hidden size — and both either trained no better or trained slightly worse while doubling compute. I kept 256 and moved on. The trunk is smaller than the 583-dim input that flows into it, which reliably surprises people when they see the diagram; the intuition is that the first hidden layer should be at least as wide as the input. But the input isn't dense information. It's a concatenation of learned embeddings plus a handful of scalar features, and 256 dimensions is enough to hold a useful mixing of those.

LayerNorm instead of BatchNorm is a deliberate call specific to this project's actor-learner setup. The actor proxies run the network one state at a time to produce actions during rollout, which means inference happens with batch size 1. BatchNorm's running statistics would drift between the batched training pass on the learner and the single-sample inference on the actor — the network would train on one normalization and deploy under another. LayerNorm normalizes each sample independently and doesn't care whether the batch is 1 or 1024, so the actor and learner see identical forward passes.

Orthogonal initialization (nn.init.orthogonal_(m.weight, gain=sqrt(2))) is applied to every nn.Linear. Orthogonal init is a PPO convention — the theory is that it preserves gradient magnitudes better than Xavier/He for the ReLU stack, and in practice training curves noticeably differ. The policy head gets a much smaller gain (0.01) so the initial action distribution is near-uniform instead of peaked at some arbitrary action. Without that override, the first few hundred episodes punish the agent for preferences the init baked in, which is both a waste of rollouts and a source of confusing early-training regressions.

Policy head + masked softmax

logits = self.policy_head(h)                         # [B, 61]
logits = logits.masked_fill(~enc.action_mask, -1e9)  # illegal → -∞

The action space is end_turn plus play(card_idx, target_idx) for 10 hand slots × 6 possible targets, for 61 actions total. The policy head is one nn.Linear(256, 61), and the network produces a categorical distribution over all 61 slots — then masks away the illegal ones before sampling.

Most states have nowhere near 61 legal actions. A typical mid-turn state has three to seven cards in hand, each legal against one to three enemies, plus end_turn as a baseline. Everything else is illegal — the slot has no card, the card costs too much energy, the card targets an enemy type that doesn't match. The encoder walks the game's legal-action list and builds a Boolean mask of length 61; the forward pass applies masked_fill(~mask, -1e9) to the logits before they hit softmax.

Caveat · -1e9 vs float('-inf')

-1e9 is not float('-inf'), and I learned the distinction the hard way around training episode 400 the second time I ran training. float('-inf') works at forward time — the softmax assigns zero probability to illegal actions, perfect. It breaks at backward time: any two legal actions with identical logits produce a logsumexp that differences two -infs and yields NaN gradients. -1e9 is finite, survives the backward pass, and gives about 1e-9 probability to illegal actions after softmax — effectively zero for sampling purposes, with a comfortable numerical margin.

The masked logits are also what the PPO surrogate loss differentiates through. log_softmax(logits) with the mask applied zeroes out the contributions of illegal actions, which is the correct treatment — a policy update shouldn't reward or penalize actions that couldn't have been taken.

I considered a variable-length action head — some kind of pointer network or masked attention over the current legal-action list — and decided against it. At 61 actions the fixed-width linear is 16K parameters, the masking is three lines, and the rest of the PPO pipeline doesn't need to special-case anything. If the action space grew by an order of magnitude, or if legal-action sets had internal structure a pointer network could exploit, I'd reconsider. At this scale the simple version wins.

Value head

value = self.value_head(h)  # [B, 1]

One linear, 256 to 1, 257 parameters including bias. That's the whole head.

PPO's advantage computation needs a value estimate to subtract as a baseline — the advantage is return - V(state), and the policy update scales by advantage. The value estimate doesn't have to be accurate in absolute terms; it has to correlate with return well enough that the baseline reduces variance in the policy gradient. One linear layer on top of the 256-d trunk is enough for that, and the value loss trains against the discounted return with standard MSE. Both the policy loss and the value loss gradient back through the shared trunk, which is the main reason the trunk learns useful features — the value signal is denser per step than the policy signal and keeps the representation well-calibrated.

The discard head is an architectural footnote: self.discard_head = nn.Linear(256, 10), producing ten scores per hand slot sorted descending to give the agent a "discard this first" priority list. It fires only when STS2 forces a mid-turn discard. I considered folding the decision into the main policy head, but "pick one action out of 61" and "rank ten slots" are different shapes; giving each its own head kept each loss attributed to one decision shape, and kept the gradients specific to that shape.

Takeaways

01 Embeddings are the representation work. The trunk mixes; the heads decode.
02 Separate pile_emb from card_emb. Cards-as-actions and cards-as-context are different jobs.
03 LayerNorm over BatchNorm for actor-learner RL. Inference is batch-of-1; BatchNorm drifts.
04 masked_fill(~mask, -1e9) before softmax. The finite margin survives the backward pass without NaN.
05 Orthogonal init with gain=0.01 on the policy head. Starts near-uniform, not peaked on an arbitrary action.