The Neural Network: Five Embeddings, One Trunk, Three Heads
A ~234K-parameter policy + value network that learns to play Slay the Spire 2 — told in one flow diagram and three cards.
One shared trunk, five embedding tables, two (really three) heads. The trunk is smaller than the input — almost all of the representation work happens in the embeddings.
The shape of it
I kept waiting for this network to need to be bigger. It never did. Two
hidden layers of 256, five embedding tables, three linear heads, about
234K parameters total — the whole policy fits in a 2.3 MB
.pth file and the shared trunk is actually
smaller than the input that feeds it. For a while that felt
wrong. I'd read enough RL papers to expect a network with a
recognizable name and a lot more depth. Instead, the thing that worked
was an MLP small enough to fit on a napkin.
The shape earns its small size because the interesting work happens in the embedding tables on the left side of the diagram, not in the trunk. Each entity type in the game — cards, enemies, relics, pile cards, generic powers — gets its own learned embedding. Five tables, roughly 4K parameters worth of weights across them, and that's where the network builds its picture of what each game object is. The trunk's job is to mix the concatenated picture into something the heads can decode. The heads themselves barely exist: 61 outputs for the policy, one for the value, ten for a discard-priority helper. Most of the architectural decisions in this file were about the tables.
This article walks through those decisions in the order the diagram reads — embeddings on the left, trunk in the middle, heads on the right — and each section tries to answer one question: why is that piece shaped the way it is?
Embeddings
The previous version of this network used one-hot everything. Cards
were 80-dim one-hots, enemies were one-hots over a ~15-entity pool,
relics and powers were bitmaps. The input cleared 1500 dimensions,
and — more importantly — the network had no structural reason to
treat Strike and Strike+ as related.
That article covers why
one-hot was wrong and how I found out. This one picks up where it
left off: five nn.Embedding tables, and the reasoning
that sized each of them.
self.card_emb = nn.Embedding(card_vocab_size, 16, padding_idx=0)
self.pile_emb = nn.Embedding(card_vocab_size, 8, padding_idx=0)
self.enemy_emb = nn.Embedding(enemy_vocab_size, 16, padding_idx=0)
self.relic_emb = nn.Embedding(relic_vocab_size, 8, padding_idx=0)
self.power_emb = nn.Embedding(power_vocab_size, 8, padding_idx=0)
Cards and enemies get 16-dim embeddings; relics, pile cards, and
player powers get 8. The split wasn't a theoretical calculation — it
was a budget question. At every step the agent takes, the most
decision-relevant entities are the cards in hand and the enemies on
screen. Playing a card is the action. Targeting an enemy
is the action. The gradient signal on card_emb
and enemy_emb is dense and immediate, and giving those
tables more dimensions gave the network more room to discover useful
structure. When I tried dropping them to 8, training took roughly
twice as long to reach the same win rate on stage 3. When I tried
bumping them to 32, training looked identical and the forward pass
cost more. 16 landed.
Relics and powers are different. They modify the game but aren't
chosen as actions — relics drop at the end of a combat, not
turn-by-turn. The gradient signal on relic_emb is sparse;
each relic updates a handful of times per episode, not per step. 8
dimensions turns out to be plenty to encode "what does this relic do"
once the network has trained for long enough. I tried 16 for relics
once and the table's weights converged to near-identical rows for
functionally similar relics — it was overparameterized for the amount
of data flowing through it.
pile_emb is the weird one. It uses the same vocab as
card_emb (every card in the game has one entry), but a
separate embedding table at 8 dimensions. Pile cards aren't
played directly; they're context. "Roughly what's in my draw pile" is
useful; "this exact slot of the draw pile is Strike+" is not, because
the pile is shuffled. A separate smaller table gave the network a
distinct representation for cards-as-context versus cards-as-actions,
and let the two representations evolve independently. If I'd shared
one 16-dim table, the gradient on playing a card and the gradient on
pooling-over-pile would have fought each other.
The padding_idx=0 argument on each table is the detail
that made masking almost free. Every table's zero-th slot is pinned
to a zero vector, never gradient-updated. Empty hand slots, empty
enemy slots, pile entries past the observed count — all of them map
to index 0, produce a zero vector, and contribute nothing to the
downstream concat. The network learns to read "all zeros" as "nothing
is here" without me writing a masking branch into the trunk.
Unordered collections go through mean-pooling before they hit the trunk. The draw pile, discard pile, and exhaust pile each mean-pool their card embeddings to 8 dims. Relics mean-pool across all equipped relics. Player powers mean-pool across active powers. Mean-pooling isn't the most expressive operator — attention would let the trunk weight pile cards by relevance — but it matches what the game actually exposes: unordered sets with no positional meaning. Encoding "slot 0 of the draw pile" as a specific representation would be baking in a lie, because slot 0 is wherever the shuffle put the next card.
Trunk
self.trunk = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.LayerNorm(256),
nn.ReLU(),
) Two layers of 256. Linear, LayerNorm, ReLU. Nothing else.
I want to be honest about how little of this network's design is in the trunk. Two of my iterations touched it — once to try three layers instead of two, once to try a 512-dim hidden size — and both either trained no better or trained slightly worse while doubling compute. I kept 256 and moved on. The trunk is smaller than the 583-dim input that flows into it, which reliably surprises people when they see the diagram; the intuition is that the first hidden layer should be at least as wide as the input. But the input isn't dense information. It's a concatenation of learned embeddings plus a handful of scalar features, and 256 dimensions is enough to hold a useful mixing of those.
LayerNorm instead of BatchNorm is a deliberate call specific to this project's actor-learner setup. The actor proxies run the network one state at a time to produce actions during rollout, which means inference happens with batch size 1. BatchNorm's running statistics would drift between the batched training pass on the learner and the single-sample inference on the actor — the network would train on one normalization and deploy under another. LayerNorm normalizes each sample independently and doesn't care whether the batch is 1 or 1024, so the actor and learner see identical forward passes.
Orthogonal initialization
(nn.init.orthogonal_(m.weight, gain=sqrt(2))) is applied
to every nn.Linear. Orthogonal init is a PPO convention
— the theory is that it preserves gradient magnitudes better than
Xavier/He for the ReLU stack, and in practice training curves
noticeably differ. The policy head gets a much smaller gain
(0.01) so the initial action distribution is
near-uniform instead of peaked at some arbitrary action. Without
that override, the first few hundred episodes punish the agent for
preferences the init baked in, which is both a waste of rollouts
and a source of confusing early-training regressions.
Policy head + masked softmax
logits = self.policy_head(h) # [B, 61]
logits = logits.masked_fill(~enc.action_mask, -1e9) # illegal → -∞
The action space is end_turn plus
play(card_idx, target_idx) for 10 hand slots × 6 possible
targets, for 61 actions total. The policy head is one
nn.Linear(256, 61), and the network produces a
categorical distribution over all 61 slots — then masks away the
illegal ones before sampling.
Most states have nowhere near 61 legal actions. A typical mid-turn
state has three to seven cards in hand, each legal against one to
three enemies, plus end_turn as a baseline. Everything
else is illegal — the slot has no card, the card costs too much
energy, the card targets an enemy type that doesn't match. The
encoder walks the game's legal-action list and builds a Boolean mask
of length 61; the forward pass applies
masked_fill(~mask, -1e9) to the logits before they hit
softmax.
-1e9 is not float('-inf'), and I learned
the distinction the hard way around training episode 400 the second
time I ran training. float('-inf') works at forward
time — the softmax assigns zero probability to illegal actions,
perfect. It breaks at backward time: any two legal actions with
identical logits produce a logsumexp that differences
two -infs and yields NaN gradients. -1e9
is finite, survives the backward pass, and gives about 1e-9
probability to illegal actions after softmax — effectively zero
for sampling purposes, with a comfortable numerical margin.
The masked logits are also what the PPO surrogate loss differentiates
through. log_softmax(logits) with the mask applied
zeroes out the contributions of illegal actions, which is the
correct treatment — a policy update shouldn't reward or penalize
actions that couldn't have been taken.
I considered a variable-length action head — some kind of pointer network or masked attention over the current legal-action list — and decided against it. At 61 actions the fixed-width linear is 16K parameters, the masking is three lines, and the rest of the PPO pipeline doesn't need to special-case anything. If the action space grew by an order of magnitude, or if legal-action sets had internal structure a pointer network could exploit, I'd reconsider. At this scale the simple version wins.
Value head
value = self.value_head(h) # [B, 1] One linear, 256 to 1, 257 parameters including bias. That's the whole head.
PPO's advantage computation needs a value estimate to subtract as a
baseline — the advantage is return - V(state), and the
policy update scales by advantage. The value estimate doesn't have
to be accurate in absolute terms; it has to correlate with return
well enough that the baseline reduces variance in the policy
gradient. One linear layer on top of the 256-d trunk is enough for
that, and the value loss trains against the discounted return with
standard MSE. Both the policy loss and the value loss gradient back
through the shared trunk, which is the main reason the trunk learns
useful features — the value signal is denser per step than the
policy signal and keeps the representation well-calibrated.
The discard head is an architectural footnote:
self.discard_head = nn.Linear(256, 10), producing ten
scores per hand slot sorted descending to give the agent a "discard
this first" priority list. It fires only when STS2 forces a
mid-turn discard. I considered folding the decision into the main
policy head, but "pick one action out of 61" and "rank ten slots"
are different shapes; giving each its own head kept each loss
attributed to one decision shape, and kept the gradients specific
to that shape.
Takeaways
pile_emb from card_emb. Cards-as-actions and cards-as-context are different jobs. masked_fill(~mask, -1e9) before softmax. The finite margin survives the backward pass without NaN. gain=0.01 on the policy head. Starts near-uniform, not peaked on an arbitrary action.