(This notebook was assembled in parts from an official Pyro example and Chapter 1 of the Problang webbook. The code is adapted from the Pyro example to fit the educational structure of the Problang chapter. I attempted to stay as close to the webbook as possible, but some minor alterations have been made, that hopefully stay true to the spirit of the Problang book. ~ Marv
Much work in formal semantics follows the tradition of positing systematic but inflexible theories of meaning. However, in practice, the meaning we derive from language is heavily dependent on nearly all aspects of context, both linguistic and situational. To formally explain these nuanced aspects of meaning and better understand the compositional mechanism that delivers them, recent work in formal pragmatics recognizes semantics not as one of the final steps in meaning calculation, but rather as one of the first.
Within the Bayesian Rational Speech Act framework (Frank and Goodman, 2012, RSA), speakers and listeners reason about each other’s reasoning about the literal interpretation of utterances. The resulting interpretation necessarily depends on the literal interpretation of an utterance, but is not necessarily wholly determined by it. This move — reasoning about likely interpretations — provides ready explanations for complex phenomena ranging from metaphor (Kao et al., 2014) and hyperbole (Kao et al., 2014) to irony and humor.
More recent research has also consisted in practical applications of RSA, such as a challenge-winning RSA-based Vision-and-Language Navigation agent (Anderson et al. (2017)) and an efficient Multi-Agent RL framework (with a Benchmark on StarCraft II) (Kang et al. (2020)).
# first some imports
import torch # pyro is a thin wrapper library around torch
torch.set_default_dtype(torch.float64) # double precision for numerical stability
from collections import namedtuple
import argparse
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro_rsa_book_utils.search_inference import Marginal
from pyro_rsa_book_utils import visualize
import warnings
warnings.filterwarnings("ignore")
Probabilistic Programming Languages (PPLs) allow for a mix of ordinary deterministic computation and randomly sampled values representing a generative process for data.
Probabilistic programs in Pyro are built up around samples from primitive probability distributions, marked by pyro.sample
or pyro.factor
:
def pyro_program():
# introduce randomness through pyro.sample:
unfair_coin = pyro.sample("example", dist.Categorical(probs=torch.Tensor([0.1, 0.9]))).item()
# pyro functions may include arbitrary control flow:
if unfair_coin == 1:
return 10
else:
return 5
pyro_program()
10
We see on multiple executions that pyro_program()
is indeed random.
How is this better than using libraries such as random
, you ask?
Well, Pyro
contains many tools to infer distributions of pyro.sample
statements and of return values based on function arguments and the internal control flow.
Let's decorate the above function with @Marginal
, which will automatically construct a distribution over the return values without knowing about the code inside the function. This is done through an exhaustive search over all possible executions.
@Marginal # <- runs an exhaustive search over all sample/factor/.. calls to get an accurate distribution over return values
def pyro_program_decorated():
# control flow may include Pyro's "sample" or "factor" statements
unfair_coin = pyro.sample("example", dist.Categorical(probs=torch.Tensor([0.1, 0.9])))
if unfair_coin == 1:
return 10
else:
return 5
distribution = pyro_program_decorated()
visualize(distribution)
For your solutions and/or better understanding, feel free to consult everything in pyro_rsa_book_utils
, or comment out the below lines:
# print(dir(distribution))
# print(help(distribution))
# distribution.log_prob?
Some useful pointers in the awesome explanatory Pyro example notebooks are the intros to Pyro models, Pyro Inference, and the RSA examples.
The Rational Speech Act (RSA) framework views communication as recursive reasoning between a speaker and a listener. The listener interprets the speaker’s utterance by reasoning about a cooperative speaker trying to inform a naive listener about some state of affairs.
Using Bayesian inference, the listener reasons about what the state of the world is likely to be given that a speaker produced some utterance, knowing that the speaker is reasoning about how a listener is most likely to interpret that utterance.
Thus, we have (at least) three levels of inference:
To make this architecture more intelligible, let’s consider a concrete example and a vanilla version of an RSA model. In its initial formulation, Frank and Goodman, 2012 use the basic RSA framework to model referent choice in efficient communication. Let’s suppose that there are only three objects that the speaker and listener want to talk about, as in Fig. 1.
Fig 1: Example referential communication scenario from Frank and Goodman, 2012. Speakers make a one-word utterance, u, to signal an object, s.
In this reference game, a speaker wants to refer to one of the given objects:
$S = \{blue\; square, blue\; circle, green\; square\}$
The speaker may only utter one property to do this:
$U=\{“square”, “circle”, “green”, “blue”\}$
As mentioned before, a vanilla RSA model for this scenario consists of three recursively layered, conditional probability rules for speaker production and listener interpretation. The idea in this game is that a pragmatic speaker $S_1$ chooses a word $u$ to best signal an object $s$ to a literal listener $L_0$, who interprets $u$ literally and finds the objects that are compatible with the meaning of $u$. The pragmatic listener $L_1$ reasons about the speaker’s reasoning and interprets $u$ accordingly, using Bayes’ rule; $L_1$ also weighs in the prior probability of objects in the scenario (i.e., an object’s salience, $\mathbb{P}(s)$):
At the base of this reasoning hierarchy, the naive, literal listener $L_0$ interprets an utterance according to its meaning. That is, $L_0$ computes the probability of $s$ given $u$ according to the semantics of $u$ and the prior probability of $s$. A standard view of the semantic content of an utterance suffices: a mapping from states of the world to truth values. For example, the utterance "blue" is true of states $blue\;square$ and $blue\;circle$ and false of state $green\;square$.
We write $[[u]]:S↦\{0,1\}$ for the denotation function of this standard, Boolean semantics of utterances in terms of states. The literal listener is then defined via a function $\mathbb{P}_{L_0}$ that maps each utterance to a probability distribution over world states, like so:
$$\mathbb{P}_{L_0}(s∣u)∝[[u]](s)\cdot\mathbb{P}(s)$$Here, $\mathbb{P}(s)$ is an a priori belief regarding which state or object the speaker is likely to refer to in general. These prior beliefs can capture general world knowledge, perceptual salience, or other things. For the time being, we assume a flat prior belief, meaning each object is equally likely.
The literal listener rule can be implemented as follows in Pyro:
# Set of states as in Fig. 1.
objects = [
"blue square",
"blue circle",
"green square"
]
# uniform distribution over world states
def object_prior():
prior_dist = dist.Categorical(probs=torch.ones( len(objects)) / len(objects) )
idx = pyro.sample("object", prior_dist)
return objects[idx]
# set of utterances (all property values, in this case)
utterances = ["blue", "green", "square", "circle"]
# literal meaning function [[u]](s) to interpret the utterances
meaning = lambda utterance, obj: utterance in obj
@Marginal
def literal_listener(utterance):
obj = object_prior()
# values below are log probabilities
pyro.factor("literal_meaning", 0. if meaning(utterance, obj) else -999999.)
return obj
Let's run this!
utterance = "blue"
l0_dist = literal_listener(utterance)
visualize(l0_dist)
objectPrior()
returns a sample from a uniformCategorical
distribution over the possible objects of reference. What happens when the listener's beliefs are not uniform over the possible objects of reference (e.g., the "blue circle" is very salient)?utterance
?pyro.sample
and pyro.factor
do, respectively? Consult the Pyro Docs (These, along with pyro.param
are the three most basic Pyro functions, which are called primitives)).literal_listener
and marginal_l0_dist
. What is the difference?Fantastic! We now have a way of integrating a listener’s prior beliefs about the world with the truth functional meaning of an utterance.
Speech acts are actions; thus, the speaker is modeled as a rational (Bayesian) actor. He chooses an action (e.g., an utterance) according to its utility. The speaker simulates taking an action, evaluates its utility, and makes a choice based on the utilities. Rationality of choice is often defined as choice of an action that maximizes the agent’s (expected) utility. Here we consider a generalization in which speakers use a softmax function to approximate the (classical) rational choice to a variable degree. (For more on action as inverse planning, see www.agentmodels.org.)
In the code box below, you’ll see a generic approximately rational agent model. In rough terms, what happens is: The Search resulting from the call to the decorated agent
considers the probabilities of the three actions (by enumeration), it first calculates a log_prob
for each action (e.g., by evaluating the factor statement and considering the fact that each action is a draw from a uniform distribution), and then creates a Distribution by computing normalized probabilities from these.
In effect, the function agent
therefore computes the distribution:
Here, $\alpha$ is also called optimality parameter, or temperature. The above Softmax distribution frequently occurs in the output layer of Neural-Network-based Classifiers.
# aside on bayesian decision making
# define action space
actions = ["a1", "a2", "a3"]
# define uniform prior over actions
def action_prior():
action_dist = dist.Categorical(probs = torch.ones(len(actions)) / len(actions) )
action = pyro.sample("action", action_dist)
return action
# define utilities for the actions
def utility(tensor_action_idx):
action = actions[int(tensor_action_idx.item())]
lookup_table = {
"a1": 1.,
"a2": 2.,
"a3": 3.
}
return lookup_table[action]
# define actor optimality
alpha = 1
# define rational agent who chooses actions
# according to their expected utility
@Marginal
def agent():
action = action_prior()
pyro.factor("utility", alpha * utility(action))
return action
agent_dist = agent()
visualize(agent_dist)
utility()
returns the correct value for $a_3$.In language understanding, the utility of an utterance is how well it communicates the state of the world $s$ to a listener. So, the speaker $S_1$ chooses utterances $u$ to communicate the state $s$ to the hypothesized literal listener $L_0$. Another way to think about this: $S_1$ wants to minimize the effort $L_0$ would need to arrive at $s$ from $u$, all while being efficient at communicating. $S_1$ thus seeks to minimize the surprisal of $s$ given $u$ for the literal listener $L_0$, while bearing in mind the utterance cost, $C(u)$. (This trade-off between efficacy and efficiency is not trivial: speakers could always use minimal ambiguity, but unambiguous utterances tend toward the unwieldy, and, very often, unnecessary. We will see this tension play out later in the book.)
Speakers act in accordance with the speaker’s utility function $U_{S_1}$: utterances are more useful at communicating about some state as surprisal and utterance cost decrease. (See the Appendix Chapter 2 for more on speaker utilities.)
With this utility function in mind, $S_1$ computes the probability of an utterance $u$ given some state $s$ in proportion to the speaker's utility function $U_{S_1}$. We can now interpret $\alpha > 0$ as the speaker's rationality in choosing utterances. We define
$$P_{S_1}(u|s)\propto\exp(\alpha\cdot U_{S_1}(u;s)),$$which expands to
$$P_{S_1}(u|s)\propto\exp(\alpha(\log L_0(s|u)-C(u)))$$The following code implements this model of the speaker:
# your code here
# pragmatic speaker
@Marginal
def pragmatic_speaker(obj):
# your code here
return utterance
idx = 0
s1_dist = pragmatic_speaker(objects[idx])
visualize(s1_dist)
speaker
appropriately. You could make reasonable choices for $\alpha$ and $C(\cdot)$, or experiment. It should be enough to keep all distributions uniform for a start.The pragmatic listener $L_1$ computes the probability of a state $s$ given some utterance $u$. By reasoning about the speaker $S_1$, this probability is proportional to the probability that $S_1$ would choose to utter $u$ to communicate about the state $s$, together with the prior probability of $s$ itself. In other words, to interpret an utterance, the pragmatic listener considers the process that generated the utterance in the first place.
$$P_{L_1}(s|u)\propto P_{S_1}(u|s)\cdot P(s)$$@Marginal
def pragmatic_listener(utterance):
# your code here
# condition on observed utterance:
pyro.sample("l1_simulation_of_speaker", speaker_dist, obs=utterance)
return obj
utt = "blue"
l1_dist = pragmatic_listener(utt)
visualize(l1_dist)
pragmatic_listener
.Awesome, that's it. Weve reached the top of the vanilla RSA inference hierarchy, so pat yourself on the back for making it this far.
Let's quickly throw it all together again for some more exercises.
# TODO copy everything from above
blue
to refer to something green? Why or why not?belong
to, and who, out of speaker and listener, has access to them). How would you address this, i.e. think of a speaker and listener that do not have access to each other. What additional problems would you expect in this case? (For further discussion, see Sec. 4.3 in Kang et al. (2020).)Happy coding :)