Wee Sun LeeSchool of Computing
National University of [email protected]
Neural Information Processing Systems December 2021
Message Passing
Machine learning using distributed algorithm on graph ๐บ = ๐, ๐ธ .Node ๐ฃ! โ ๐ do local computation,depends only on information at ๐ฃ!, neighbours, and incident edges ๐!" โ ๐ธ.Information sent through ๐!" to ๐ฃ" as message ๐!".
vi
vk
vj
mij
Why Message Passing Algorithms?
Effective in practice! โข Graph neural networksโข Transformerโข Probabilistic graphical model inference โข Value iteration for Markov decision process ....
Potentially easier to parallelize.
Plan for TutorialCover better understood, more โinterpretableโ models and algorithms
โข Probabilistic graphical models (PGM), inference algorithms โข Markov decision process (MDP), decision algorithmWhat do the components of the models represent?What objective functions are the algorithms are optimizing for?
Discuss more flexible, less โinterpretableโ methodsโข Graph neural networks โข Attention networks, transformer
Connect to PGM and MDP, help understand the inductive biases.
Probabilistic Graphical ModelFocus on Markov random fields๐ random variables {๐!, ๐", โฆ , ๐#}๐ ๐ = ๐(๐! = ๐ฅ!, ๐" = ๐ฅ", โฆ , ๐$ = ๐ฅ$)factors into a product of functions
๐ ๐ =1๐/%
๐%(๐%)
โข ๐ non-negative compatibility or potentialfunctions ๐! , ๐, โฆ , ๐"
โข ๐ฅ#, the argument of ๐#, is a subset of ๐ฅ$, ๐ฅ%, โฆ , ๐ฅ&
โข ๐, the partition function, is the normalizing constant
Graphical representation using factor graph: bipartite graph, each factor node connected to variable nodes that it depends on $'๐! ๐ฅ$, ๐ฅ( ๐) ๐ฅ$, ๐ฅ%, ๐ฅ* ๐+(๐ฅ*)
๐ฅ!
๐ฅ"
๐ฅ#
๐ฅ$
๐ด
๐ต
๐ถ
When compatibility functions always positive, can write
๐ ๐ =1๐/#
๐#(๐#)
=1๐๐$%(๐)
where
๐ธ ๐ = โ4#
)
ln ๐# ๐#
is the energy function.
Error Correcting Codes
[Fig from http://www.inference.org.uk/mackay/codes/gifs/]
๐ฅ!
๐ฅ"
๐ฅ#
๐ฅ$
๐ด
๐ต
๐ถ
๐ฆ!
๐ฆ"
๐ฆ#
๐ฆ$
๐ฆ% observed,corrupted version of ๐ฅ%
๐& ๐ฅ!, ๐ฅ", ๐ฅ$ = 11 if ๐ฅ!โจ๐ฅ#โจ๐ฅ$ = 10 ๐๐กโ๐๐๐ค๐๐ ๐
Codeword
Parity bits
MAP, Marginals, and Partition FunctionMaximum a posteriori (MAP): find a state ๐ that maximizes ๐ ๐
โข Equivalently minimizes the energy ๐ธ ๐
Marginal: probabilities for individual variable
๐! ๐ฅ! = โ๐\+' ๐ ๐
Partition function: Compute the normalizing constant
Z =4๐/#
๐#(๐#)
Semantic SegmentationIn conditional random field (CRF), the conditional distribution ๐ ๐ ๐ฐ = ,
-(๐ฐ)๐$%(๐|๐ฐ) is modeled.
โข For semantic segmentation ๐ฐ is the image and ๐ is the semantic class label.๐ธ ๐|๐ฐ = โ%๐( ๐ฅ%|๐ฐ + โ%)*๐+ ๐ฅ% , ๐ฅ* ๐ฐ)
[Figs from Zheng et. al. 2015]
Message passing algorithmSum product computes the marginals
๐%, ๐ฅ% = F-โ/(%)\3
๐-%(๐ฅ%)
๐,% ๐ฅ% = โ๐!\5"
๐, ๐, F*โ/(,)\6
๐%*(๐ฅ*)
๐% 7" โ F,โ/(%)
๐,%(๐ฅ%)
Max product solves the MAP problem: just replace sum with max in message passingWorks exactly on trees, dynamic programming
Belief Propagation๐ฅ!
๐ฅ"
๐ฅ#
๐ฅ$
๐ด
๐ต๐ถ
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
tree
Correctness ๐ท
๐ท
Every variable and factor nodes compute messages in parallel at each iteration
โข After ๐(๐ท) iterations, where ๐ท is the diameter of the tree, all messages and all marginals are correct.
Suffices to pass messages from leaves to the root and backโข More efficient for serial computation
Belief Propagation on Trees
Loopy Belief PropagationBelief propagation can also be applied to general probabilistic graphical modelsOften called loopy belief propagationAs a message passing algorithm:
Init all messages ๐%, , ๐,% to all-one vectorsrepeat ๐ iterations
for each variable ๐ and factor ๐ compute (in parallel)๐%, ๐ฅ% = โ-โ/(%)\3๐-%(๐ฅ%) for each ๐ โ ๐(๐)๐,% ๐ฅ% = โ
๐!\5"๐, ๐, โ*โ/(,)\6๐%*(๐ฅ*) for each ๐ โ ๐(๐)
return ๐% 7" =!8"โ,โ/(%)๐,%(๐ฅ%) for each ๐
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
Belief propagation may fail when there are cycles
โข May not even convergeโข Often works well in practice
when converges
Variational inference ideas help understand loopy belief propagation
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
๐%, ๐ฅ% = F-โ/(%)\3
๐-%(๐ฅ%)
๐,% ๐ฅ% = โ๐!\5"
๐, ๐, F*โ/(,)\6
๐%*(๐ฅ*)
๐% 7" โ F,โ/(%)
๐,%(๐ฅ%)
Variational Principles
View message passing algorithms through lens of variational principlesโข A variational principle solves a problem by viewing the solution
as an extremum (maximum, minimum, saddle point) of a function or functional
โข To understand or โinterpretโ an algorithm, ask โwhat objective might the algorithm implicitly be optimizing for?โ
In standard variational inference, we approximate the Helmholtz free energy
๐น! = โ ln๐๐ = โ๐ ๐#$(๐) , ๐(๐) = ๐#$(๐)/๐
by turning it into an optimization problem
For target belief ๐ and arbitrary belief ๐๐น! = ๐น ๐ โ ๐พ๐ฟ(๐||๐)
where ๐น ๐ is the variational free energy๐น ๐ =<
'๐ ๐ ๐ธ ๐ +<
'๐ ๐ ln ๐(๐)
and ๐พ๐ฟ(๐||๐) is the Kullback Lieber divergence between ๐ and ๐
๐พ๐ฟ(๐| ๐ =<'๐ ๐ ln
๐ ๐๐ ๐
Variational Inference
Derivation:ln ๐ ๐ฅ = โ๐ธ ๐ฅ โ ln ๐
๐น! = โ ln ๐ =1"
๐ ๐ฅ ๐ธ ๐ฅ +1"
๐ ๐ฅ ln ๐ ๐ฅ
=1"
๐ ๐ฅ ๐ธ ๐ฅ +1"
๐ ๐ฅ ln ๐ ๐ฅ + 1"
๐ ๐ฅ ln ๐ ๐ฅ โ1"
๐ ๐ฅ ln ๐ ๐ฅ
=1"
๐ ๐ฅ ๐ธ ๐ฅ + 1"
๐ ๐ฅ ln ๐ ๐ฅ โ1"
๐ ๐ฅ ln ๐ ๐ฅ /๐(๐ฅ)
TerminologyThe variational free energy๐น ๐ =T
7๐ ๐ ๐ธ ๐ +T
7๐ ๐ ln ๐(๐)
= ๐ ๐ โ ๐ป(๐)where ๐ ๐ = โ7 ๐ ๐ ๐ธ ๐is the variational average energy and ๐ป(๐)= -โ7 ๐ ๐ ln ๐(๐)is the variational entropy.
๐พ๐ฟ(๐| ๐ โฅ 0 and is zero when ๐ = ๐. From ๐น0 = ๐น ๐ โ ๐พ๐ฟ(๐||๐), ๐น ๐ is an upper bound for ๐น0
โข Minimizing ๐น ๐ improves approximation, exact when ๐ = ๐
Minimizing ๐น ๐ intractable in general
โข One approximate method is to use a tractable ๐
โข Mean field uses a factorized belief
๐9: ๐ฅ =F%;!
/
๐%(๐ฅ%)
Mean Field
Mean field often solved by coordinate descentโข Optimize one variable at a time, holding other variables
constant
๐[ ๐ฅ[ = \]56 exp โโ๐\_5โ`a[
b ๐` ๐ฅ` ๐ธ ๐
Coordinate descent converges to local optimum. โข Local optimum is fixed point of updates for all variables.โข Parallel updates can also be done but may not always converge
Derivation
Mean field as message passingRecall ๐ธ ๐ = โโ,9 ln๐, ๐,
๐# ๐ฅ# =1๐#$exp 9
๐\'!:()#
*
๐( ๐ฅ( 9+
,
ln๐+ ๐+
=1๐#$exp 9
+โ*(#)
9๐\'!
:()#
*
๐( ๐ฅ( ln๐+ ๐+ + 9+โ*(#)
9๐\'!
:()#
*
๐( ๐ฅ( ln๐+ ๐+
To compute ๐, ๐ฅ, , only need ๐# ๐# for neighbouring factors ๐ โ ๐ ๐
Does not depend on ๐ฅ#, constant
Previously๐* ๐ฅ* = !
812 exp โโ๐\51โ%<*
/ ๐% ๐ฅ% ๐ธ ๐
As a message passing algorithm on a factor graph:repeat ๐ iterations
for each variable ๐ compute (serially or in parallel)๐78(๐ฅ8) = โ๐!\:1โ;โ= 7 ,;?8
= ๐; ๐ฅ; ln ๐7 ๐7 for ๐ โ ๐ ๐ in parallel
๐8 ๐ฅ8 = @A12 exp โ7โ= 8 ๐78 ๐ฅ8
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
Loopy Belief Propagation and Bethe Free Energy1
For a tree-structured factor graph, the variational free energy
๐น ๐ =)-๐ ๐ ๐ธ ๐ +)
-๐ ๐ ln ๐(๐)
= โ)#.$
"
)๐$๐# ๐# ln๐# ๐#
+)#.$
"
)๐๐๐# ๐๐ ln
๐#(๐#)โ1โ&(#)๐1(๐ฅ1)
+)1.$
&
)-&๐1 ๐ฅ1 ln ๐1(๐ฅ1)
1 Yedidia, Jonathan S., William T. Freeman, and Yair Weiss. "Constructing free-energy approximations and generalized belief propagation algorithms." IEEE Transactions on information theory 51.7 (2005): 2282-2312.
Derivation
Variational average energy
Variational entropy
For the Bethe approximation, the following Bethe free energy ๐น?@AB@ ๐ = ๐?@AB@ ๐ โ ๐ป?@AB@(๐) is used even though the graph may not be a tree, where
๐?@AB@ ๐ = โ4#C,
)
4D=๐# ๐# ln ๐# ๐#
H?@AB@ = โ4#C,
)
4๐๐๐# ๐๐ ln
๐# ๐๐โ!โG # ๐! ๐ฅ!
โ4!C,
G
4D'๐! ๐ฅ! ln ๐!(๐ฅ!)
For a tree-structured factor graph, the variational free energy
๐น ๐ =)-๐ ๐ ๐ธ ๐ +)
-๐ ๐ ln ๐(๐)
= โ)#.$
"
)๐$๐# ๐# ln๐# ๐# +)
#.$
"
)๐๐๐# ๐๐ ln
๐#(๐#)โ1โ&(#)๐1(๐ฅ1)
+)1.$
&
)-&๐1 ๐ฅ1 ln ๐1(๐ฅ1)
In addition, we impose the constraintsโข โ7_% ๐% ๐ฅ% = โ๐๐ ๐, ๐, = 1โข ๐% ๐ฅ% โฅ 0, ๐, ๐ฅ, โฅ 0โข โ๐๐ \5" ๐, ๐, = ๐%(๐ฅ%)
What does loopy belief propagation optimize?Loopy belief propagation equations give the stationary points of the constrained Bethe free energy.
Derivation
We only specify the factor marginals ๐# ๐# and the variable marginals ๐!(๐ฅ!).โข There may be no distribution ๐
whose marginals agree with ๐# ๐#
โข Often called pseudomarginalsinstead of marginal
Furthermore, the Bethe entropy ๐ป?@AB@(๐) is an approximation of the variational entropy when the graph is not a tree
Figure from Wainwright and Jordan 2008.The set of marginals from valid probability distributions ๐(๐บ) is a strict subset of the set of of pseudomarginals ๐ฟ(๐บ).
Variational Inference Methods
Mean field minimizes the variational free energy ๐น ๐ = ๐ ๐ โ ๐ป(๐)
โข Assumes fully factorized ๐ for tractabilityโข Can be extended to other tractable ๐: structured mean fieldโข Minimizes upper bound of Helmholtz free energy ๐น0 = โ ln๐โข Converges to local optimum if coordinate descent used, may
not converge for parallel updateโข Update equations be computed as message passing on graph
Variational Inference MethodsLoopy belief propagation can be viewed as minimizing Bethe free energy, ๐น?@AB@ ๐ = ๐?@AB@ ๐ โ ๐ป?@AB@(๐), an approximation of variational free energy
โข May not be an upper bound of ๐น@โข Resulting ๐ may not be consistent with a probability distributionโข May not converge, but performance often good when convergesโข Message passing on a graph, various methods to help convergence,
e.g. scheduling messages, damping, etc.โข Extension to generalized belief propagation for other region based free
energy, e.g. Kikuchi free energyOther commonly found variational inference message passing methods include expectation propagation, also max product linear programming relaxations for finding MAP approximations.
Parameter EstimationLearn parameterized compatibility functions or components of energy function ๐ธ ๐|๐ฝ = โโ7B ln ๐7 ๐7|๐โข Can do maximum likelihood estimationโข If some variables are not observed, can do the EM algorithmโข If inference intractable, variational approximation for estimating the
latent variables is one approach: variational EMโข With mean field approximation, maximize a lower bound of likelihood
function โข Can also treat parameters as latent variables: Variational BayesFor this tutorial, focus on unrolling the message passing algorithm into a deep neural network and doing end-to-end learning (later).
Markov Decision Process
Markov Decision Process (MDP) is defined by ๐, ๐ด, ๐, ๐ State ๐ : Current description of the world
โข Markov: the past is irrelevant once we know the state
โข Navigation example: Position of the robot
Robot navigation
MDP ๐, ๐ด, ๐, ๐ Actions ๐ด : Set of available actions
โข Navigation example:โข Move Northโข Move Southโข Move Eastโข Move West
Robot navigation
MDP ๐, ๐ด, ๐, ๐ Transition function ๐ :
โข ๐ ๐ , ๐, ๐ A = ๐(๐ A|๐ , ๐)โข Navigation example:
โข Darker shade, higher probability
Robot navigation
MDP ๐, ๐ด, ๐, ๐ Reward function ๐ : Reward received when action ๐ in state ๐ results in transition to state ๐ โฒ
โข ๐ (๐ , ๐, ๐ 5)โข Navigation example:
โข 100 if ๐ โฒ is Homeโข -100 if ๐ โฒ is in the danger zoneโข -1 otherwise
โข Can be a function of a subset of ๐ , ๐, ๐ 5 as in navigation example
Robot navigation
Example of 3 state, two action Markov Decision Process ๐, ๐ด, ๐, ๐
โข Transition can be sparse as in navigation example
[Fig by waldoalvarez CC BY-SA 4.0 ]
MDP ๐, ๐ด, ๐, ๐ Policy ๐: Function from state and time step to action
โข ๐ = ๐ (๐ , ๐ก)โข Navigation example:
โข Which direction to move at current location at step ๐ก
Robot navigation
Robot navigationMDP ๐, ๐ด, ๐, ๐ Value function ๐I: How good is a policy ๐ when started from state ๐
โข ๐B ๐ C = โD;CEF!๐ธ[๐ (๐ D , ๐ ๐ D , ๐ก , ๐ DG!)โข ๐ is the horizonโข When horizon is infinite, usually
use discounted rewardโข ๐4 ๐ = โ5678 ๐ธ[๐พ5 ๐ (๐ 5 , ๐ ๐ 5 , ๐ก , ๐ 59:)โข ๐พ โ (0,1) is discount factor
MDP ๐, ๐ด, ๐, ๐ Optimal policy ๐โ: policy that maximizes
๐F ๐ G =<HIG
J#@๐ธ[๐ (๐ H, ๐ ๐ H, ๐ก , ๐ HK@)
โข For infinite horizon, discounted reward MDP, optimal policy ๐โ(๐ ) is stationary (independent of ๐ก)
โข Optimal value ๐ ๐ = ๐Fโ(๐ ): value corresponding to optimal policy
Robot navigation
Value Iteration Algorithm
Dynamic programming algorithm for solving MDPsLet ๐(๐ , ๐) denote the optimal value at ๐ when horizon is ๐, initialized with ๐ ๐ , 0 = ๐ฃT .Then
๐ ๐ , ๐ = max#๐ธ[๐ ๐ , ๐, ๐ U + ๐ ๐ U, ๐ โ 1 ]
= max#4
TU๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ U + ๐ ๐ U, ๐ โ 1 )
As message passing on a graphโข Node at each state ๐ , initialized to ๐ ๐ , 0 = ๐ฃLโข Utilize |๐ด| โheadsโ, one for each action ๐โข repeat ๐ iterations
for each action ๐ of each state ๐ (in parallel)Collect ๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ M + ๐ ๐ M ) from all ๐ โฒ to ๐ at ๐
if ๐ ๐ โฒ|๐, ๐ non-zeroSum all messages
for each node ๐ (in parallel)Collect message from its corresponding actions ๐Take the maximum of the messages
[Fig by waldoalvarez CC BY-SA 4.0 ]
๐(๐ , ๐) = max,T
HA๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ A + ๐ ๐ A, ๐ โ 1 )
Robot navigationShortest path exampleโข Deterministic transition (only one
next state with prob 1 from each action)
โข Initialize ๐ ๐ , 0 = 0 for goal state, init to โโ for other states
โข Self loop with 0 reward at goal state for all actions
โข Reward โ๐ค%* for moving from ๐ to ๐, โโ if no edge between the two nodes
โข Value iteration is Bellman-Fordshortest path algorithm
After ๐ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐ steps.
Robot navigation
At initialization
โโ โโ
โโ โโ โโ
โโ โโ
After ๐ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐ steps.
โโ โ๐
โโ โโ โ๐
โโ โโ
Robot navigation
After 1 iteration
After ๐ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐ steps.
โ๐ โ๐
โโ โ๐ โ๐
โโ โ๐
Robot navigation
After 2 iterations
For the infinite horizon discounted case, the dynamic programming equation (Bellman equation) is
๐ ๐ = max#๐ธ[๐ ๐ , ๐, ๐ U + ๐พ๐ ๐ U ]
= max#4
TU๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ U + ๐พ๐ ๐ U )
Same value iteration algorithm with message changed to๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ U + ๐พ๐ ๐ U )
Converges to the optimal value function
Convergence of Value Iteration
Derivation
The optimal value function ๐(๐ ) satisfies Bellmanโs principle of optimality๐ ๐ = max
,T
HA๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ A + ๐พ๐ ๐ A )
The Bellman update ๐ต in value iteration transforms ๐D to ๐DG! as follows๐DG! ๐ = max
,T
HA๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ A + ๐พ๐D ๐ A )
We denote this as ๐DG! = ๐ต๐D. The optimal value function is a fixed point of this operator ๐ = ๐ต๐. The Bellman update is a contraction (for the max norm), i.e.
๐ต๐! โ ๐ต๐# โค ๐พ ๐! โ ๐#
From Bellmanโs equation, we have ๐ = ๐ต๐ for the optimal value function ๐.Applying ๐A = ๐ต๐A$, repeatedly and using contraction property ๐ต๐, โ ๐ต๐V โค ๐พ ๐, โ ๐V we have
๐A โ ๐ = ๐ต๐A$, โ ๐ต๐ โค ๐พ ๐A$, โ ๐ โค ๐พA ๐W โ ๐Distance converges exponentially to 0 for any initial value ๐W
Graph Neural Networksโข Many effective graph neural networks (GNN):
GCN, GIN, โฆโข Message passing neural networks
(MPNN)1: general formulation for GNNs as message passing algorithms.
โข Input: ๐บ = (๐, ๐ธ), node attributes vectors ๐ฅ% , ๐ โ ๐, edge attribute vectors ๐ฅ%* , ๐, ๐ โ ๐ธ
โข Output: Label or value for graph classification or regression, or a label/value for each node in structured prediction
1 Gilmer, Justin, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. "Neural message passing for quantum chemistry." In International conference on machine learning, pp. 1263-1272. PMLR, 2017
v1
v2
v3
๐ฅ!
๐ฅ"
๐ฅ#๐ฅ!#
๐ฅ#!
๐ฅ"#
๐ฅ!"
MPNN pseudocodeInitialization: โ!W = ๐ฅ! for each ๐ฃ! โ ๐for layers โ = 1,โฆ , ๐
for every edge ๐, ๐ โ ๐ธ (in parallel)๐!"X = ๐๐๐บโ(โ!โ$,, โ"โ$,, ๐ฃ! , ๐ฃ" , ๐ฅ!")
for every node ๐ฃ! โ ๐ (in parallel)โ!โ = ๐๐โ {๐!"
X : ๐ โ ๐(๐) , โ!โ$,)return โ!Z for every ๐ฃ! โ ๐ or ๐ฆ = ๐ ๐ธ๐ด๐ท( โ!Z: ๐ฃ! โ ๐ )
โข ๐๐๐บโ is arbitrary function, usually a neural net โข ๐๐โ aggregrates the messages from neighbours (usually with a set
function) and combine with node embedding โข ๐ ๐ธ๐ด๐ท is a set function for graph classification or regression tasks
๐ฅ"
v1
v2
v3
๐ฅ!
๐ฅ#๐ฅ!#
๐ฅ#!
๐ฅ"#
๐ฅ!"
GNN propertiesIf depth and width are large enough, message functions ๐๐๐บโ and update functions ๐๐โ are sufficiently powerful, and nodes can uniquely distinguish each other, then the MPNN is computationally universal1
โข Equivalent to LOCAL model in distributed algorithmsโข Can compute any function computable with respect to
the graph and attributes (just send all information, including graph structure, to a single node, then compute there).
1 Loukas, Andreas. "What graph neural networks cannot learn: depth vs width.โ ICLR 2020
NotationDepth: number of layers
Width: largest embedding dimension
Example:
๐ท =0 0 110
01
00, ๐จ =
1 2 347
58
69
๐ท permutes vertices 1,2,3 โ (3,1,2)
๐ท๐จ swap the rows0 0 110
01
00
1 2 347
58
69
= 7 8 914
25
36
๐ท๐จ ๐ทJ swap the columns after that7 8 914
25
36
0 1 001
00
10
= 9 7 836
14
25
When using graph neural networks, we are often interested in permutation invariance and equivarianceGiven an adjacency matrix ๐จ, a permutation matrix ๐ท, and attribute matrix ๐ฟ containing the attributes ๐ฅ; on the ๐-th row
โข Permutation invariance: ๐ ๐ท๐จ๐ท๐ป, ๐ท๐ฟ = ๐(๐จ, ๐ฟ)
โข Permutation equivariance: ๐ ๐ท๐จ๐ท๐ป, ๐ท๐ฟ = ๐ท๐(๐จ, ๐ฟ)
If ๐๐๐บโ does not depend on node ids ๐ฃ; , ๐ฃ8, MPNN is permutation invariant and equivariant for any permutation matrix ๐ท, ๐๐๐๐ ๐ท๐จ๐ท๐ป, ๐ท๐ฟ = ๐ท๐๐๐๐(๐จ, ๐ฟ)
โข However, lose approximation power if messages do not depend on node ids โ cannot distinguish some graph structures
โข Discrimination power at most as powerful as the 1-dimensional Weisfeiler-Lehman (WL) graph isomorphism test, which cannot distinguish certain graphs
โข Graph isomorphism network (GIN)1as powerful as 1-WL
MPNN pseudocodeInitialization: โ(7 = ๐ฅ( for each ๐ฃ( โ ๐for layers โ = 1,โฆ , ๐
for every edge ๐, ๐ โ ๐ธ (in parallel)๐(#< = ๐๐๐บโ(โ(โ>:, โ#โ>:, ๐ฃ( , ๐ฃ# , ๐ฅ(#)
for every node ๐ฃ( โ ๐ (in parallel)โโโ = ๐๐( {๐(#
< : ๐ โ ๐(๐) , โ(โ>:)return โ(? for every ๐ฃ( โ ๐ or ๐ฆ = ๐ ๐ธ๐ด๐ท(mn
โ(?: ๐ฃ( โ๐ )
1 Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. "How powerful are graph neural networks?" ICLR 2019
Drug Discovery
Molecules naturally represented as graphsGNNs commonly used for predicting properties of molecules, e.g. whether it inhibits certain bacteria, etc.
[Figs on paracetamol by Benjah-bmm27 and Ben Mills are in the public domain]
Bellman Ford Shortest Path
for โ = 1,โฆ , ๐for ๐ฃ โ ๐ (in parallel)
d โ ๐ฃ = min"๐ โ โ 1 ๐ข + ๐ค(๐ข, ๐ฃ)
Algorithmic Alignment
Graph Neural Network
for layers โ = 1,โฆ , ๐for ๐ฃ โ ๐ (in parallel)
โ#โ = ๐๐({๐๐ฟ๐ โ"โ%& , โ#โ%& , ๐ค(๐ข, ๐ฃ) }, โ#โ%&)
1 Xu, Keyulu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken-ichi Kawarabayashi, and Stefanie Jegelka. "What can neural networks reason about?.โ ICLR 2020
Sample complexity for learning a GNN is smaller for tasks that the GNN is algorithmically aligned with1.
d โ ๐ฃ easy function to approximate by ๐๐ฟ๐ as
function of ๐ โ โ 1 ๐ข ,๐ค .Same ๐๐ฟ๐ shared by all nodes,
good inductive bias, so low sample complexity
In contrast, learning entire for loop with a single function requires
higher sample complexity
Alignment of GNN and VI
1 Tamar, Aviv, Yi Wu, Garrett Thomas, Sergey Levine, and Pieter Abbeel. "Value iteration networks.โ NeurIPS 2016
Value iteration algorithm for MDPs can be represented in a network form โ value iteration network (VIN)1.โข GNN well aligned with value iteration
Value Iterationfor โ = 1,โฆ , ๐
for ๐ฃ โ ๐ (in parallel)๐ฑ(๐ฃ, โ) = max
+โ@ ๐ ๐ข|๐, ๐ฃ (๐ ๐ฃ, ๐, ๐ข + ๐ฑ ๐ข, โ โ 1 )
Graph Neural Networkfor layers โ = 1,โฆ , ๐
for ๐ฃ โ ๐ (in parallel)โAโ = ๐๐({๐๐ฟ๐ โ@โ>:, โAโ>:, {๐ ๐ข|๐, ๐ฃ , ๐ ๐ฃ, ๐, ๐ข } }, โAโ>:)
1 Karkus, Peter, David Hsu, and Wee Sun Lee. "QMDP-net: Deep learning for planning under partial observability." NeurIPS 2017
Value Iterationfor โ = 1,โฆ , ๐
for ๐ฃ โ ๐ (in parallel)๐ฑ(๐ฃ, โ) = max
'โ" ๐ ๐ข|๐, ๐ฃ (๐ ๐ฃ, ๐, ๐ข + ๐ฑ ๐ข, โ โ 1 )
Example: robot navigation using a map, using VIN1
โข The transition ๐ ๐ข|๐, ๐ฃ and reward ๐ ๐ฃ, ๐, ๐ข function may also need to be learned, instead of being provided.
โข Message passing structure suggests
โข represent transition and reward separately, learned as function of image
โข use as input to same function at all states
1 Lee, Lisa, Emilio Parisotto, Devendra Singh Chaplot, Eric Xing, and Ruslan Salakhutdinov. "Gated path planning networks.โ ICML 2018.
GNN well aligned with value iteration: may work well hereVIN has stronger inductive bias: encode value iteration equations directly
โข Action heads: for each action, take weighted sum from neighbours
โข Then max over actionsGNN potentially more flexible: also aligned other similar algorithms, particularly dynamic programming algorithms
โข May work better if MDP assumption is not accurate
โข Optimization may be easier for some types of architectures1
Alignment of GNN and Graphical Model Algorithms Mean field
repeat ๐ iterationsfor each variable ๐ compute (serially or in parallel)
๐'((๐ฅ() = โ๐!\+"โ,โ. ' ,,0(. ๐, ๐ฅ, ln ๐' ๐' for ๐ โ ๐ ๐ in parallel
๐( ๐ฅ( = &1"# exp โ'โ. ( ๐'( ๐ฅ(
1 Dai, Hanjun, Bo Dai, and Le Song. "Discriminative embeddings of latent variable models for structured data." ICML 2016.
When all potential functions are pairwise ๐# ๐# = ๐!,"(๐ฅ! , ๐ฅ"), then ๐#"(๐ฅ") is a function of only ๐! ๐ฅ! .
โข Can send message directly from variable to variable without combining the messages at factor nodes
โข Can interpret node embedding as feature representation of belief and learn a mapping from one belief to another as a message function1
โข GNN algorithmically well aligned with mean field for pairwise potential functionsโข Similarly well aligned with loopy belief propagation for pairwise potential
What about with higher order potentials ๐# ๐# where ๐# consists of ๐# variables?
โข In tabular representation, size of ๐, ๐,grows exponentially with ๐,, so even loopy belief propagation is not efficient
โข But if ๐, ๐, is low-ranked tensor, then loopy belief propagation can be efficiently implemented1,2
โข Represent ๐+ ๐+ as a tensor decomposition (CP decomposition), where ๐+ is the rank๐+ ๐+ = โ(6:
B2 ๐ค+,:( (๐ฅ+,:)๐ค+,D( (๐ฅ+,D)โฆ๐ค+,E2( ๐ฅ+,E2
โข Factor graph neural network (FGNN), which passes messages on a factor graph is well aligned with this
1 Dupty, Mohammed Haroon, and Wee Sun Lee. "Neuralizing Efficient Higher-order Belief Propagation." arXiv preprint arXiv:2010.09283 (2020)2 Zhang, Zhen, Fan Wu, and Wee Sun Lee. "Factor graph neural network." NeurIPS 2020.
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
Implement message functions๐1# ๐ฅ1 = 2
Zโ&(1)\\
๐Z1(๐ฅ1)
๐#1 ๐ฅ1 = โ๐$\]&
๐# ๐# 2,โ&(#)\^
๐1,(๐ฅ,)
using
๐# ๐# =)โ.$
_$2,โ&(#)
๐ค#,,โ (๐ฅ,)
๐#! ๐ฅ! = โ๐`\+a
๐# ๐# '"โG(#)\z
๐!"(๐ฅ")
=+๐`\+a
+โC,
{`'"โG(#)
๐ค#,"โ (๐ฅ") '"โG(#)\z
๐!"(๐ฅ")
=+โC,
{`๐ค#,!โ (๐ฅ!) +
๐`\+_!
'"โG(#)\z
๐ค#,"โ (๐ฅ")๐!"(๐ฅ")
=+โC,
{`๐ค#,!โ (๐ฅ!) '
"โG(#)\z
++b
๐ค#,"โ (๐ฅ")๐!"(๐ฅ")
Matrix notationโข ๐#1 vector of length ๐1โข ๐1# vector of length ๐1โข ๐ค#,1โ vector of length ๐1โข ๐พ#1 matrix of ๐#rows where each
row is (๐ค#,1โ )cโข โ element-wise multiplication
๐%, ๐ฅ% = โ-โ/(%)\3๐-%(๐ฅ%)
๐,% ๐ฅ% =Tโ;!
`!๐ค,,%โ (๐ฅ%) F
*โ/(,)\6
T51
๐ค,,*โ (๐ฅ*)๐%*(๐ฅ*)
In matrix notation๐7; = ๐พ7;
J โ8โ=(7)\Z๐พ78๐87๐;7 =โ[โ=(;)\\ ๐[;
To make factor and variable messages symmetric
๐7;M =โ8โ=(7)\Z๐พ78๐87
๐;7 =โ[โ=(;)\\๐พ[;J๐[;
M
Matrix vector multiplication followed by product aggregation.
Factor graph provides an easy way to specify dependencies, even higher order ones.Loopy belief propagation can be approximated with the following message passing equations
๐#! =โ"โG(#)\z๐พ#"๐"#๐!# =โ}โG(!)\~๐พ}!
๏ฟฝ ๐}!
Optimizes Bethe free energy if it converges.
โข Uses low rank approximationfor potential functions.
โข Increasing number of rows of ๐พ,% increases rank of tensor decompositionapproximation.
Neuralizing Loopy Belief PropagationAlignment shows that a neural network with small number of parameters can approximate an algorithm, smaller sample complexity in learning: analysis.Can also use the ideas in design. Neuralizing the algorithm
โข Start with the network representing the algorithm to capture the inductive bias.
โข Modify the algorithm, e.g. add computational elements to enhance approximation capability, while mostly maintaining the structure to keep the inductive bias.
โข Start with an algorithm in network form, e.g. ๐,% =โ*โ/(,)\6๐พ,*๐*,๐%, =โ-โ/(%)\3๐พ-%
E๐-%
โข Add network elements to potentially make the network more powerful โ enlarge the class of algorithms that can be learned, e.g.
๐,% = ๐๐ฟ๐(โ*โ/(,)\6๐พ,*๐*,)๐%, = ๐๐ฟ๐(โ-โ/(%)\3๐พ-%
E๐-%)โข It is usually simpler to keep messages only on nodes
instead of on edges, simplify while keeping message passing structure. Can also change aggregrator, e.g. to sum, max, etc. Works well in practice
๐, = ๐๐ฟ๐(๐ด๐บ๐บ*โ/(,)๐พ,*๐*)๐% = ๐๐ฟ๐(๐ด๐บ๐บ-โ/(%)๐พ-%
E๐-)
๐ฅ!
๐ฅ"
๐ฅ#
๐ฅ$
๐ด
๐ต
๐ถ
Message passing neural net on factor graph
Distributed Representation of Graphs and Matrices A graph can be represented using an adjacency matrixA matrix ๐ด, in turn can be factorized ๐ด = ๐๐E
In factorized form, ๐ข%E๐ฃ* = ๐ด%*โข Entry ๐, ๐ of matrix ๐ด is the inner product of row ๐ of matrix ๐ with row ๐ of matrix ๐โข Node ๐ of graph has an embeddings ๐ข( , ๐ฃ( such that the value of edge (๐. ๐) can be
computed as ๐ข(F๐ฃ#โข Distributed representation of graph โ information distributed to the nodes
๐ด = ๐๐E ๐ด%* = ๐ข%E๐ฃ* =T`
๐ข%`๐ฃ*`
๐!! ๐!# ๐!"๐#! ๐## ๐#"๐"! ๐"# ๐""
=๐ข!! ๐ข!#๐ข#! ๐ข##๐ข"! ๐ข"#
๐ฃ!! ๐ฃ#! ๐ฃ"!๐ฃ!# ๐ฃ## ๐ฃ"#
With distributed representation, using a subset of embeddings ๐ข% , ๐ฃ%gives representation of subgraph!
Attention NetworkUsing matrix factorization, we can show that the transformer-type attention network is well aligned with value iteration for MDP In the attention network, we have a set of nodes, each of which has an embedding as input, and the same operations (implemented with a network) are applied at each node in a layer.
Multi-Head Attention
Add
Feedforward
Add
Multi-head attentionLet embedding at node ๐ be ๐ฅ`
โข Each node has ๐พ attention headsโข Each attention head has weight
matrices: query weights ๐พa used to compute query ๐% = ๐พa๐ฅ%, key weights ๐พb used to compute key ๐% = ๐พb๐ฅ% and value weights ๐พcused to compute ๐ฃ% = ๐พc๐ฅ%.
โข Compute ๐%* = ๐%E๐* with all other nodes ๐. Compute probability vector ๐%* for all ๐ using the softmax function ๐%* = ๐,"1/โโ ๐,"โ. The output of the head is a weighted average of all the values โ* ๐%*๐ฃ*.
๐ฅ'
๐' = ๐พ(๐ฅ'๐ฃ' = ๐พ)๐ฅ'๐' = ๐พ*๐ฅ'๐#
๐'# = ๐'+๐#
Softmax
โฆ .
Weighted Average
๐ฅ#
Concatenate Across K Heads
Projection
Output of Multi-Head Attention
Output of multi-head attentionโข Concatenate outputs of ๐พ attention heads โข Project back to vector of the same length โข Passed to a feedforward network through a
residual operation (added to ๐ฅ%)
๐ฅ'
๐' = ๐พ(๐ฅ'๐ฃ' = ๐พ)๐ฅ'๐' = ๐พ*๐ฅ'๐#
๐'# = ๐'+๐#
Softmax
โฆ .
Weighted Average
๐ฅ#
Concatenate Across K Heads
Projection
Output of Multi-Head Attention
Multi-Head Attention
Add
Feedforward
Add
Alignment of Attention Network with Value Iteration
Constructing input ๐ฅ%โข Use matrix factorization to get a
distributed representation of the log of each transition matrix ๐ฟ, = ๐,๐,E for each action ๐ where ๐ฟ,[๐ , ๐ A] = log ๐(๐ A|๐ , ๐).
โข Construct input for node ๐, ๐ฅ%, by stacking up ๐พ + 1 copies of initial value ๐ฃ% followed by embeddings of transition matrices
๐ b%โฎ๐ !%๐ฃb%๐ขb%โฎ๐ฃ!%๐ข!%๐ฃ%โฎ๐ฃ%
๐พ + 1 copies of initial value ๐ฃ%
Embeddings of transition matrices for each action for state ๐
Reward for eachaction for state ๐
Constructing input ๐ฅ%โข Compute expected reward ๐ ,H = ๐ธ[๐ ๐ , ๐, ๐ A ] for each action of the ๐พ actions.
โข Place expected reward ๐ ,% for each action ๐ into a single vector input and concatenate to the earlier input vector.
๐ b%โฎ๐ !%๐ฃb%๐ขb%โฎ๐ฃ!%๐ข!%๐ฃ%โฎ๐ฃ%
๐พ + 1 copies of initial value ๐ฃ%
Embeddings of transition matrices for each action for state ๐
Reward for eachaction for state ๐
Need to extract out ๐ข7; , ๐ฃ78 to compute ๐ ๐ ๐, ๐ = exp(๐ข7;J ๐ฃ78)
โข Setting ๐d = [0, ๐ผe , 0] where ๐ผ_ is a ๐ by ๐identity matrix at the appropriate columns will extract ๐ข#1 = ๐d๐ฅ1 as query.
โข Similarly can construct ๐e to extract ๐ฃ#, as key, allowing softmax of inner product ๐ข#1c ๐ฃ#, to correctly compute transition probabilities.
โข Set ๐f to extract out the value component from ๐ฅ1
With this construction, output of head ๐ is๐7 = โ8 ๐ ๐ ๐, ๐ ๐ฃ8
At each layer of value iteration:for each action ๐ of each state ๐ (in parallel)
Collect ๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ 3 + ๐ ๐ 3 ) from all ๐ โฒ to ๐ at ๐ if ๐ ๐ โฒ|๐, ๐ non-zeroSum all messages
for each node ๐ (in parallel)Collect message from its corresponding actions ๐Take the maximum of the messages
0โฎ0
โฆโฑโฆ
0โฎ0๐ผ`0โฎ0
โฆโฑโฆ
0โฎ0
๐ b%โฎ๐ !%๐ฃb%๐ขb%โฎ๐ฃ!%๐ข!%๐ฃ%โฎ๐ฃ%
= ๐ข,%
Output of head ๐ is
๐#! =4"
๐ ๐ ๐, ๐ ๐ฃ"
Concatenate as ๐,! โฆ๐๏ฟฝ! ๏ฟฝ and add to first ๐พ components of ๐ฅ! to form input to feedforward network,
At each layer of value iteration:for each action ๐ of each state ๐ (in parallel)
Collect ๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ 3 + ๐ ๐ 3 ) from all ๐ โฒ to ๐ at ๐ if ๐ ๐ โฒ|๐, ๐ non-zeroSum all messages
for each node ๐ (in parallel)Collect message from its corresponding actions ๐Take the maximum of the messages
Multi-Head Attention
Add
Feedforward
Add๐ b%โฎ๐ !%๐ฃb%๐ขb%โฎ๐ฃ!%๐ข!%๐ฃ%
๐ฃ% + ๐b%โฎ
๐ฃ% + ๐!%
Feedforward network:โข Subtract ๐ฃ_๐ from ๐ฃ% + ๐,% to get ๐,%โข Compute ๐ฃ%A = max
,{๐ ,%+๐,%}
โข Construct output so that adding back the input of the feedforward network gets ๐ฃ%Ain the first ๐พ + 1 positions and the original MDP parameters in the remaining positions: output the value ๐ฃ%A โ (๐ฃ% + ๐,%) as first ๐พ elements and ๐ฃ%A โ ๐ฃ% as the ๐พ + 1st element,
At each layer of value iteration:for each action ๐ of each state ๐ (in parallel)
Collect ๐ ๐ โฒ|๐, ๐ (๐ ๐ , ๐, ๐ 3 + ๐ ๐ 3 ) from all ๐ โฒ to ๐ at ๐ if ๐ ๐ โฒ|๐, ๐ non-zeroSum all messages
for each node ๐ (in parallel)Collect message from its corresponding actions ๐Take the maximum of the messages
Multi-Head Attention
Add
Feedforward
Add๐ b%โฎ๐ !%๐ฃb%๐ขb%โฎ๐ฃ!%๐ข!%๐ฃ%
๐ฃ% + ๐b%โฎ
๐ฃ% + ๐!%
All methods discussed are message passing methods on a graphโข Messages are constructed using operations such as addition,
multiplication, max, exponentiation, etc.โข Can be represented using network of computational elementsโข Usually same network at each graph node
Each iteration of message passing forms a layerPutting together layers form a deep network
โข Different layers can have different parameters, additional flexibilityWith appropriate loss functions, can learn using gradient descent if all network elements are differentiable: backpropagation
Backpropagation and Recurrent Backpropagation
If different layers all have the same parameters, we have a recurrent neural network.For some recurrent networks, the inputs the same at each iteration and we are aiming for solution at convergence
โข Loopy belief propagationโข Value iteration for discounted MDPโข Some graph neural networks1โข Some tranformer architectures2
Recurrent backpropagation and other optimization methods based on implicit functions can be used3
โข Constant memory usage
1 Scarselli, Franco, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. "The graph neural network model." IEEE Transactions on Neural Networks, 2008. 2 Bai, Shaojie, J. Zico Kolter, and Vladlen Koltun. "Deep equilibrium models." NeurIPS 20193 Zico Kolter, David Duvenaud, and Matt Johnson, โDeep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond.โ NeurIPS 2020 Tutorial http://implicit-layers-tutorial.org/
Message Passing in Machine LearningWe viewed some โclassicโ message passing algorithms in machine learning through variational principles
โข Loopy belief propagationโข Mean fieldโข Value iteration
What optimization problems are they solving?
We relate graph neural networks and attention networks to the โclassicโ message passing algorithms
โข Algorithmic alignment (analysis): Can they simulate those algorithms using a small network?
โข Neuralizing algorithms (design): Can we enhance those algorithms into more powerful neural versions
References1โข David Mackayโs Error Correcting Code demo
http://www.inference.org.uk/mackay/codes/gifs/โข Zheng, Shuai, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav
Vineet, Zhizhong Su, Dalong Du, Chang Huang, and Philip HS Torr. "Conditional random fields as recurrent neural networks." In Proceedings of the IEEE international conference on computer vision, pp. 1529-1537. 2015.
โข Yedidia, Jonathan S., William T. Freeman, and Yair Weiss. "Constructing free-energy approximations and generalized belief propagation algorithms." IEEE Transactions on information theory 51.7 (2005): 2282-2312.
โข Wainwright, Martin J., and Michael Irwin Jordan. Graphical models, exponential families, and variational inference. Now Publishers Inc, 2008.
1 This list contains only references referred to within the slides and not the many other works related to the material in the tutorial.
โข Markov Decision Process (figure) by waldoalvarez - Own work, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=59364518
โข Gilmer, Justin, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. "Neural message passing for quantum chemistry." In International conference on machine learning, pp. 1263-1272. PMLR, 2017.
โข Loukas, Andreas. "What graph neural networks cannot learn: depth vs width.โ ICLR 2020
โข Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. "How powerful are graph neural networks?" ICLR 2019
โข Xu, Keyulu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken-ichiKawarabayashi, and Stefanie Jegelka. "What can neural networks reason about?.โ ICLR 2020
โข Skeletal formula of paracetamol by Benjah-bmm27 is in the public domain, https://en.wikipedia.org/wiki/Paracetamol#/media/File:Paracetamol-skeletal.svg
โข Ball and stick model of paracetamol by Ben Mills is in the public domain, https://en.wikipedia.org/wiki/Paracetamol#/media/File:Paracetamol-from-xtal-3D-balls.png
โข Tamar, Aviv, Yi Wu, Garrett Thomas, Sergey Levine, and Pieter Abbeel. "Value iteration networks.โ NeurIPS 2016.
โข Lee, Lisa, Emilio Parisotto, Devendra Singh Chaplot, Eric Xing, and Ruslan Salakhutdinov. "Gated path planning networks.โ ICML 2018.
โข Karkus, Peter, David Hsu, and Wee Sun Lee. "QMDP-net: Deep learning for planning under partial observability." NeurIPS 2017.
โข Dai, Hanjun, Bo Dai, and Le Song. "Discriminative embeddings of latent variable models for structured data." ICML 2016.
โข Dupty, Mohammed Haroon, and Wee Sun Lee. "Neuralizing Efficient Higher-order Belief Propagation." arXiv preprint arXiv:2010.09283 (2020).
โข Zhang, Zhen, Fan Wu, and Wee Sun Lee. "Factor graph neural network." NeurIPS 2020.
โข Scarselli, Franco, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. "The graph neural network model." IEEE Transactions on Neural Networks 20, no. 1 (2008): 61-80.
โข Bai, Shaojie, J. Zico Kolter, and Vladlen Koltun. "Deep equilibrium models." NeurIPS 2019.
โข Zico Kolter, David Duvenaud, and Matt Johnson, โDeep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond.โ NeurIPS 2020 Tutorial http://implicit-layers-tutorial.org/
Correctness of Belief Propagation on Trees
๐,% ๐ฅ%
๐ป๐
๐ ๐
๐ฅe
๐ฅ%
๐ฅ* ๐ฅ`
๐
๐
๐%, ๐ฅ% = F-โ/(%)\3
๐-%(๐ฅ%)
๐,% ๐ฅ% = โ๐!\5"
๐, ๐, F*โ/(,)\6
๐*%(๐ฅ*)
๐% 7" โ F,โ/(%)
๐,%(๐ฅ%)
Assume ๐% is the subtree along edge ๐, ๐ , rooted at ๐ฅ% and message ๐,% ๐ฅ% is sent
from ๐ to ๐. After ๐ iterations of message passing, the message correctly marginalizes away other variables in the subtree๐,% ๐ฅ% = โ7AโE"\5"โ,โE"๐,(๐,
A )where ๐ is the height of the subtree.
Assume marginalization correctly done for subtrees of height < ๐โข Push summation inside product,
can group at subtrees depth 2 below ๐ because of tree structure
โข By inductive hypothesis, marginalization correct at depth 2 below
So after ๐ iterations of message passing, marginalization at height ๐ correct
CorrectnessInit all messages to all-one vectors. Then:
๐ฅ%
๐ฅ* ๐ฅ`
๐
๐
๐,' ๐ฅ' =1"!
1""
1"#
1"...
๐, ๐ฅ' , ๐ฅ# , ๐ฅ. ๐/ โฆ โฆ
๐ ๐
๐ฅe
๐ฅ%
๐ฅ* ๐ฅ`
๐
๐
๐,' ๐ฅ' =1"!
1""
๐, ๐ฅ' , ๐ฅ# , ๐ฅ. 1"#
1"...
๐/ โฆ โฆ1"โฆ
๐1 โฆ โฆ
๐ ๐
๐ฅe
Compute at ๐/# ๐ฅ#instead
Mean Field Update DerivationIn mean field approximation, we find ๐ that minimizes the variational free energy
๐น ๐ = โ' ๐ ๐ ๐ธ ๐ +โ' ๐ ๐ ln ๐(๐)
when ๐ is restricted to a factored form
๐Bc ๐ฅ = โ;I@= ๐;(๐ฅ;).
Consider the dependence on a single variable ๐,(๐ฅ,) with all other variables fixed
๐น ๐ ='๐(1.$
&
๐1(๐ฅ1) ๐ธ ๐ +'๐(1.$
&
๐1(๐ฅ1) ln(1.$
&
๐1(๐ฅ1)
='-2๐,(๐ฅ,)'
๐\]2(1i,
&
๐1(๐ฅ1) ๐ธ ๐ +'-2๐, ๐ฅ, ln ๐,(๐ฅ,) + ๐๐๐๐ ๐ก
='-2๐, ๐ฅ, ln
1๐,5 ๐ฅ,
+ ๐๐๐๐ ๐ก +'-2๐, ๐ฅ, ln ๐,(๐ฅ,) + ๐๐๐๐ ๐ก
where ๐,5 ๐ฅ, =1๐,โฒexp โ'
๐\]2(1i,
&
๐1 ๐ฅ1 ๐ธ ๐
= ๐พ๐ฟ(๐,||๐,5) + ๐๐๐๐ ๐ก
which is minimized at ๐, ๐ฅ, = ๐,5 ๐ฅ, .
Variational Free Energy for Trees
As ๐G ๐ฅG also appears on the numerator, we can write the distribution as
โ+ ๐+(๐+)โ( ๐( ๐ฅ( ?>:
Multiplying the numerator and denominator by โ( ๐((๐ฅ() and observing that โ#โ1โ&(#) ๐1(๐ฅ1) = โ( ๐( ๐ฅ( ? gives the required expression
๐ฅ!
๐ฅ"๐ฅ# ๐ฅ$
๐ด๐ต
๐ถ
We can write a tree distribution asโ+ ๐+ ๐+ โ( ๐((๐ฅ()โ+โ1โ&(#) ๐1(๐ฅ1)
To see this, we can write a tree distribution as
๐G ๐ฅG :+
๐+(๐๐|๐๐ ๐+ )
where ๐ฅGis the root, and ๐๐ ๐+ is the parent variable for factor ๐ in the treeWe write
๐+ ๐๐ ๐๐ ๐+ =๐+(๐+)๐H+(๐๐)
Each variable appears as a parent ๐ โ 1 times, where ๐ is the degree of the variable node, except the root which appears ๐ times.
Substituting
๐ ๐ =โ7 ๐7 ๐7 โ; ๐;(๐ฅ;)โ7โ;โ=(7) ๐;(๐ฅ;)
into ๐น ๐ =<
'๐ ๐ ๐ธ ๐ +<
'๐ ๐ ln ๐(๐)
we get ๐น ๐
= โ<7I@
B
<๐!๐7 ๐7 ln ๐7 ๐7 +<
7I@
B
<๐๐๐7 ๐๐ ln
๐7(๐7)โ;โ=(7) ๐;(๐ฅ;)
+<;I@
=
<'"๐; ๐ฅ; ln ๐;(๐ฅ;)
Belief Propagation as Stationary Point of Bethe Free EnergyConsider optimizing Bethe free energy subject to constraints described. We form the Lagrangian
๐ฟ = ๐นIJ5KJ +9+
๐พ+[1 โ9๐2
๐+ ๐๐ ] +9(๐พ([1 โ9
L5
๐( ๐ฅ( ]
+ 9(9
+โ*(()9
L5๐+( ๐ฅ( [๐( ๐ฅ( โ9
๐2\'5๐+ ๐๐ ]
Differentiating with respect to ๐((๐ฅ() and setting to 00 = โ๐( + 1 + ln ๐( ๐ฅ( โ ๐พ( +9
+โ* (๐+( ๐ฅ(
๐( ๐ฅ( โ exp โ9+โ* (
๐+( ๐ฅ( =:+โ*(()
๐+((๐ฅ()
where ๐( is the degree of node and ๐+((๐ฅ() =exp(โ๐+( ๐ฅ( )
๐น34564 ๐ = ๐34564 ๐ โ ๐ป34564(๐)
๐34564 ๐ = โ1,78
9
1"&๐, ๐, ln ๐, ๐,
H34564
= โ1,78
9
1๐๐๐, ๐๐ ln
๐, ๐,โ'โ< , ๐' ๐ฅ'
โ1'78
<
1"(๐' ๐ฅ' ln ๐'(๐ฅ')
๐น34564 ๐ = ๐34564 ๐ โ ๐ป34564(๐)
๐34564 ๐ = โ1,78
9
1"&๐, ๐, ln ๐, ๐,
H34564
= โ1,78
9
1๐๐๐, ๐๐ ln
๐, ๐,โ'โ< , ๐' ๐ฅ'
โ1'78
<
1"(๐' ๐ฅ' ln ๐'(๐ฅ')
Previously:
๐' ๐ฅ' โ exp โ1,โ< '
๐,' ๐ฅ'
Consider optimizing Bethe free energy subject to constraints described
๐ฟ = ๐น)jklj +'#
๐พ#[1 โ'๐$
๐# ๐๐ ] +'1๐พ1[1 โ'
-&
๐1 ๐ฅ1 ]
+ '1'
#โ&(1)'
-&๐#1 ๐ฅ1 [๐1 ๐ฅ1 โ'
๐$\]&๐# ๐๐ ]
Differentiating with respect to ๐# ๐๐ and setting to 00 = โ ln๐# ๐# + 1 + ln ๐# ๐๐โ ln(
1โ& #๐1 ๐ฅ1 โ๐พ# โ'
1โ& #๐#1 ๐ฅ1
= โ ln๐# ๐# + ln ๐# ๐๐ + ๐๐๐๐ ๐ก+'
1โ&(#)'
#โ&(1)๐#1(๐ฅ1) โ'
1โ& #๐#1 ๐ฅ1
๐# ๐๐ โ ๐# ๐# exp โ'1โ& #
'mโ& 1 \\
๐m1 ๐ฅ1
= ๐# ๐ฅ# (1โ&(#)
(mโ&(1)\\
๐m1(๐ฅ1)
where ๐m1 ๐ฅ1 = exp(โ๐m1 ๐ฅ1 )
For โ7!\6 ๐,(๐ฅ,) to be consistent with ๐%(๐ฅ%) = โfโ/(%)๐f%(๐ฅ%), we get the belief propagation equation
๐,% ๐ฅ% =T7!\6
๐, ๐ฅ, F*โ/(,)\6
Ffโ/(*)\3
๐f*(๐ฅ*)
Stationary points of the Bethe free energy are fixed points of loopy belief propagation updates
Previously:
๐% ๐ฅ% =Ffโ/(%)
๐f%(๐ฅ%)
๐, ๐๐ = ๐, ๐ฅ, โ*โ/(,)โ fโ/(*)\3๐f*(๐ฅ*)
[Fig from Yedidia et. al. 2005]
Bellman Update ContractionWe show that Bellman update is a contraction
๐ต๐! โ ๐ต๐# โค ๐พ ๐! โ ๐#
First we show that for any two functions ๐ ๐ and ๐(๐)|max,
๐ ๐ โ max,๐ ๐ | โค max
,|๐ ๐ โ ๐ ๐ |
To see this, assume consider the case where max,๐ ๐ โฅ max
,๐(๐). Then
|max,๐ ๐ โ max
,๐ ๐ | = max
,๐ ๐ โ max
,๐ ๐
= ๐ ๐โ โmax,๐ ๐
โค ๐ ๐โ โ ๐ ๐โโค max
,|๐ ๐ โ ๐ ๐ |
where ๐โ = ๐๐๐๐๐๐ฅ, ๐(๐). The case where max,๐ ๐ โฅ max
,๐ ๐ is similar.
Now, for any state ๐ ๐ต๐! ๐ โ ๐ต๐# ๐
= ๏ฟฝ
๏ฟฝ
๐๐๐ฅ, THA
๐ ๐ A ๐ , ๐ ๐ ๐ , ๐, ๐ A + ๐พ๐! ๐ โฒ
โ ๐๐๐ฅ, THA
๐ ๐ A ๐ , ๐ ๐ ๐ , ๐, ๐ A + ๐พ๐# ๐ โฒ
โค ๐๐๐ฅ, TH2๐ ๐ A ๐ , ๐ ๐ ๐ , ๐, ๐ A + ๐พ๐! ๐ โฒ โT
H2๐ ๐ A ๐ , ๐ ๐ ๐ , ๐, ๐ A + ๐พ๐# ๐ โฒ
= ๐พ๐๐๐ฅ, TH2๐ ๐ A ๐ , ๐ (๐! ๐ โฒ โ ๐# ๐ โฒ )
= ๐พ TH2๐(๐ A|๐ , ๐โ)(๐! ๐ โฒ โ ๐# ๐ โฒ )
where we have use |max,
๐ ๐ โ max,๐ ๐ | โค max
,|๐ ๐ โ ๐ ๐ |.