Skip to content

Commit 962294b

Browse files
authored
Merge pull request #179 from TomNong/connector-update
Fix *StochasticConnector when `transform=False`
2 parents 5a8fb32 + 1d71c15 commit 962294b

File tree

4 files changed

+55
-17
lines changed

4 files changed

+55
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* Fix [GPT-2](https://github.com/asyml/texar/tree/master/examples/gpt-2) tokenization loading path. ([#165](https://github.com/asyml/texar/pull/165))
2222
* Fix [examples/vae_text](https://github.com/asyml/texar/tree/master/examples/vae_text) EOS bug. ([#168](https://github.com/asyml/texar/pull/168))
2323
* Fix transformer [bleu_tool.py](https://github.com/asyml/texar/blob/master/examples/transformer/bleu_tool.py) when `translation_length` is 0. ([#176](https://github.com/asyml/texar/pull/176))
24+
* Fix `StochasticConnector` and `ReparameterizedStochasticConnector` when `transform=False`. ([#179](https://github.com/asyml/texar/pull/179))
2425

2526
## [v0.2.0](https://github.com/asyml/texar/releases/tag/v0.2.0) (2019-04-09)
2627

texar/modules/connectors/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,3 @@
2323

2424
from texar.modules.connectors.connector_base import *
2525
from texar.modules.connectors.connectors import *
26-

texar/modules/connectors/connectors.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The Texar Authors. All Rights Reserved.
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -54,7 +54,10 @@ def _assert_same_size(outputs, output_size):
5454
flat_output = nest.flatten(outputs)
5555

5656
for (output, size) in zip(flat_output, flat_output_size):
57-
if output[0].shape != tf.TensorShape(size):
57+
if isinstance(size, tf.TensorShape):
58+
if output.shape == size:
59+
pass
60+
elif output[0].shape != tf.TensorShape(size):
5861
raise ValueError(
5962
"The output size does not match the the required output_size")
6063

@@ -518,7 +521,8 @@ class instance.
518521
- output: A Tensor or a (nested) tuple of Tensors with the same \
519522
structure and size of :attr:`output_size`. The batch dimension \
520523
equals :attr:`num_samples` if specified, or is determined by the \
521-
distribution dimensionality.
524+
distribution dimensionality. If :attr:`transform` is `False`, \
525+
:attr:`output` will be equal to :attr:`sample`.
522526
- sample: The sample from the distribution, prior to transformation.
523527
524528
Raises:
@@ -549,9 +553,10 @@ class instance.
549553
fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
550554
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
551555
output = _mlp_transform(sample, self._output_size, activation_fn)
556+
else:
557+
output = sample
552558

553559
_assert_same_size(output, self._output_size)
554-
555560
if not self._built:
556561
self._add_internal_trainable_variables()
557562
self._built = True
@@ -616,7 +621,7 @@ def default_hparams():
616621
def _build(self,
617622
distribution='MultivariateNormalDiag',
618623
distribution_kwargs=None,
619-
transform=False,
624+
transform=True,
620625
num_samples=None):
621626
"""Samples from a distribution and optionally performs transformation
622627
with an MLP layer.
@@ -649,7 +654,8 @@ class instance.
649654
- output: A Tensor or a (nested) tuple of Tensors with the same \
650655
structure and size of :attr:`output_size`. The batch dimension \
651656
equals :attr:`num_samples` if specified, or is determined by the \
652-
distribution dimensionality.
657+
distribution dimensionality. If :attr:`transform` is `False`, \
658+
:attr:`output` will be equal to :attr:`sample`.
653659
- sample: The sample from the distribution, prior to transformation.
654660
655661
Raises:
@@ -661,31 +667,32 @@ class instance.
661667
"tensorflow.contrib.distributions", "texar.custom"])
662668

663669
if num_samples:
664-
output = dstr.sample(num_samples)
670+
sample = dstr.sample(num_samples)
665671
else:
666-
output = dstr.sample()
672+
sample = dstr.sample()
667673

668674
if dstr.event_shape == []:
669-
output = tf.reshape(output,
670-
output.shape.concatenate(tf.TensorShape(1)))
675+
sample = tf.reshape(sample,
676+
sample.shape.concatenate(tf.TensorShape(1)))
671677

672678
# Disable gradients through samples
673-
output = tf.stop_gradient(output)
679+
sample = tf.stop_gradient(sample)
674680

675-
output = tf.cast(output, tf.float32)
681+
sample = tf.cast(sample, tf.float32)
676682

677683
if transform:
678684
fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
679685
activation_fn = get_function(self.hparams.activation_fn, fn_modules)
680-
output = _mlp_transform(output, self._output_size, activation_fn)
686+
output = _mlp_transform(sample, self._output_size, activation_fn)
687+
else:
688+
output = sample
681689

682690
_assert_same_size(output, self._output_size)
683-
684691
if not self._built:
685692
self._add_internal_trainable_variables()
686693
self._built = True
687694

688-
return output
695+
return output, sample
689696

690697

691698
#class ConcatConnector(ConnectorBase):

texar/modules/connectors/connectors_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from texar.core import layers
1616
from texar.modules import ConstantConnector
1717
from texar.modules import MLPTransformConnector
18-
from texar.modules import ReparameterizedStochasticConnector
18+
from texar.modules import (ReparameterizedStochasticConnector,
19+
StochasticConnector)
1920
from texar.modules.connectors.connectors import _assert_same_size
2021

2122
# pylint: disable=too-many-locals, invalid-name
@@ -132,6 +133,36 @@ def test_reparameterized_stochastic_connector(self):
132133
# self.assertAlmostEqual(0, sample_mu[i], delta=0.2)
133134
# self.assertAlmostEqual(1, sample_var[i], delta=0.2)
134135

136+
def test_stochastic_connector(self):
137+
"""Tests the logic of
138+
:class:`~texar.modules.StochasticConnector`.
139+
"""
140+
state_size = (10, 10)
141+
variable_size = 100
142+
state_size_ts = tf.TensorShape([self._batch_size, variable_size])
143+
gauss_connector = StochasticConnector(state_size)
144+
mu = tf.zeros([self._batch_size, variable_size])
145+
var = tf.ones([self._batch_size, variable_size])
146+
gauss_ds = tfpd.MultivariateNormalDiag(loc=mu, scale_diag=var)
147+
output_1, _ = gauss_connector(gauss_ds)
148+
149+
gauss_connector_2 = StochasticConnector(state_size_ts)
150+
output_2, sample2 = gauss_connector_2(
151+
distribution="MultivariateNormalDiag",
152+
distribution_kwargs={"loc": mu, "scale_diag": var}, transform=False)
153+
test_list = [output_1, output_2, sample2]
154+
155+
with self.test_session() as sess:
156+
sess.run(tf.global_variables_initializer())
157+
out_list = sess.run(test_list)
158+
out1 = out_list[0]
159+
out2 = out_list[1]
160+
sample2 = out_list[2]
161+
self.assertEqual(out1[0].shape,
162+
tf.TensorShape([self._batch_size, state_size[0]]))
163+
self.assertEqual(out2.shape, state_size_ts)
164+
self.assertEqual(out2.shape, sample2.shape)
165+
135166
#def test_concat_connector(self): # pylint: disable=too-many-locals
136167
# """Tests the logic of
137168
# :class:`~texar.modules.connectors.ConcatConnector`.

0 commit comments

Comments
 (0)