1
- # Copyright 2018 The Texar Authors. All Rights Reserved.
1
+ # Copyright 2019 The Texar Authors. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
@@ -54,7 +54,10 @@ def _assert_same_size(outputs, output_size):
54
54
flat_output = nest .flatten (outputs )
55
55
56
56
for (output , size ) in zip (flat_output , flat_output_size ):
57
- if output [0 ].shape != tf .TensorShape (size ):
57
+ if isinstance (size , tf .TensorShape ):
58
+ if output .shape == size :
59
+ pass
60
+ elif output [0 ].shape != tf .TensorShape (size ):
58
61
raise ValueError (
59
62
"The output size does not match the the required output_size" )
60
63
@@ -518,7 +521,8 @@ class instance.
518
521
- output: A Tensor or a (nested) tuple of Tensors with the same \
519
522
structure and size of :attr:`output_size`. The batch dimension \
520
523
equals :attr:`num_samples` if specified, or is determined by the \
521
- distribution dimensionality.
524
+ distribution dimensionality. If :attr:`transform` is `False`, \
525
+ :attr:`output` will be equal to :attr:`sample`.
522
526
- sample: The sample from the distribution, prior to transformation.
523
527
524
528
Raises:
@@ -549,9 +553,10 @@ class instance.
549
553
fn_modules = ['tensorflow' , 'tensorflow.nn' , 'texar.custom' ]
550
554
activation_fn = get_function (self .hparams .activation_fn , fn_modules )
551
555
output = _mlp_transform (sample , self ._output_size , activation_fn )
556
+ else :
557
+ output = sample
552
558
553
559
_assert_same_size (output , self ._output_size )
554
-
555
560
if not self ._built :
556
561
self ._add_internal_trainable_variables ()
557
562
self ._built = True
@@ -616,7 +621,7 @@ def default_hparams():
616
621
def _build (self ,
617
622
distribution = 'MultivariateNormalDiag' ,
618
623
distribution_kwargs = None ,
619
- transform = False ,
624
+ transform = True ,
620
625
num_samples = None ):
621
626
"""Samples from a distribution and optionally performs transformation
622
627
with an MLP layer.
@@ -649,7 +654,8 @@ class instance.
649
654
- output: A Tensor or a (nested) tuple of Tensors with the same \
650
655
structure and size of :attr:`output_size`. The batch dimension \
651
656
equals :attr:`num_samples` if specified, or is determined by the \
652
- distribution dimensionality.
657
+ distribution dimensionality. If :attr:`transform` is `False`, \
658
+ :attr:`output` will be equal to :attr:`sample`.
653
659
- sample: The sample from the distribution, prior to transformation.
654
660
655
661
Raises:
@@ -661,31 +667,32 @@ class instance.
661
667
"tensorflow.contrib.distributions" , "texar.custom" ])
662
668
663
669
if num_samples :
664
- output = dstr .sample (num_samples )
670
+ sample = dstr .sample (num_samples )
665
671
else :
666
- output = dstr .sample ()
672
+ sample = dstr .sample ()
667
673
668
674
if dstr .event_shape == []:
669
- output = tf .reshape (output ,
670
- output .shape .concatenate (tf .TensorShape (1 )))
675
+ sample = tf .reshape (sample ,
676
+ sample .shape .concatenate (tf .TensorShape (1 )))
671
677
672
678
# Disable gradients through samples
673
- output = tf .stop_gradient (output )
679
+ sample = tf .stop_gradient (sample )
674
680
675
- output = tf .cast (output , tf .float32 )
681
+ sample = tf .cast (sample , tf .float32 )
676
682
677
683
if transform :
678
684
fn_modules = ['tensorflow' , 'tensorflow.nn' , 'texar.custom' ]
679
685
activation_fn = get_function (self .hparams .activation_fn , fn_modules )
680
- output = _mlp_transform (output , self ._output_size , activation_fn )
686
+ output = _mlp_transform (sample , self ._output_size , activation_fn )
687
+ else :
688
+ output = sample
681
689
682
690
_assert_same_size (output , self ._output_size )
683
-
684
691
if not self ._built :
685
692
self ._add_internal_trainable_variables ()
686
693
self ._built = True
687
694
688
- return output
695
+ return output , sample
689
696
690
697
691
698
#class ConcatConnector(ConnectorBase):
0 commit comments