Wee Sun Lee School of Computing National ... - NeurIPS

95
Wee Sun Lee School of Computing National University of Singapore [email protected] Neural Information Processing Systems December 2021

Transcript of Wee Sun Lee School of Computing National ... - NeurIPS

Wee Sun LeeSchool of Computing

National University of [email protected]

Neural Information Processing Systems December 2021

Message Passing

Machine learning using distributed algorithm on graph ๐บ = ๐‘‰, ๐ธ .Node ๐‘ฃ! โˆˆ ๐‘‰ do local computation,depends only on information at ๐‘ฃ!, neighbours, and incident edges ๐‘’!" โˆˆ ๐ธ.Information sent through ๐‘’!" to ๐‘ฃ" as message ๐‘š!".

vi

vk

vj

mij

Why Message Passing Algorithms?

Effective in practice! โ€ข Graph neural networksโ€ข Transformerโ€ข Probabilistic graphical model inference โ€ข Value iteration for Markov decision process ....

Potentially easier to parallelize.

Plan for TutorialCover better understood, more โ€œinterpretableโ€ models and algorithms

โ€ข Probabilistic graphical models (PGM), inference algorithms โ€ข Markov decision process (MDP), decision algorithmWhat do the components of the models represent?What objective functions are the algorithms are optimizing for?

Discuss more flexible, less โ€œinterpretableโ€ methodsโ€ข Graph neural networks โ€ข Attention networks, transformer

Connect to PGM and MDP, help understand the inductive biases.

Outline

โ€ข Probabilistic Graphical Models

Probabilistic Graphical ModelFocus on Markov random fields๐‘ random variables {๐‘‹!, ๐‘‹", โ€ฆ , ๐‘‹#}๐‘ ๐’™ = ๐‘(๐‘‹! = ๐‘ฅ!, ๐‘‹" = ๐‘ฅ", โ€ฆ , ๐‘‹$ = ๐‘ฅ$)factors into a product of functions

๐‘ ๐’™ =1๐‘/%

๐œ“%(๐’™%)

โ€ข ๐‘€ non-negative compatibility or potentialfunctions ๐œ“! , ๐œ“, โ€ฆ , ๐œ“"

โ€ข ๐‘ฅ#, the argument of ๐œ“#, is a subset of ๐‘ฅ$, ๐‘ฅ%, โ€ฆ , ๐‘ฅ&

โ€ข ๐‘, the partition function, is the normalizing constant

Graphical representation using factor graph: bipartite graph, each factor node connected to variable nodes that it depends on $'๐œ“! ๐‘ฅ$, ๐‘ฅ( ๐œ“) ๐‘ฅ$, ๐‘ฅ%, ๐‘ฅ* ๐œ“+(๐‘ฅ*)

๐‘ฅ!

๐‘ฅ"

๐‘ฅ#

๐‘ฅ$

๐ด

๐ต

๐ถ

When compatibility functions always positive, can write

๐‘ ๐’™ =1๐‘/#

๐œ“#(๐’™#)

=1๐‘๐‘’$%(๐’™)

where

๐ธ ๐’™ = โˆ’4#

)

ln ๐œ“# ๐’™#

is the energy function.

Error Correcting Codes

[Fig from http://www.inference.org.uk/mackay/codes/gifs/]

๐‘ฅ!

๐‘ฅ"

๐‘ฅ#

๐‘ฅ$

๐ด

๐ต

๐ถ

๐‘ฆ!

๐‘ฆ"

๐‘ฆ#

๐‘ฆ$

๐‘ฆ% observed,corrupted version of ๐‘ฅ%

๐œ“& ๐‘ฅ!, ๐‘ฅ", ๐‘ฅ$ = 11 if ๐‘ฅ!โจ๐‘ฅ#โจ๐‘ฅ$ = 10 ๐‘œ๐‘กโ„Ž๐‘’๐‘Ÿ๐‘ค๐‘–๐‘ ๐‘’

Codeword

Parity bits

MAP, Marginals, and Partition FunctionMaximum a posteriori (MAP): find a state ๐’™ that maximizes ๐‘ ๐’™

โ€ข Equivalently minimizes the energy ๐ธ ๐’™

Marginal: probabilities for individual variable

๐‘! ๐‘ฅ! = โˆ‘๐’™\+' ๐‘ ๐’™

Partition function: Compute the normalizing constant

Z =4๐’™/#

๐œ“#(๐’™#)

Semantic SegmentationIn conditional random field (CRF), the conditional distribution ๐‘ ๐’™ ๐‘ฐ = ,

-(๐‘ฐ)๐‘’$%(๐’™|๐‘ฐ) is modeled.

โ€ข For semantic segmentation ๐‘ฐ is the image and ๐’™ is the semantic class label.๐ธ ๐’™|๐‘ฐ = โˆ‘%๐œ™( ๐‘ฅ%|๐‘ฐ + โˆ‘%)*๐œ™+ ๐‘ฅ% , ๐‘ฅ* ๐‘ฐ)

[Figs from Zheng et. al. 2015]

Message passing algorithmSum product computes the marginals

๐‘›%, ๐‘ฅ% = F-โˆˆ/(%)\3

๐‘š-%(๐‘ฅ%)

๐‘š,% ๐‘ฅ% = โˆ‘๐’™!\5"

๐œ“, ๐’™, F*โˆˆ/(,)\6

๐‘›%*(๐‘ฅ*)

๐‘% 7" โˆ F,โˆˆ/(%)

๐‘š,%(๐‘ฅ%)

Max product solves the MAP problem: just replace sum with max in message passingWorks exactly on trees, dynamic programming

Belief Propagation๐‘ฅ!

๐‘ฅ"

๐‘ฅ#

๐‘ฅ$

๐ด

๐ต๐ถ

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

tree

Correctness ๐ท

๐ท

Every variable and factor nodes compute messages in parallel at each iteration

โ€ข After ๐‘‚(๐ท) iterations, where ๐ท is the diameter of the tree, all messages and all marginals are correct.

Suffices to pass messages from leaves to the root and backโ€ข More efficient for serial computation

Belief Propagation on Trees

Loopy Belief PropagationBelief propagation can also be applied to general probabilistic graphical modelsOften called loopy belief propagationAs a message passing algorithm:

Init all messages ๐‘›%, , ๐‘š,% to all-one vectorsrepeat ๐‘‡ iterations

for each variable ๐‘– and factor ๐‘Ž compute (in parallel)๐‘›%, ๐‘ฅ% = โˆ-โˆˆ/(%)\3๐‘š-%(๐‘ฅ%) for each ๐‘Ž โˆˆ ๐‘(๐‘–)๐‘š,% ๐‘ฅ% = โˆ‘

๐’™!\5"๐œ“, ๐’™, โˆ*โˆˆ/(,)\6๐‘›%*(๐‘ฅ*) for each ๐‘– โˆˆ ๐‘(๐‘Ž)

return ๐‘% 7" =!8"โˆ,โˆˆ/(%)๐‘š,%(๐‘ฅ%) for each ๐‘–

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

Belief propagation may fail when there are cycles

โ€ข May not even convergeโ€ข Often works well in practice

when converges

Variational inference ideas help understand loopy belief propagation

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

๐‘›%, ๐‘ฅ% = F-โˆˆ/(%)\3

๐‘š-%(๐‘ฅ%)

๐‘š,% ๐‘ฅ% = โˆ‘๐’™!\5"

๐œ“, ๐’™, F*โˆˆ/(,)\6

๐‘›%*(๐‘ฅ*)

๐‘% 7" โˆ F,โˆˆ/(%)

๐‘š,%(๐‘ฅ%)

Variational Principles

View message passing algorithms through lens of variational principlesโ€ข A variational principle solves a problem by viewing the solution

as an extremum (maximum, minimum, saddle point) of a function or functional

โ€ข To understand or โ€œinterpretโ€ an algorithm, ask โ€œwhat objective might the algorithm implicitly be optimizing for?โ€

In standard variational inference, we approximate the Helmholtz free energy

๐น! = โˆ’ ln๐‘๐‘ = โˆ‘๐’™ ๐‘’#$(๐’™) , ๐‘(๐’™) = ๐‘’#$(๐’™)/๐‘

by turning it into an optimization problem

For target belief ๐‘ and arbitrary belief ๐‘ž๐น! = ๐น ๐‘ž โˆ’ ๐พ๐ฟ(๐‘ž||๐‘)

where ๐น ๐‘ž is the variational free energy๐น ๐‘ž =<

'๐‘ž ๐’™ ๐ธ ๐’™ +<

'๐‘ž ๐’™ ln ๐‘ž(๐’™)

and ๐พ๐ฟ(๐‘ž||๐‘) is the Kullback Lieber divergence between ๐‘ž and ๐‘

๐พ๐ฟ(๐‘ž| ๐‘ =<'๐‘ž ๐’™ ln

๐‘ž ๐’™๐‘ ๐’™

Variational Inference

Derivation:ln ๐‘ ๐‘ฅ = โˆ’๐ธ ๐‘ฅ โˆ’ ln ๐‘

๐น! = โˆ’ ln ๐‘ =1"

๐‘ž ๐‘ฅ ๐ธ ๐‘ฅ +1"

๐‘ž ๐‘ฅ ln ๐‘ ๐‘ฅ

=1"

๐‘ž ๐‘ฅ ๐ธ ๐‘ฅ +1"

๐‘ž ๐‘ฅ ln ๐‘ ๐‘ฅ + 1"

๐‘ž ๐‘ฅ ln ๐‘ž ๐‘ฅ โˆ’1"

๐‘ž ๐‘ฅ ln ๐‘ž ๐‘ฅ

=1"

๐‘ž ๐‘ฅ ๐ธ ๐‘ฅ + 1"

๐‘ž ๐‘ฅ ln ๐‘ž ๐‘ฅ โˆ’1"

๐‘ž ๐‘ฅ ln ๐‘ž ๐‘ฅ /๐‘(๐‘ฅ)

TerminologyThe variational free energy๐น ๐‘ž =T

7๐‘ž ๐’™ ๐ธ ๐’™ +T

7๐‘ž ๐’™ ln ๐‘ž(๐’™)

= ๐‘ˆ ๐‘ž โˆ’ ๐ป(๐‘ž)where ๐‘ˆ ๐‘ž = โˆ‘7 ๐‘ž ๐’™ ๐ธ ๐’™is the variational average energy and ๐ป(๐‘ž)= -โˆ‘7 ๐‘ž ๐’™ ln ๐‘ž(๐’™)is the variational entropy.

๐พ๐ฟ(๐‘ž| ๐‘ โ‰ฅ 0 and is zero when ๐‘ž = ๐‘. From ๐น0 = ๐น ๐‘ž โˆ’ ๐พ๐ฟ(๐‘ž||๐‘), ๐น ๐‘ž is an upper bound for ๐น0

โ€ข Minimizing ๐น ๐‘ž improves approximation, exact when ๐‘ž = ๐‘

Minimizing ๐น ๐‘ž intractable in general

โ€ข One approximate method is to use a tractable ๐‘ž

โ€ข Mean field uses a factorized belief

๐‘ž9: ๐‘ฅ =F%;!

/

๐‘ž%(๐‘ฅ%)

Mean Field

Mean field often solved by coordinate descentโ€ข Optimize one variable at a time, holding other variables

constant

๐‘ž[ ๐‘ฅ[ = \]56 exp โˆ’โˆ‘๐’™\_5โˆ`a[

b ๐‘ž` ๐‘ฅ` ๐ธ ๐’™

Coordinate descent converges to local optimum. โ€ข Local optimum is fixed point of updates for all variables.โ€ข Parallel updates can also be done but may not always converge

Derivation

Mean field as message passingRecall ๐ธ ๐’™ = โˆ’โˆ‘,9 ln๐œ“, ๐’™,

๐‘ž# ๐‘ฅ# =1๐‘#$exp 9

๐’™\'!:()#

*

๐‘ž( ๐‘ฅ( 9+

,

ln๐œ“+ ๐’™+

=1๐‘#$exp 9

+โˆˆ*(#)

9๐’™\'!

:()#

*

๐‘ž( ๐‘ฅ( ln๐œ“+ ๐’™+ + 9+โˆ‰*(#)

9๐’™\'!

:()#

*

๐‘ž( ๐‘ฅ( ln๐œ“+ ๐’™+

To compute ๐‘ž, ๐‘ฅ, , only need ๐œ“# ๐’™# for neighbouring factors ๐‘Ž โˆˆ ๐‘ ๐‘—

Does not depend on ๐‘ฅ#, constant

Previously๐‘ž* ๐‘ฅ* = !

812 exp โˆ’โˆ‘๐’™\51โˆ%<*

/ ๐‘ž% ๐‘ฅ% ๐ธ ๐’™

As a message passing algorithm on a factor graph:repeat ๐‘‡ iterations

for each variable ๐‘— compute (serially or in parallel)๐‘š78(๐‘ฅ8) = โˆ‘๐’™!\:1โˆ;โˆˆ= 7 ,;?8

= ๐‘ž; ๐‘ฅ; ln ๐œ“7 ๐’™7 for ๐‘Ž โˆˆ ๐‘ ๐‘— in parallel

๐‘ž8 ๐‘ฅ8 = @A12 exp โˆ‘7โˆˆ= 8 ๐‘š78 ๐‘ฅ8

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

Loopy Belief Propagation and Bethe Free Energy1

For a tree-structured factor graph, the variational free energy

๐น ๐‘ž =)-๐‘ž ๐’™ ๐ธ ๐’™ +)

-๐‘ž ๐’™ ln ๐‘ž(๐’™)

= โˆ’)#.$

"

)๐’™$๐‘ž# ๐’™# ln๐œ“# ๐’™#

+)#.$

"

)๐’™๐’‚๐‘ž# ๐’™๐’‚ ln

๐‘ž#(๐’™#)โˆ1โˆˆ&(#)๐‘ž1(๐‘ฅ1)

+)1.$

&

)-&๐‘ž1 ๐‘ฅ1 ln ๐‘ž1(๐‘ฅ1)

1 Yedidia, Jonathan S., William T. Freeman, and Yair Weiss. "Constructing free-energy approximations and generalized belief propagation algorithms." IEEE Transactions on information theory 51.7 (2005): 2282-2312.

Derivation

Variational average energy

Variational entropy

For the Bethe approximation, the following Bethe free energy ๐น?@AB@ ๐‘ž = ๐‘ˆ?@AB@ ๐‘ž โˆ’ ๐ป?@AB@(๐‘ž) is used even though the graph may not be a tree, where

๐‘ˆ?@AB@ ๐‘ž = โˆ’4#C,

)

4D=๐‘ž# ๐’™# ln ๐œ“# ๐’™#

H?@AB@ = โˆ’4#C,

)

4๐’™๐’‚๐‘ž# ๐’™๐’‚ ln

๐‘ž# ๐’™๐’‚โˆ!โˆˆG # ๐‘ž! ๐‘ฅ!

โˆ’4!C,

G

4D'๐‘ž! ๐‘ฅ! ln ๐‘ž!(๐‘ฅ!)

For a tree-structured factor graph, the variational free energy

๐น ๐‘ž =)-๐‘ž ๐’™ ๐ธ ๐’™ +)

-๐‘ž ๐’™ ln ๐‘ž(๐’™)

= โˆ’)#.$

"

)๐’™$๐‘ž# ๐’™# ln๐œ“# ๐’™# +)

#.$

"

)๐’™๐’‚๐‘ž# ๐’™๐’‚ ln

๐‘ž#(๐’™#)โˆ1โˆˆ&(#)๐‘ž1(๐‘ฅ1)

+)1.$

&

)-&๐‘ž1 ๐‘ฅ1 ln ๐‘ž1(๐‘ฅ1)

In addition, we impose the constraintsโ€ข โˆ‘7_% ๐‘ž% ๐‘ฅ% = โˆ‘๐’™๐’‚ ๐‘ž, ๐’™, = 1โ€ข ๐‘ž% ๐‘ฅ% โ‰ฅ 0, ๐‘ž, ๐‘ฅ, โ‰ฅ 0โ€ข โˆ‘๐’™๐’‚ \5" ๐‘ž, ๐’™, = ๐‘ž%(๐‘ฅ%)

What does loopy belief propagation optimize?Loopy belief propagation equations give the stationary points of the constrained Bethe free energy.

Derivation

We only specify the factor marginals ๐‘ž# ๐’™# and the variable marginals ๐‘ž!(๐‘ฅ!).โ€ข There may be no distribution ๐‘ž

whose marginals agree with ๐‘ž# ๐’™#

โ€ข Often called pseudomarginalsinstead of marginal

Furthermore, the Bethe entropy ๐ป?@AB@(๐‘ž) is an approximation of the variational entropy when the graph is not a tree

Figure from Wainwright and Jordan 2008.The set of marginals from valid probability distributions ๐‘€(๐บ) is a strict subset of the set of of pseudomarginals ๐ฟ(๐บ).

Variational Inference Methods

Mean field minimizes the variational free energy ๐น ๐‘ž = ๐‘ˆ ๐‘ž โˆ’ ๐ป(๐‘ž)

โ€ข Assumes fully factorized ๐‘ž for tractabilityโ€ข Can be extended to other tractable ๐‘ž: structured mean fieldโ€ข Minimizes upper bound of Helmholtz free energy ๐น0 = โˆ’ ln๐‘โ€ข Converges to local optimum if coordinate descent used, may

not converge for parallel updateโ€ข Update equations be computed as message passing on graph

Variational Inference MethodsLoopy belief propagation can be viewed as minimizing Bethe free energy, ๐น?@AB@ ๐‘ž = ๐‘ˆ?@AB@ ๐‘ž โˆ’ ๐ป?@AB@(๐‘ž), an approximation of variational free energy

โ€ข May not be an upper bound of ๐น@โ€ข Resulting ๐‘ž may not be consistent with a probability distributionโ€ข May not converge, but performance often good when convergesโ€ข Message passing on a graph, various methods to help convergence,

e.g. scheduling messages, damping, etc.โ€ข Extension to generalized belief propagation for other region based free

energy, e.g. Kikuchi free energyOther commonly found variational inference message passing methods include expectation propagation, also max product linear programming relaxations for finding MAP approximations.

Parameter EstimationLearn parameterized compatibility functions or components of energy function ๐ธ ๐’™|๐œฝ = โˆ’โˆ‘7B ln ๐œ“7 ๐’™7|๐œƒโ€ข Can do maximum likelihood estimationโ€ข If some variables are not observed, can do the EM algorithmโ€ข If inference intractable, variational approximation for estimating the

latent variables is one approach: variational EMโ€ข With mean field approximation, maximize a lower bound of likelihood

function โ€ข Can also treat parameters as latent variables: Variational BayesFor this tutorial, focus on unrolling the message passing algorithm into a deep neural network and doing end-to-end learning (later).

Outline

โ€ข Markov Decision Process

Markov Decision Process

Markov Decision Process (MDP) is defined by ๐‘†, ๐ด, ๐‘‡, ๐‘…State ๐‘† : Current description of the world

โ€ข Markov: the past is irrelevant once we know the state

โ€ข Navigation example: Position of the robot

Robot navigation

MDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Actions ๐ด : Set of available actions

โ€ข Navigation example:โ€ข Move Northโ€ข Move Southโ€ข Move Eastโ€ข Move West

Robot navigation

MDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Transition function ๐‘‡ :

โ€ข ๐‘‡ ๐‘ , ๐‘Ž, ๐‘ A = ๐‘ƒ(๐‘ A|๐‘ , ๐‘Ž)โ€ข Navigation example:

โ€ข Darker shade, higher probability

Robot navigation

MDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Reward function ๐‘… : Reward received when action ๐‘Ž in state ๐‘ results in transition to state ๐‘ โ€ฒ

โ€ข ๐‘…(๐‘ , ๐‘Ž, ๐‘ 5)โ€ข Navigation example:

โ€ข 100 if ๐‘ โ€ฒ is Homeโ€ข -100 if ๐‘ โ€ฒ is in the danger zoneโ€ข -1 otherwise

โ€ข Can be a function of a subset of ๐‘ , ๐‘Ž, ๐‘ 5 as in navigation example

Robot navigation

Example of 3 state, two action Markov Decision Process ๐‘†, ๐ด, ๐‘‡, ๐‘…

โ€ข Transition can be sparse as in navigation example

[Fig by waldoalvarez CC BY-SA 4.0 ]

MDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Policy ๐œ‹: Function from state and time step to action

โ€ข ๐‘Ž = ๐œ‹ (๐‘ , ๐‘ก)โ€ข Navigation example:

โ€ข Which direction to move at current location at step ๐‘ก

Robot navigation

Robot navigationMDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Value function ๐‘‰I: How good is a policy ๐œ‹ when started from state ๐‘ 

โ€ข ๐‘‰B ๐‘ C = โˆ‘D;CEF!๐ธ[๐‘…(๐‘ D , ๐œ‹ ๐‘ D , ๐‘ก , ๐‘ DG!)โ€ข ๐‘‡ is the horizonโ€ข When horizon is infinite, usually

use discounted rewardโ€ข ๐‘‰4 ๐‘  = โˆ‘5678 ๐ธ[๐›พ5 ๐‘…(๐‘ 5 , ๐œ‹ ๐‘ 5 , ๐‘ก , ๐‘ 59:)โ€ข ๐›พ โˆˆ (0,1) is discount factor

MDP ๐‘†, ๐ด, ๐‘‡, ๐‘…Optimal policy ๐œ‹โˆ—: policy that maximizes

๐‘‰F ๐‘ G =<HIG

J#@๐ธ[๐‘…(๐‘ H, ๐œ‹ ๐‘ H, ๐‘ก , ๐‘ HK@)

โ€ข For infinite horizon, discounted reward MDP, optimal policy ๐œ‹โˆ—(๐‘ ) is stationary (independent of ๐‘ก)

โ€ข Optimal value ๐‘‰ ๐‘  = ๐‘‰Fโˆ—(๐‘ ): value corresponding to optimal policy

Robot navigation

Value Iteration Algorithm

Dynamic programming algorithm for solving MDPsLet ๐‘‰(๐‘ , ๐‘‡) denote the optimal value at ๐‘  when horizon is ๐‘‡, initialized with ๐‘‰ ๐‘ , 0 = ๐‘ฃT .Then

๐‘‰ ๐‘ , ๐‘‡ = max#๐ธ[๐‘… ๐‘ , ๐‘Ž, ๐‘ U + ๐‘‰ ๐‘ U, ๐‘‡ โˆ’ 1 ]

= max#4

TU๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ U + ๐‘‰ ๐‘ U, ๐‘‡ โˆ’ 1 )

As message passing on a graphโ€ข Node at each state ๐‘ , initialized to ๐‘‰ ๐‘ , 0 = ๐‘ฃLโ€ข Utilize |๐ด| โ€˜headsโ€™, one for each action ๐‘Žโ€ข repeat ๐‘‡ iterations

for each action ๐‘Ž of each state ๐‘  (in parallel)Collect ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ M + ๐‘‰ ๐‘ M ) from all ๐‘ โ€ฒ to ๐‘Ž at ๐‘ 

if ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  non-zeroSum all messages

for each node ๐‘  (in parallel)Collect message from its corresponding actions ๐‘ŽTake the maximum of the messages

[Fig by waldoalvarez CC BY-SA 4.0 ]

๐‘‰(๐‘ , ๐‘‡) = max,T

HA๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐‘‰ ๐‘ A, ๐‘‡ โˆ’ 1 )

Robot navigationShortest path exampleโ€ข Deterministic transition (only one

next state with prob 1 from each action)

โ€ข Initialize ๐‘‰ ๐‘ , 0 = 0 for goal state, init to โˆ’โˆž for other states

โ€ข Self loop with 0 reward at goal state for all actions

โ€ข Reward โˆ’๐‘ค%* for moving from ๐‘– to ๐‘—, โˆ’โˆž if no edge between the two nodes

โ€ข Value iteration is Bellman-Fordshortest path algorithm

After ๐‘˜ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐‘˜ steps.

Robot navigation

At initialization

โˆ’โˆž โˆ’โˆž

โˆ’โˆž โˆ’โˆž โˆ’โˆž

โˆ’โˆž โˆ’โˆž

After ๐‘˜ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐‘˜ steps.

โˆ’โˆž โˆ’๐Ÿ

โˆ’โˆž โˆ’โˆž โˆ’๐Ÿ

โˆ’โˆž โˆ’โˆž

Robot navigation

After 1 iteration

After ๐‘˜ iterations, values at each node is the value of the (-ve) shortest path from the node to the goal, reachable within ๐‘˜ steps.

โˆ’๐Ÿ โˆ’๐Ÿ

โˆ’โˆž โˆ’๐Ÿ โˆ’๐Ÿ

โˆ’โˆž โˆ’๐Ÿ

Robot navigation

After 2 iterations

For the infinite horizon discounted case, the dynamic programming equation (Bellman equation) is

๐‘‰ ๐‘  = max#๐ธ[๐‘… ๐‘ , ๐‘Ž, ๐‘ U + ๐›พ๐‘‰ ๐‘ U ]

= max#4

TU๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ U + ๐›พ๐‘‰ ๐‘ U )

Same value iteration algorithm with message changed to๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ U + ๐›พ๐‘‰ ๐‘ U )

Converges to the optimal value function

Convergence of Value Iteration

Derivation

The optimal value function ๐‘‰(๐‘ ) satisfies Bellmanโ€™s principle of optimality๐‘‰ ๐‘  = max

,T

HA๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰ ๐‘ A )

The Bellman update ๐ต in value iteration transforms ๐‘‰D to ๐‘‰DG! as follows๐‘‰DG! ๐‘  = max

,T

HA๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰D ๐‘ A )

We denote this as ๐‘‰DG! = ๐ต๐‘‰D. The optimal value function is a fixed point of this operator ๐‘‰ = ๐ต๐‘‰. The Bellman update is a contraction (for the max norm), i.e.

๐ต๐‘‰! โˆ’ ๐ต๐‘‰# โ‰ค ๐›พ ๐‘‰! โˆ’ ๐‘‰#

From Bellmanโ€™s equation, we have ๐‘‰ = ๐ต๐‘‰ for the optimal value function ๐‘‰.Applying ๐‘‰A = ๐ต๐‘‰A$, repeatedly and using contraction property ๐ต๐‘‰, โˆ’ ๐ต๐‘‰V โ‰ค ๐›พ ๐‘‰, โˆ’ ๐‘‰V we have

๐‘‰A โˆ’ ๐‘‰ = ๐ต๐‘‰A$, โˆ’ ๐ต๐‘‰ โ‰ค ๐›พ ๐‘‰A$, โˆ’ ๐‘‰ โ‰ค ๐›พA ๐‘‰W โˆ’ ๐‘‰Distance converges exponentially to 0 for any initial value ๐‘‰W

Outline

โ€ข Graph Neural Networks and Attention Networks

Graph Neural Networksโ€ข Many effective graph neural networks (GNN):

GCN, GIN, โ€ฆโ€ข Message passing neural networks

(MPNN)1: general formulation for GNNs as message passing algorithms.

โ€ข Input: ๐บ = (๐‘‰, ๐ธ), node attributes vectors ๐‘ฅ% , ๐‘– โˆˆ ๐‘‰, edge attribute vectors ๐‘ฅ%* , ๐‘–, ๐‘— โˆˆ ๐ธ

โ€ข Output: Label or value for graph classification or regression, or a label/value for each node in structured prediction

1 Gilmer, Justin, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. "Neural message passing for quantum chemistry." In International conference on machine learning, pp. 1263-1272. PMLR, 2017

v1

v2

v3

๐‘ฅ!

๐‘ฅ"

๐‘ฅ#๐‘ฅ!#

๐‘ฅ#!

๐‘ฅ"#

๐‘ฅ!"

MPNN pseudocodeInitialization: โ„Ž!W = ๐‘ฅ! for each ๐‘ฃ! โˆˆ ๐‘‰for layers โ„“ = 1,โ€ฆ , ๐‘‘

for every edge ๐‘–, ๐‘— โˆˆ ๐ธ (in parallel)๐‘š!"X = ๐‘€๐‘†๐บโ„“(โ„Ž!โ„“$,, โ„Ž"โ„“$,, ๐‘ฃ! , ๐‘ฃ" , ๐‘ฅ!")

for every node ๐‘ฃ! โˆˆ ๐‘‰ (in parallel)โ„Ž!โ„“ = ๐‘ˆ๐‘ƒโ„“ {๐‘š!"

X : ๐‘— โˆˆ ๐‘(๐‘–) , โ„Ž!โ„“$,)return โ„Ž!Z for every ๐‘ฃ! โˆˆ ๐‘‰ or ๐‘ฆ = ๐‘…๐ธ๐ด๐ท( โ„Ž!Z: ๐‘ฃ! โˆˆ ๐‘‰ )

โ€ข ๐‘€๐‘†๐บโ„“ is arbitrary function, usually a neural net โ€ข ๐‘ˆ๐‘ƒโ„“ aggregrates the messages from neighbours (usually with a set

function) and combine with node embedding โ€ข ๐‘…๐ธ๐ด๐ท is a set function for graph classification or regression tasks

๐‘ฅ"

v1

v2

v3

๐‘ฅ!

๐‘ฅ#๐‘ฅ!#

๐‘ฅ#!

๐‘ฅ"#

๐‘ฅ!"

GNN propertiesIf depth and width are large enough, message functions ๐‘€๐‘†๐บโ„“ and update functions ๐‘ˆ๐‘ƒโ„“ are sufficiently powerful, and nodes can uniquely distinguish each other, then the MPNN is computationally universal1

โ€ข Equivalent to LOCAL model in distributed algorithmsโ€ข Can compute any function computable with respect to

the graph and attributes (just send all information, including graph structure, to a single node, then compute there).

1 Loukas, Andreas. "What graph neural networks cannot learn: depth vs width.โ€ ICLR 2020

NotationDepth: number of layers

Width: largest embedding dimension

Example:

๐‘ท =0 0 110

01

00, ๐‘จ =

1 2 347

58

69

๐‘ท permutes vertices 1,2,3 โ†’ (3,1,2)

๐‘ท๐‘จ swap the rows0 0 110

01

00

1 2 347

58

69

= 7 8 914

25

36

๐‘ท๐‘จ ๐‘ทJ swap the columns after that7 8 914

25

36

0 1 001

00

10

= 9 7 836

14

25

When using graph neural networks, we are often interested in permutation invariance and equivarianceGiven an adjacency matrix ๐‘จ, a permutation matrix ๐‘ท, and attribute matrix ๐‘ฟ containing the attributes ๐‘ฅ; on the ๐‘–-th row

โ€ข Permutation invariance: ๐‘“ ๐‘ท๐‘จ๐‘ท๐‘ป, ๐‘ท๐‘ฟ = ๐‘“(๐‘จ, ๐‘ฟ)

โ€ข Permutation equivariance: ๐‘“ ๐‘ท๐‘จ๐‘ท๐‘ป, ๐‘ท๐‘ฟ = ๐‘ท๐‘“(๐‘จ, ๐‘ฟ)

If ๐‘€๐‘†๐บโ„“ does not depend on node ids ๐‘ฃ; , ๐‘ฃ8, MPNN is permutation invariant and equivariant for any permutation matrix ๐‘ท, ๐‘€๐‘ƒ๐‘๐‘ ๐‘ท๐‘จ๐‘ท๐‘ป, ๐‘ท๐‘ฟ = ๐‘ท๐‘€๐‘ƒ๐‘๐‘(๐‘จ, ๐‘ฟ)

โ€ข However, lose approximation power if messages do not depend on node ids โ€“ cannot distinguish some graph structures

โ€ข Discrimination power at most as powerful as the 1-dimensional Weisfeiler-Lehman (WL) graph isomorphism test, which cannot distinguish certain graphs

โ€ข Graph isomorphism network (GIN)1as powerful as 1-WL

MPNN pseudocodeInitialization: โ„Ž(7 = ๐‘ฅ( for each ๐‘ฃ( โˆˆ ๐‘‰for layers โ„“ = 1,โ€ฆ , ๐‘‘

for every edge ๐‘–, ๐‘— โˆˆ ๐ธ (in parallel)๐‘š(#< = ๐‘€๐‘†๐บโ„“(โ„Ž(โ„“>:, โ„Ž#โ„“>:, ๐‘ฃ( , ๐‘ฃ# , ๐‘ฅ(#)

for every node ๐‘ฃ( โˆˆ ๐‘‰ (in parallel)โ„Žโ„“โ„“ = ๐‘ˆ๐‘ƒ( {๐‘š(#

< : ๐‘— โˆˆ ๐‘(๐‘–) , โ„Ž(โ„“>:)return โ„Ž(? for every ๐‘ฃ( โˆˆ ๐‘‰ or ๐‘ฆ = ๐‘…๐ธ๐ด๐ท(mn

โ„Ž(?: ๐‘ฃ( โˆˆ๐‘‰ )

1 Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. "How powerful are graph neural networks?" ICLR 2019

Drug Discovery

Molecules naturally represented as graphsGNNs commonly used for predicting properties of molecules, e.g. whether it inhibits certain bacteria, etc.

[Figs on paracetamol by Benjah-bmm27 and Ben Mills are in the public domain]

Bellman Ford Shortest Path

for โ„“ = 1,โ€ฆ , ๐‘‘for ๐‘ฃ โˆˆ ๐‘‰ (in parallel)

d โ„“ ๐‘ฃ = min"๐‘‘ โ„“ โˆ’ 1 ๐‘ข + ๐‘ค(๐‘ข, ๐‘ฃ)

Algorithmic Alignment

Graph Neural Network

for layers โ„“ = 1,โ€ฆ , ๐‘‘for ๐‘ฃ โˆˆ ๐‘‰ (in parallel)

โ„Ž#โ„“ = ๐‘ˆ๐‘ƒ({๐‘€๐ฟ๐‘ƒ โ„Ž"โ„“%& , โ„Ž#โ„“%& , ๐‘ค(๐‘ข, ๐‘ฃ) }, โ„Ž#โ„“%&)

1 Xu, Keyulu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken-ichi Kawarabayashi, and Stefanie Jegelka. "What can neural networks reason about?.โ€ ICLR 2020

Sample complexity for learning a GNN is smaller for tasks that the GNN is algorithmically aligned with1.

d โ„“ ๐‘ฃ easy function to approximate by ๐‘€๐ฟ๐‘ƒ as

function of ๐‘‘ โ„“ โˆ’ 1 ๐‘ข ,๐‘ค .Same ๐‘€๐ฟ๐‘ƒ shared by all nodes,

good inductive bias, so low sample complexity

In contrast, learning entire for loop with a single function requires

higher sample complexity

Alignment of GNN and VI

1 Tamar, Aviv, Yi Wu, Garrett Thomas, Sergey Levine, and Pieter Abbeel. "Value iteration networks.โ€ NeurIPS 2016

Value iteration algorithm for MDPs can be represented in a network form โ€“ value iteration network (VIN)1.โ€ข GNN well aligned with value iteration

Value Iterationfor โ„“ = 1,โ€ฆ , ๐‘‘

for ๐‘ฃ โˆˆ ๐‘‰ (in parallel)๐’ฑ(๐‘ฃ, โ„“) = max

+โˆ‘@ ๐‘ ๐‘ข|๐‘Ž, ๐‘ฃ (๐‘… ๐‘ฃ, ๐‘Ž, ๐‘ข + ๐’ฑ ๐‘ข, โ„“ โˆ’ 1 )

Graph Neural Networkfor layers โ„“ = 1,โ€ฆ , ๐‘‘

for ๐‘ฃ โˆˆ ๐‘‰ (in parallel)โ„ŽAโ„“ = ๐‘ˆ๐‘ƒ({๐‘€๐ฟ๐‘ƒ โ„Ž@โ„“>:, โ„ŽAโ„“>:, {๐‘ ๐‘ข|๐‘Ž, ๐‘ฃ , ๐‘… ๐‘ฃ, ๐‘Ž, ๐‘ข } }, โ„ŽAโ„“>:)

1 Karkus, Peter, David Hsu, and Wee Sun Lee. "QMDP-net: Deep learning for planning under partial observability." NeurIPS 2017

Value Iterationfor โ„“ = 1,โ€ฆ , ๐‘‘

for ๐‘ฃ โˆˆ ๐‘‰ (in parallel)๐’ฑ(๐‘ฃ, โ„“) = max

'โˆ‘" ๐‘ ๐‘ข|๐‘Ž, ๐‘ฃ (๐‘… ๐‘ฃ, ๐‘Ž, ๐‘ข + ๐’ฑ ๐‘ข, โ„“ โˆ’ 1 )

Example: robot navigation using a map, using VIN1

โ€ข The transition ๐‘ ๐‘ข|๐‘Ž, ๐‘ฃ and reward ๐‘… ๐‘ฃ, ๐‘Ž, ๐‘ข function may also need to be learned, instead of being provided.

โ€ข Message passing structure suggests

โ€ข represent transition and reward separately, learned as function of image

โ€ข use as input to same function at all states

1 Lee, Lisa, Emilio Parisotto, Devendra Singh Chaplot, Eric Xing, and Ruslan Salakhutdinov. "Gated path planning networks.โ€ ICML 2018.

GNN well aligned with value iteration: may work well hereVIN has stronger inductive bias: encode value iteration equations directly

โ€ข Action heads: for each action, take weighted sum from neighbours

โ€ข Then max over actionsGNN potentially more flexible: also aligned other similar algorithms, particularly dynamic programming algorithms

โ€ข May work better if MDP assumption is not accurate

โ€ข Optimization may be easier for some types of architectures1

Alignment of GNN and Graphical Model Algorithms Mean field

repeat ๐‘‡ iterationsfor each variable ๐‘— compute (serially or in parallel)

๐‘š'((๐‘ฅ() = โˆ‘๐’™!\+"โˆ,โˆˆ. ' ,,0(. ๐‘ž, ๐‘ฅ, ln ๐œ“' ๐’™' for ๐‘Ž โˆˆ ๐‘ ๐‘— in parallel

๐‘ž( ๐‘ฅ( = &1"# exp โˆ‘'โˆˆ. ( ๐‘š'( ๐‘ฅ(

1 Dai, Hanjun, Bo Dai, and Le Song. "Discriminative embeddings of latent variable models for structured data." ICML 2016.

When all potential functions are pairwise ๐œ“# ๐’™# = ๐œ“!,"(๐‘ฅ! , ๐‘ฅ"), then ๐‘š#"(๐‘ฅ") is a function of only ๐‘ž! ๐‘ฅ! .

โ€ข Can send message directly from variable to variable without combining the messages at factor nodes

โ€ข Can interpret node embedding as feature representation of belief and learn a mapping from one belief to another as a message function1

โ€ข GNN algorithmically well aligned with mean field for pairwise potential functionsโ€ข Similarly well aligned with loopy belief propagation for pairwise potential

What about with higher order potentials ๐œ“# ๐’™# where ๐’™# consists of ๐‘›# variables?

โ€ข In tabular representation, size of ๐œ“, ๐’™,grows exponentially with ๐‘›,, so even loopy belief propagation is not efficient

โ€ข But if ๐œ“, ๐’™, is low-ranked tensor, then loopy belief propagation can be efficiently implemented1,2

โ€ข Represent ๐œ“+ ๐’™+ as a tensor decomposition (CP decomposition), where ๐‘˜+ is the rank๐œ“+ ๐’™+ = โˆ‘(6:

B2 ๐‘ค+,:( (๐‘ฅ+,:)๐‘ค+,D( (๐‘ฅ+,D)โ€ฆ๐‘ค+,E2( ๐‘ฅ+,E2

โ€ข Factor graph neural network (FGNN), which passes messages on a factor graph is well aligned with this

1 Dupty, Mohammed Haroon, and Wee Sun Lee. "Neuralizing Efficient Higher-order Belief Propagation." arXiv preprint arXiv:2010.09283 (2020)2 Zhang, Zhen, Fan Wu, and Wee Sun Lee. "Factor graph neural network." NeurIPS 2020.

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

Implement message functions๐‘›1# ๐‘ฅ1 = 2

Zโˆˆ&(1)\\

๐‘šZ1(๐‘ฅ1)

๐‘š#1 ๐‘ฅ1 = โˆ‘๐’™$\]&

๐œ“# ๐’™# 2,โˆˆ&(#)\^

๐‘›1,(๐‘ฅ,)

using

๐œ“# ๐’™# =)โ„“.$

_$2,โˆˆ&(#)

๐‘ค#,,โ„“ (๐‘ฅ,)

๐‘š#! ๐‘ฅ! = โˆ‘๐’™`\+a

๐œ“# ๐’™# '"โˆˆG(#)\z

๐‘›!"(๐‘ฅ")

=+๐’™`\+a

+โ„“C,

{`'"โˆˆG(#)

๐‘ค#,"โ„“ (๐‘ฅ") '"โˆˆG(#)\z

๐‘›!"(๐‘ฅ")

=+โ„“C,

{`๐‘ค#,!โ„“ (๐‘ฅ!) +

๐’™`\+_!

'"โˆˆG(#)\z

๐‘ค#,"โ„“ (๐‘ฅ")๐‘›!"(๐‘ฅ")

=+โ„“C,

{`๐‘ค#,!โ„“ (๐‘ฅ!) '

"โˆˆG(#)\z

++b

๐‘ค#,"โ„“ (๐‘ฅ")๐‘›!"(๐‘ฅ")

Matrix notationโ€ข ๐‘š#1 vector of length ๐‘›1โ€ข ๐‘›1# vector of length ๐‘›1โ€ข ๐‘ค#,1โ„“ vector of length ๐‘›1โ€ข ๐‘พ#1 matrix of ๐‘˜#rows where each

row is (๐‘ค#,1โ„“ )cโ€ข โŠ™ element-wise multiplication

๐‘›%, ๐‘ฅ% = โˆ-โˆˆ/(%)\3๐‘š-%(๐‘ฅ%)

๐‘š,% ๐‘ฅ% =Tโ„“;!

`!๐‘ค,,%โ„“ (๐‘ฅ%) F

*โˆˆ/(,)\6

T51

๐‘ค,,*โ„“ (๐‘ฅ*)๐‘›%*(๐‘ฅ*)

In matrix notation๐‘š7; = ๐‘พ7;

J โŠ™8โˆˆ=(7)\Z๐‘พ78๐‘›87๐‘›;7 =โŠ™[โˆˆ=(;)\\ ๐‘š[;

To make factor and variable messages symmetric

๐‘š7;M =โŠ™8โˆˆ=(7)\Z๐‘พ78๐‘›87

๐‘›;7 =โŠ™[โˆˆ=(;)\\๐‘พ[;J๐‘š[;

M

Matrix vector multiplication followed by product aggregation.

Factor graph provides an easy way to specify dependencies, even higher order ones.Loopy belief propagation can be approximated with the following message passing equations

๐‘š#! =โŠ™"โˆˆG(#)\z๐‘พ#"๐‘›"#๐‘›!# =โŠ™}โˆˆG(!)\~๐‘พ}!

๏ฟฝ ๐‘š}!

Optimizes Bethe free energy if it converges.

โ€ข Uses low rank approximationfor potential functions.

โ€ข Increasing number of rows of ๐‘พ,% increases rank of tensor decompositionapproximation.

Neuralizing Loopy Belief PropagationAlignment shows that a neural network with small number of parameters can approximate an algorithm, smaller sample complexity in learning: analysis.Can also use the ideas in design. Neuralizing the algorithm

โ€ข Start with the network representing the algorithm to capture the inductive bias.

โ€ข Modify the algorithm, e.g. add computational elements to enhance approximation capability, while mostly maintaining the structure to keep the inductive bias.

โ€ข Start with an algorithm in network form, e.g. ๐‘š,% =โŠ™*โˆˆ/(,)\6๐‘พ,*๐‘›*,๐‘›%, =โŠ™-โˆˆ/(%)\3๐‘พ-%

E๐‘š-%

โ€ข Add network elements to potentially make the network more powerful โ€“ enlarge the class of algorithms that can be learned, e.g.

๐‘š,% = ๐‘€๐ฟ๐‘ƒ(โŠ™*โˆˆ/(,)\6๐‘พ,*๐‘›*,)๐‘›%, = ๐‘€๐ฟ๐‘ƒ(โŠ™-โˆˆ/(%)\3๐‘พ-%

E๐‘š-%)โ€ข It is usually simpler to keep messages only on nodes

instead of on edges, simplify while keeping message passing structure. Can also change aggregrator, e.g. to sum, max, etc. Works well in practice

๐‘š, = ๐‘€๐ฟ๐‘ƒ(๐ด๐บ๐บ*โˆˆ/(,)๐‘พ,*๐‘›*)๐‘›% = ๐‘€๐ฟ๐‘ƒ(๐ด๐บ๐บ-โˆˆ/(%)๐‘พ-%

E๐‘š-)

๐‘ฅ!

๐‘ฅ"

๐‘ฅ#

๐‘ฅ$

๐ด

๐ต

๐ถ

Message passing neural net on factor graph

Attention Network

Distributed Representation of Graphs and Matrices A graph can be represented using an adjacency matrixA matrix ๐ด, in turn can be factorized ๐ด = ๐‘ˆ๐‘‰E

In factorized form, ๐‘ข%E๐‘ฃ* = ๐ด%*โ€ข Entry ๐‘–, ๐‘— of matrix ๐ด is the inner product of row ๐‘– of matrix ๐‘ˆ with row ๐‘— of matrix ๐‘‰โ€ข Node ๐‘– of graph has an embeddings ๐‘ข( , ๐‘ฃ( such that the value of edge (๐‘–. ๐‘—) can be

computed as ๐‘ข(F๐‘ฃ#โ€ข Distributed representation of graph โ€“ information distributed to the nodes

๐ด = ๐‘ˆ๐‘‰E ๐ด%* = ๐‘ข%E๐‘ฃ* =T`

๐‘ข%`๐‘ฃ*`

๐‘Ž!! ๐‘Ž!# ๐‘Ž!"๐‘Ž#! ๐‘Ž## ๐‘Ž#"๐‘Ž"! ๐‘Ž"# ๐‘Ž""

=๐‘ข!! ๐‘ข!#๐‘ข#! ๐‘ข##๐‘ข"! ๐‘ข"#

๐‘ฃ!! ๐‘ฃ#! ๐‘ฃ"!๐‘ฃ!# ๐‘ฃ## ๐‘ฃ"#

With distributed representation, using a subset of embeddings ๐‘ข% , ๐‘ฃ%gives representation of subgraph!

Attention NetworkUsing matrix factorization, we can show that the transformer-type attention network is well aligned with value iteration for MDP In the attention network, we have a set of nodes, each of which has an embedding as input, and the same operations (implemented with a network) are applied at each node in a layer.

Multi-Head Attention

Add

Feedforward

Add

Multi-head attentionLet embedding at node ๐‘– be ๐‘ฅ`

โ€ข Each node has ๐พ attention headsโ€ข Each attention head has weight

matrices: query weights ๐‘พa used to compute query ๐‘ž% = ๐‘พa๐‘ฅ%, key weights ๐‘พb used to compute key ๐‘˜% = ๐‘พb๐‘ฅ% and value weights ๐‘พcused to compute ๐‘ฃ% = ๐‘พc๐‘ฅ%.

โ€ข Compute ๐‘Ž%* = ๐‘ž%E๐‘˜* with all other nodes ๐‘—. Compute probability vector ๐‘%* for all ๐‘— using the softmax function ๐‘%* = ๐‘’,"1/โˆ‘โ„“ ๐‘’,"โ„“. The output of the head is a weighted average of all the values โˆ‘* ๐‘%*๐‘ฃ*.

๐‘ฅ'

๐‘ž' = ๐‘พ(๐‘ฅ'๐‘ฃ' = ๐‘พ)๐‘ฅ'๐‘˜' = ๐‘พ*๐‘ฅ'๐‘˜#

๐‘Ž'# = ๐‘ž'+๐‘˜#

Softmax

โ€ฆ .

Weighted Average

๐‘ฅ#

Concatenate Across K Heads

Projection

Output of Multi-Head Attention

Output of multi-head attentionโ€ข Concatenate outputs of ๐พ attention heads โ€ข Project back to vector of the same length โ€ข Passed to a feedforward network through a

residual operation (added to ๐‘ฅ%)

๐‘ฅ'

๐‘ž' = ๐‘พ(๐‘ฅ'๐‘ฃ' = ๐‘พ)๐‘ฅ'๐‘˜' = ๐‘พ*๐‘ฅ'๐‘˜#

๐‘Ž'# = ๐‘ž'+๐‘˜#

Softmax

โ€ฆ .

Weighted Average

๐‘ฅ#

Concatenate Across K Heads

Projection

Output of Multi-Head Attention

Multi-Head Attention

Add

Feedforward

Add

Alignment of Attention Network with Value Iteration

Constructing input ๐‘ฅ%โ€ข Use matrix factorization to get a

distributed representation of the log of each transition matrix ๐ฟ, = ๐‘ˆ,๐‘‰,E for each action ๐‘Ž where ๐ฟ,[๐‘ , ๐‘ A] = log ๐‘ƒ(๐‘ A|๐‘ , ๐‘Ž).

โ€ข Construct input for node ๐‘–, ๐‘ฅ%, by stacking up ๐พ + 1 copies of initial value ๐‘ฃ% followed by embeddings of transition matrices

๐‘…b%โ‹ฎ๐‘…!%๐‘ฃb%๐‘ขb%โ‹ฎ๐‘ฃ!%๐‘ข!%๐‘ฃ%โ‹ฎ๐‘ฃ%

๐พ + 1 copies of initial value ๐‘ฃ%

Embeddings of transition matrices for each action for state ๐‘–

Reward for eachaction for state ๐‘–

Constructing input ๐‘ฅ%โ€ข Compute expected reward ๐‘…,H = ๐ธ[๐‘… ๐‘ , ๐‘Ž, ๐‘ A ] for each action of the ๐พ actions.

โ€ข Place expected reward ๐‘…,% for each action ๐‘Ž into a single vector input and concatenate to the earlier input vector.

๐‘…b%โ‹ฎ๐‘…!%๐‘ฃb%๐‘ขb%โ‹ฎ๐‘ฃ!%๐‘ข!%๐‘ฃ%โ‹ฎ๐‘ฃ%

๐พ + 1 copies of initial value ๐‘ฃ%

Embeddings of transition matrices for each action for state ๐‘–

Reward for eachaction for state ๐‘–

Need to extract out ๐‘ข7; , ๐‘ฃ78 to compute ๐‘ ๐‘— ๐‘Ž, ๐‘– = exp(๐‘ข7;J ๐‘ฃ78)

โ€ข Setting ๐‘Šd = [0, ๐ผe , 0] where ๐ผ_ is a ๐‘˜ by ๐‘˜identity matrix at the appropriate columns will extract ๐‘ข#1 = ๐‘Šd๐‘ฅ1 as query.

โ€ข Similarly can construct ๐‘Še to extract ๐‘ฃ#, as key, allowing softmax of inner product ๐‘ข#1c ๐‘ฃ#, to correctly compute transition probabilities.

โ€ข Set ๐‘Šf to extract out the value component from ๐‘ฅ1

With this construction, output of head ๐‘Ž is๐‘œ7 = โˆ‘8 ๐‘ ๐‘— ๐‘Ž, ๐‘– ๐‘ฃ8

At each layer of value iteration:for each action ๐‘Ž of each state ๐‘  (in parallel)

Collect ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ 3 + ๐‘‰ ๐‘ 3 ) from all ๐‘ โ€ฒ to ๐‘Ž at ๐‘  if ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  non-zeroSum all messages

for each node ๐‘  (in parallel)Collect message from its corresponding actions ๐‘ŽTake the maximum of the messages

0โ‹ฎ0

โ€ฆโ‹ฑโ€ฆ

0โ‹ฎ0๐ผ`0โ‹ฎ0

โ€ฆโ‹ฑโ€ฆ

0โ‹ฎ0

๐‘…b%โ‹ฎ๐‘…!%๐‘ฃb%๐‘ขb%โ‹ฎ๐‘ฃ!%๐‘ข!%๐‘ฃ%โ‹ฎ๐‘ฃ%

= ๐‘ข,%

Output of head ๐‘Ž is

๐‘œ#! =4"

๐‘ ๐‘— ๐‘Ž, ๐‘– ๐‘ฃ"

Concatenate as ๐‘œ,! โ€ฆ๐‘œ๏ฟฝ! ๏ฟฝ and add to first ๐พ components of ๐‘ฅ! to form input to feedforward network,

At each layer of value iteration:for each action ๐‘Ž of each state ๐‘  (in parallel)

Collect ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ 3 + ๐‘‰ ๐‘ 3 ) from all ๐‘ โ€ฒ to ๐‘Ž at ๐‘  if ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  non-zeroSum all messages

for each node ๐‘  (in parallel)Collect message from its corresponding actions ๐‘ŽTake the maximum of the messages

Multi-Head Attention

Add

Feedforward

Add๐‘…b%โ‹ฎ๐‘…!%๐‘ฃb%๐‘ขb%โ‹ฎ๐‘ฃ!%๐‘ข!%๐‘ฃ%

๐‘ฃ% + ๐‘œb%โ‹ฎ

๐‘ฃ% + ๐‘œ!%

Feedforward network:โ€ข Subtract ๐‘ฃ_๐‘– from ๐‘ฃ% + ๐‘œ,% to get ๐‘œ,%โ€ข Compute ๐‘ฃ%A = max

,{๐‘…,%+๐‘œ,%}

โ€ข Construct output so that adding back the input of the feedforward network gets ๐‘ฃ%Ain the first ๐พ + 1 positions and the original MDP parameters in the remaining positions: output the value ๐‘ฃ%A โˆ’ (๐‘ฃ% + ๐‘œ,%) as first ๐พ elements and ๐‘ฃ%A โˆ’ ๐‘ฃ% as the ๐พ + 1st element,

At each layer of value iteration:for each action ๐‘Ž of each state ๐‘  (in parallel)

Collect ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  (๐‘… ๐‘ , ๐‘Ž, ๐‘ 3 + ๐‘‰ ๐‘ 3 ) from all ๐‘ โ€ฒ to ๐‘Ž at ๐‘  if ๐‘ ๐‘ โ€ฒ|๐‘Ž, ๐‘  non-zeroSum all messages

for each node ๐‘  (in parallel)Collect message from its corresponding actions ๐‘ŽTake the maximum of the messages

Multi-Head Attention

Add

Feedforward

Add๐‘…b%โ‹ฎ๐‘…!%๐‘ฃb%๐‘ขb%โ‹ฎ๐‘ฃ!%๐‘ข!%๐‘ฃ%

๐‘ฃ% + ๐‘œb%โ‹ฎ

๐‘ฃ% + ๐‘œ!%

Learning

All methods discussed are message passing methods on a graphโ€ข Messages are constructed using operations such as addition,

multiplication, max, exponentiation, etc.โ€ข Can be represented using network of computational elementsโ€ข Usually same network at each graph node

Each iteration of message passing forms a layerPutting together layers form a deep network

โ€ข Different layers can have different parameters, additional flexibilityWith appropriate loss functions, can learn using gradient descent if all network elements are differentiable: backpropagation

Backpropagation and Recurrent Backpropagation

If different layers all have the same parameters, we have a recurrent neural network.For some recurrent networks, the inputs the same at each iteration and we are aiming for solution at convergence

โ€ข Loopy belief propagationโ€ข Value iteration for discounted MDPโ€ข Some graph neural networks1โ€ข Some tranformer architectures2

Recurrent backpropagation and other optimization methods based on implicit functions can be used3

โ€ข Constant memory usage

1 Scarselli, Franco, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. "The graph neural network model." IEEE Transactions on Neural Networks, 2008. 2 Bai, Shaojie, J. Zico Kolter, and Vladlen Koltun. "Deep equilibrium models." NeurIPS 20193 Zico Kolter, David Duvenaud, and Matt Johnson, โ€œDeep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond.โ€ NeurIPS 2020 Tutorial http://implicit-layers-tutorial.org/

Summary

Message Passing in Machine LearningWe viewed some โ€œclassicโ€ message passing algorithms in machine learning through variational principles

โ€ข Loopy belief propagationโ€ข Mean fieldโ€ข Value iteration

What optimization problems are they solving?

We relate graph neural networks and attention networks to the โ€œclassicโ€ message passing algorithms

โ€ข Algorithmic alignment (analysis): Can they simulate those algorithms using a small network?

โ€ข Neuralizing algorithms (design): Can we enhance those algorithms into more powerful neural versions

References1โ€ข David Mackayโ€™s Error Correcting Code demo

http://www.inference.org.uk/mackay/codes/gifs/โ€ข Zheng, Shuai, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav

Vineet, Zhizhong Su, Dalong Du, Chang Huang, and Philip HS Torr. "Conditional random fields as recurrent neural networks." In Proceedings of the IEEE international conference on computer vision, pp. 1529-1537. 2015.

โ€ข Yedidia, Jonathan S., William T. Freeman, and Yair Weiss. "Constructing free-energy approximations and generalized belief propagation algorithms." IEEE Transactions on information theory 51.7 (2005): 2282-2312.

โ€ข Wainwright, Martin J., and Michael Irwin Jordan. Graphical models, exponential families, and variational inference. Now Publishers Inc, 2008.

1 This list contains only references referred to within the slides and not the many other works related to the material in the tutorial.

โ€ข Markov Decision Process (figure) by waldoalvarez - Own work, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=59364518

โ€ข Gilmer, Justin, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, and George E. Dahl. "Neural message passing for quantum chemistry." In International conference on machine learning, pp. 1263-1272. PMLR, 2017.

โ€ข Loukas, Andreas. "What graph neural networks cannot learn: depth vs width.โ€ ICLR 2020

โ€ข Xu, Keyulu, Weihua Hu, Jure Leskovec, and Stefanie Jegelka. "How powerful are graph neural networks?" ICLR 2019

โ€ข Xu, Keyulu, Jingling Li, Mozhi Zhang, Simon S. Du, Ken-ichiKawarabayashi, and Stefanie Jegelka. "What can neural networks reason about?.โ€ ICLR 2020

โ€ข Skeletal formula of paracetamol by Benjah-bmm27 is in the public domain, https://en.wikipedia.org/wiki/Paracetamol#/media/File:Paracetamol-skeletal.svg

โ€ข Ball and stick model of paracetamol by Ben Mills is in the public domain, https://en.wikipedia.org/wiki/Paracetamol#/media/File:Paracetamol-from-xtal-3D-balls.png

โ€ข Tamar, Aviv, Yi Wu, Garrett Thomas, Sergey Levine, and Pieter Abbeel. "Value iteration networks.โ€ NeurIPS 2016.

โ€ข Lee, Lisa, Emilio Parisotto, Devendra Singh Chaplot, Eric Xing, and Ruslan Salakhutdinov. "Gated path planning networks.โ€ ICML 2018.

โ€ข Karkus, Peter, David Hsu, and Wee Sun Lee. "QMDP-net: Deep learning for planning under partial observability." NeurIPS 2017.

โ€ข Dai, Hanjun, Bo Dai, and Le Song. "Discriminative embeddings of latent variable models for structured data." ICML 2016.

โ€ข Dupty, Mohammed Haroon, and Wee Sun Lee. "Neuralizing Efficient Higher-order Belief Propagation." arXiv preprint arXiv:2010.09283 (2020).

โ€ข Zhang, Zhen, Fan Wu, and Wee Sun Lee. "Factor graph neural network." NeurIPS 2020.

โ€ข Scarselli, Franco, Marco Gori, Ah Chung Tsoi, Markus Hagenbuchner, and Gabriele Monfardini. "The graph neural network model." IEEE Transactions on Neural Networks 20, no. 1 (2008): 61-80.

โ€ข Bai, Shaojie, J. Zico Kolter, and Vladlen Koltun. "Deep equilibrium models." NeurIPS 2019.

โ€ข Zico Kolter, David Duvenaud, and Matt Johnson, โ€œDeep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond.โ€ NeurIPS 2020 Tutorial http://implicit-layers-tutorial.org/

Appendix

Correctness of Belief Propagation on Trees

๐‘š,% ๐‘ฅ%

๐‘ป๐’Š

๐‘ ๐‘

๐‘ฅe

๐‘ฅ%

๐‘ฅ* ๐‘ฅ`

๐‘Ž

๐‘š

๐‘›%, ๐‘ฅ% = F-โˆˆ/(%)\3

๐‘š-%(๐‘ฅ%)

๐‘š,% ๐‘ฅ% = โˆ‘๐’™!\5"

๐œ“, ๐’™, F*โˆˆ/(,)\6

๐‘›*%(๐‘ฅ*)

๐‘% 7" โˆ F,โˆˆ/(%)

๐‘š,%(๐‘ฅ%)

Assume ๐‘‡% is the subtree along edge ๐‘Ž, ๐‘– , rooted at ๐‘ฅ% and message ๐‘š,% ๐‘ฅ% is sent

from ๐‘Ž to ๐‘–. After ๐’Œ iterations of message passing, the message correctly marginalizes away other variables in the subtree๐‘š,% ๐‘ฅ% = โˆ‘7AโˆˆE"\5"โˆ,โˆˆE"๐œ“,(๐’™,

A )where ๐‘˜ is the height of the subtree.

Assume marginalization correctly done for subtrees of height < ๐‘˜โ€ข Push summation inside product,

can group at subtrees depth 2 below ๐‘Ž because of tree structure

โ€ข By inductive hypothesis, marginalization correct at depth 2 below

So after ๐‘˜ iterations of message passing, marginalization at height ๐‘˜ correct

CorrectnessInit all messages to all-one vectors. Then:

๐‘ฅ%

๐‘ฅ* ๐‘ฅ`

๐‘Ž

๐‘š

๐‘š,' ๐‘ฅ' =1"!

1""

1"#

1"...

๐œ“, ๐‘ฅ' , ๐‘ฅ# , ๐‘ฅ. ๐œ“/ โ€ฆ โ€ฆ

๐‘ ๐‘

๐‘ฅe

๐‘ฅ%

๐‘ฅ* ๐‘ฅ`

๐‘Ž

๐‘š

๐‘š,' ๐‘ฅ' =1"!

1""

๐œ“, ๐‘ฅ' , ๐‘ฅ# , ๐‘ฅ. 1"#

1"...

๐œ“/ โ€ฆ โ€ฆ1"โ€ฆ

๐œ“1 โ€ฆ โ€ฆ

๐‘ ๐‘

๐‘ฅe

Compute at ๐‘š/# ๐‘ฅ#instead

Mean Field Update DerivationIn mean field approximation, we find ๐‘ž that minimizes the variational free energy

๐น ๐‘ž = โˆ‘' ๐‘ž ๐’™ ๐ธ ๐’™ +โˆ‘' ๐‘ž ๐’™ ln ๐‘ž(๐’™)

when ๐‘ž is restricted to a factored form

๐‘žBc ๐‘ฅ = โˆ;I@= ๐‘ž;(๐‘ฅ;).

Consider the dependence on a single variable ๐‘ž,(๐‘ฅ,) with all other variables fixed

๐น ๐‘ž ='๐’™(1.$

&

๐‘ž1(๐‘ฅ1) ๐ธ ๐’™ +'๐’™(1.$

&

๐‘ž1(๐‘ฅ1) ln(1.$

&

๐‘ž1(๐‘ฅ1)

='-2๐‘ž,(๐‘ฅ,)'

๐’™\]2(1i,

&

๐‘ž1(๐‘ฅ1) ๐ธ ๐’™ +'-2๐‘ž, ๐‘ฅ, ln ๐‘ž,(๐‘ฅ,) + ๐‘๐‘œ๐‘›๐‘ ๐‘ก

='-2๐‘ž, ๐‘ฅ, ln

1๐‘,5 ๐‘ฅ,

+ ๐‘๐‘œ๐‘›๐‘ ๐‘ก +'-2๐‘ž, ๐‘ฅ, ln ๐‘ž,(๐‘ฅ,) + ๐‘๐‘œ๐‘›๐‘ ๐‘ก

where ๐‘,5 ๐‘ฅ, =1๐‘,โ€ฒexp โˆ’'

๐’™\]2(1i,

&

๐‘ž1 ๐‘ฅ1 ๐ธ ๐’™

= ๐พ๐ฟ(๐‘ž,||๐‘,5) + ๐‘๐‘œ๐‘›๐‘ ๐‘ก

which is minimized at ๐‘ž, ๐‘ฅ, = ๐‘,5 ๐‘ฅ, .

Variational Free Energy for Trees

As ๐‘žG ๐‘ฅG also appears on the numerator, we can write the distribution as

โˆ+ ๐‘ž+(๐’™+)โˆ( ๐‘ž( ๐‘ฅ( ?>:

Multiplying the numerator and denominator by โˆ( ๐‘ž((๐‘ฅ() and observing that โˆ#โˆ1โˆˆ&(#) ๐‘ž1(๐‘ฅ1) = โˆ( ๐‘ž( ๐‘ฅ( ? gives the required expression

๐‘ฅ!

๐‘ฅ"๐‘ฅ# ๐‘ฅ$

๐ด๐ต

๐ถ

We can write a tree distribution asโˆ+ ๐‘ž+ ๐’™+ โˆ( ๐‘ž((๐‘ฅ()โˆ+โˆ1โˆˆ&(#) ๐‘ž1(๐‘ฅ1)

To see this, we can write a tree distribution as

๐‘žG ๐‘ฅG :+

๐‘ž+(๐’™๐’‚|๐‘๐‘Ž ๐’™+ )

where ๐‘ฅGis the root, and ๐‘๐‘Ž ๐’™+ is the parent variable for factor ๐‘Ž in the treeWe write

๐‘ž+ ๐’™๐’‚ ๐‘๐‘Ž ๐’™+ =๐‘ž+(๐’™+)๐‘žH+(๐’™๐’‚)

Each variable appears as a parent ๐‘‘ โˆ’ 1 times, where ๐‘‘ is the degree of the variable node, except the root which appears ๐‘‘ times.

Substituting

๐‘ž ๐’™ =โˆ7 ๐‘ž7 ๐’™7 โˆ; ๐‘ž;(๐‘ฅ;)โˆ7โˆ;โˆˆ=(7) ๐‘ž;(๐‘ฅ;)

into ๐น ๐‘ž =<

'๐‘ž ๐’™ ๐ธ ๐’™ +<

'๐‘ž ๐’™ ln ๐‘ž(๐’™)

we get ๐น ๐‘ž

= โˆ’<7I@

B

<๐’™!๐‘ž7 ๐’™7 ln ๐œ“7 ๐’™7 +<

7I@

B

<๐’™๐’‚๐‘ž7 ๐’™๐’‚ ln

๐‘ž7(๐’™7)โˆ;โˆˆ=(7) ๐‘ž;(๐‘ฅ;)

+<;I@

=

<'"๐‘ž; ๐‘ฅ; ln ๐‘ž;(๐‘ฅ;)

Belief Propagation as Stationary Point of Bethe Free EnergyConsider optimizing Bethe free energy subject to constraints described. We form the Lagrangian

๐ฟ = ๐นIJ5KJ +9+

๐›พ+[1 โˆ’9๐’™2

๐‘ž+ ๐’™๐’‚ ] +9(๐›พ([1 โˆ’9

L5

๐‘ž( ๐‘ฅ( ]

+ 9(9

+โˆˆ*(()9

L5๐œ†+( ๐‘ฅ( [๐‘ž( ๐‘ฅ( โˆ’9

๐’™2\'5๐‘ž+ ๐’™๐’‚ ]

Differentiating with respect to ๐‘ž((๐‘ฅ() and setting to 00 = โˆ’๐‘‘( + 1 + ln ๐‘ž( ๐‘ฅ( โˆ’ ๐›พ( +9

+โˆˆ* (๐œ†+( ๐‘ฅ(

๐‘ž( ๐‘ฅ( โˆ exp โˆ’9+โˆˆ* (

๐œ†+( ๐‘ฅ( =:+โˆˆ*(()

๐‘š+((๐‘ฅ()

where ๐‘‘( is the degree of node and ๐‘š+((๐‘ฅ() =exp(โˆ’๐œ†+( ๐‘ฅ( )

๐น34564 ๐‘ž = ๐‘ˆ34564 ๐‘ž โˆ’ ๐ป34564(๐‘ž)

๐‘ˆ34564 ๐‘ž = โˆ’1,78

9

1"&๐‘ž, ๐’™, ln ๐œ“, ๐’™,

H34564

= โˆ’1,78

9

1๐’™๐’‚๐‘ž, ๐’™๐’‚ ln

๐‘ž, ๐’™,โˆ'โˆˆ< , ๐‘ž' ๐‘ฅ'

โˆ’1'78

<

1"(๐‘ž' ๐‘ฅ' ln ๐‘ž'(๐‘ฅ')

๐น34564 ๐‘ž = ๐‘ˆ34564 ๐‘ž โˆ’ ๐ป34564(๐‘ž)

๐‘ˆ34564 ๐‘ž = โˆ’1,78

9

1"&๐‘ž, ๐’™, ln ๐œ“, ๐’™,

H34564

= โˆ’1,78

9

1๐’™๐’‚๐‘ž, ๐’™๐’‚ ln

๐‘ž, ๐’™,โˆ'โˆˆ< , ๐‘ž' ๐‘ฅ'

โˆ’1'78

<

1"(๐‘ž' ๐‘ฅ' ln ๐‘ž'(๐‘ฅ')

Previously:

๐‘ž' ๐‘ฅ' โˆ exp โˆ’1,โˆˆ< '

๐œ†,' ๐‘ฅ'

Consider optimizing Bethe free energy subject to constraints described

๐ฟ = ๐น)jklj +'#

๐›พ#[1 โˆ’'๐’™$

๐‘ž# ๐’™๐’‚ ] +'1๐›พ1[1 โˆ’'

-&

๐‘ž1 ๐‘ฅ1 ]

+ '1'

#โˆˆ&(1)'

-&๐œ†#1 ๐‘ฅ1 [๐‘ž1 ๐‘ฅ1 โˆ’'

๐’™$\]&๐‘ž# ๐’™๐’‚ ]

Differentiating with respect to ๐‘ž# ๐’™๐’‚ and setting to 00 = โˆ’ ln๐œ“# ๐’™# + 1 + ln ๐‘ž# ๐’™๐’‚โˆ’ ln(

1โˆˆ& #๐‘ž1 ๐‘ฅ1 โˆ’๐›พ# โˆ’'

1โˆˆ& #๐œ†#1 ๐‘ฅ1

= โˆ’ ln๐œ“# ๐’™# + ln ๐‘ž# ๐’™๐’‚ + ๐‘๐‘œ๐‘›๐‘ ๐‘ก+'

1โˆˆ&(#)'

#โˆˆ&(1)๐œ†#1(๐‘ฅ1) โˆ’'

1โˆˆ& #๐œ†#1 ๐‘ฅ1

๐‘ž# ๐’™๐’‚ โˆ ๐œ“# ๐’™# exp โˆ’'1โˆˆ& #

'mโˆˆ& 1 \\

๐œ†m1 ๐‘ฅ1

= ๐œ“# ๐‘ฅ# (1โˆˆ&(#)

(mโˆˆ&(1)\\

๐‘šm1(๐‘ฅ1)

where ๐‘šm1 ๐‘ฅ1 = exp(โˆ’๐œ†m1 ๐‘ฅ1 )

For โˆ‘7!\6 ๐‘ž,(๐‘ฅ,) to be consistent with ๐‘ž%(๐‘ฅ%) = โˆfโˆˆ/(%)๐‘šf%(๐‘ฅ%), we get the belief propagation equation

๐‘š,% ๐‘ฅ% =T7!\6

๐œ“, ๐‘ฅ, F*โˆˆ/(,)\6

Ffโˆˆ/(*)\3

๐‘šf*(๐‘ฅ*)

Stationary points of the Bethe free energy are fixed points of loopy belief propagation updates

Previously:

๐‘ž% ๐‘ฅ% =Ffโˆˆ/(%)

๐‘šf%(๐‘ฅ%)

๐‘ž, ๐’™๐’‚ = ๐œ“, ๐‘ฅ, โˆ*โˆˆ/(,)โˆ fโˆˆ/(*)\3๐‘šf*(๐‘ฅ*)

[Fig from Yedidia et. al. 2005]

Bellman Update ContractionWe show that Bellman update is a contraction

๐ต๐‘‰! โˆ’ ๐ต๐‘‰# โ‰ค ๐›พ ๐‘‰! โˆ’ ๐‘‰#

First we show that for any two functions ๐‘“ ๐‘Ž and ๐‘”(๐‘Ž)|max,

๐‘“ ๐‘Ž โˆ’ max,๐‘” ๐‘Ž | โ‰ค max

,|๐‘“ ๐‘Ž โˆ’ ๐‘” ๐‘Ž |

To see this, assume consider the case where max,๐‘“ ๐‘Ž โ‰ฅ max

,๐‘”(๐‘Ž). Then

|max,๐‘“ ๐‘Ž โˆ’ max

,๐‘” ๐‘Ž | = max

,๐‘“ ๐‘Ž โˆ’ max

,๐‘” ๐‘Ž

= ๐‘“ ๐‘Žโˆ— โˆ’max,๐‘” ๐‘Ž

โ‰ค ๐‘“ ๐‘Žโˆ— โˆ’ ๐‘” ๐‘Žโˆ—โ‰ค max

,|๐‘“ ๐‘Ž โˆ’ ๐‘” ๐‘Ž |

where ๐‘Žโˆ— = ๐‘Ž๐‘Ÿ๐‘”๐‘š๐‘Ž๐‘ฅ, ๐‘“(๐‘Ž). The case where max,๐‘” ๐‘Ž โ‰ฅ max

,๐‘“ ๐‘Ž is similar.

Now, for any state ๐‘ ๐ต๐‘‰! ๐‘  โˆ’ ๐ต๐‘‰# ๐‘ 

= ๏ฟฝ

๏ฟฝ

๐‘š๐‘Ž๐‘ฅ, THA

๐‘ƒ ๐‘ A ๐‘ , ๐‘Ž ๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰! ๐‘ โ€ฒ

โˆ’ ๐‘š๐‘Ž๐‘ฅ, THA

๐‘ƒ ๐‘ A ๐‘ , ๐‘Ž ๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰# ๐‘ โ€ฒ

โ‰ค ๐‘š๐‘Ž๐‘ฅ, TH2๐‘ƒ ๐‘ A ๐‘ , ๐‘Ž ๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰! ๐‘ โ€ฒ โˆ’T

H2๐‘ƒ ๐‘ A ๐‘ , ๐‘Ž ๐‘… ๐‘ , ๐‘Ž, ๐‘ A + ๐›พ๐‘‰# ๐‘ โ€ฒ

= ๐›พ๐‘š๐‘Ž๐‘ฅ, TH2๐‘ƒ ๐‘ A ๐‘ , ๐‘Ž (๐‘‰! ๐‘ โ€ฒ โˆ’ ๐‘‰# ๐‘ โ€ฒ )

= ๐›พ TH2๐‘ƒ(๐‘ A|๐‘ , ๐‘Žโˆ—)(๐‘‰! ๐‘ โ€ฒ โˆ’ ๐‘‰# ๐‘ โ€ฒ )

where we have use |max,

๐‘“ ๐‘Ž โˆ’ max,๐‘” ๐‘Ž | โ‰ค max

,|๐‘“ ๐‘Ž โˆ’ ๐‘” ๐‘Ž |.

Finally, we show contraction๐ต๐‘‰, โˆ’ ๐ต๐‘‰V = max

T|๐ต๐‘‰, ๐‘  โˆ’ ๐ต๐‘‰V ๐‘  |

โ‰ค ๐›พmaxT

4Th๐‘ƒ ๐‘ U ๐‘ , ๐‘Žโˆ— ๐‘‰, ๐‘ U โˆ’ ๐‘‰V ๐‘ U

โ‰ค ๐›พmaxT

๐‘‰, ๐‘  โˆ’ ๐‘‰V ๐‘ = ๐›พ ๐‘‰, โˆ’ ๐‘‰V