From 96134a9d9a3e1bd32513977ac76d97716016615d Mon Sep 17 00:00:00 2001 From: Hirosora <474516010@qq.com> Date: Sat, 10 Aug 2019 02:53:58 +0800 Subject: [PATCH] Update models --- README.md | 4 ++- core/blocks.py | 88 +++++++++++++++++++++++++++++++++++++++++++++++ core/utils.py | 4 +-- models/FiBiNet.py | 64 ++++++++++++++++++++++++++++++++++ models/NFFM.py | 77 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 234 insertions(+), 3 deletions(-) create mode 100644 models/FiBiNet.py create mode 100644 models/NFFM.py diff --git a/README.md b/README.md index 27bcc79..241e3bc 100644 --- a/README.md +++ b/README.md @@ -24,4 +24,6 @@ dataset](https://www.kaggle.com/c/avazu-ctr-prediction) with 10000 rows. | AutoInt | [arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | | Convolutional Click Prediction Model (CCPM) | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) | | Feature Generation by Convolutional Neural Network (FGCNN) | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction ](https://arxiv.org/pdf/1904.04447) | -| Mixed Logistic Regression (MLR) | [arxiv 2017][Learning Piece-wise Linear Models from Large Scale Data for Ad Click Prediction](https://arxiv.org/pdf/1704.05194.pdf) | \ No newline at end of file +| Mixed Logistic Regression (MLR) | [arxiv 2017][Learning Piece-wise Linear Models from Large Scale Data for Ad Click Prediction](https://arxiv.org/pdf/1704.05194.pdf) | +| FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) | +| NFFM | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) | diff --git a/core/blocks.py b/core/blocks.py index c82a0e9..b8f70b7 100644 --- a/core/blocks.py +++ b/core/blocks.py @@ -1,4 +1,5 @@ from collections import Iterable +import itertools import tensorflow as tf from tensorflow.python.keras import layers @@ -423,3 +424,90 @@ def call(self, inputs, **kwargs): shape=(-1, output.shape[1] * self.new_feat_filters, output.shape[2])) return output, new_feat_output + + +class BiInteraction(tf.keras.Model): + + def __init__(self, mode='all', **kwargs): + + super(BiInteraction, self).__init__(**kwargs) + + self.mode = mode + + def call(self, inputs, **kwargs): + + output = list() + embedding_size = inputs[0].shape[-1] + + if self.mode == 'all': + W = self.add_weight( + shape=(embedding_size, embedding_size), + initializer='glorot_uniform', + regularizer=tf.keras.regularizers.l2(1e-5), + trainable=True + ) + for i in range(len(inputs) - 1): + for j in range(i, len(inputs)): + inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] + output.append(inter) + + elif self.mode == 'each': + for i in range(len(inputs) - 1): + W = self.add_weight( + shape=(embedding_size, embedding_size), + initializer='glorot_uniform', + regularizer=tf.keras.regularizers.l2(1e-5), + trainable=True + ) + for j in range(i, len(inputs)): + inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] + output.append(inter) + + elif self.mode == 'interaction': + for i in range(len(inputs) - 1): + for j in range(i, len(inputs)): + W = self.add_weight( + shape=(embedding_size, embedding_size), + initializer='glorot_uniform', + regularizer=tf.keras.regularizers.l2(1e-5), + trainable=True + ) + inter = tf.tensordot(inputs[i], W, axes=(-1, 0)) * inputs[j] + output.append(inter) + + output = tf.concat(output, axis=1) + return output + + +class SENet(tf.keras.Model): + + def __init__(self, axis=-1, reduction=4, **kwargs): + + super(SENet, self).__init__(**kwargs) + + self.axis = axis + self.reduction = reduction + + def call(self, inputs, **kwargs): + + # inputs [batch_size, feats_num, embedding_size] + feats_num = inputs.shape[1] + + weights = tf.reduce_mean(inputs, axis=self.axis, keepdims=False) # [batch_size, feats_num] + W1 = self.add_weight( + shape=(feats_num, self.reduction), + trainable=True, + initializer='glorot_normal' + ) + W2 = self.add_weight( + shape=(self.reduction, feats_num), + trainable=True, + initializer='glorot_normal' + ) + weights = tf.keras.activations.relu(tf.tensordot(weights, W1, axes=(-1, 0))) + weights = tf.keras.activations.relu(tf.tensordot(weights, W2, axes=(-1, 0))) + + weights = tf.expand_dims(weights, axis=-1) + output = tf.multiply(weights, inputs) # [batch_size, feats_num, embedding_size] + + return output diff --git a/core/utils.py b/core/utils.py index 0bb8066..b20b71f 100644 --- a/core/utils.py +++ b/core/utils.py @@ -81,8 +81,8 @@ def group_embedded_by_dim(embedded_dict): """ Group a embedded features' dict according to embedding dimension. - :param embedded_dict: Dict of embedded sparse features {name: 3D_embedded_feature} - :return: Dict of grouped embedded features {embedding_dim: [3D_embedded_features]} + :param embedded_dict: Dict of embedded sparse features {name: embedded_feature} + :return: Dict of grouped embedded features {embedding_dim: [embedded_features]} """ groups = dict() diff --git a/models/FiBiNet.py b/models/FiBiNet.py new file mode 100644 index 0000000..0bdd1fd --- /dev/null +++ b/models/FiBiNet.py @@ -0,0 +1,64 @@ +import tensorflow as tf + +from core.features import FeatureMetas, Features +from core.blocks import DNN, BiInteraction, SENet +from core.utils import split_tensor + + +def FiBiNet( + feature_metas, + interaction_mode='all', + interaction_mode_se='all', + embedding_initializer='glorot_uniform', + embedding_regularizer=tf.keras.regularizers.l2(1e-5), + fixed_embedding_dim=None, + dnn_hidden_units=(128, 64, 1), + dnn_activations=('relu', 'relu', None), + dnn_use_bias=True, + dnn_use_bn=False, + dnn_dropout=0, + dnn_kernel_initializers='glorot_uniform', + dnn_bias_initializers='zeros', + dnn_kernel_regularizers=tf.keras.regularizers.l2(1e-5), + dnn_bias_regularizers=None, + name='FiBiNet'): + + assert isinstance(feature_metas, FeatureMetas) + + with tf.name_scope(name): + + features = Features(metas=feature_metas) + + embedded_dict = features.get_embedded_dict( + group_name='embedding', + fixed_embedding_dim=fixed_embedding_dim, + embedding_initializer=embedding_initializer, + embedding_regularizer=embedding_regularizer, + slots_filter=None + ) + inputs = list(embedded_dict.values()) + interactions = BiInteraction(mode=interaction_mode)(inputs) + + inputs_se = SENet(axis=-1)(tf.stack(inputs, axis=1)) + interactions_se = BiInteraction(mode=interaction_mode_se)(split_tensor(inputs_se, axis=1)) + + dnn_inputs = tf.concat([interactions, interactions_se], axis=1) + + dnn_output = DNN( + units=dnn_hidden_units, + use_bias=dnn_use_bias, + activations=dnn_activations, + use_bn=dnn_use_bn, + dropout=dnn_dropout, + kernel_initializers=dnn_kernel_initializers, + bias_initializers=dnn_bias_initializers, + kernel_regularizers=dnn_kernel_regularizers, + bias_regularizers=dnn_bias_regularizers + )(dnn_inputs) + + # Output + output = tf.keras.activations.sigmoid(dnn_output) + + model = tf.keras.Model(inputs=features.get_inputs_list(), outputs=output) + + return model diff --git a/models/NFFM.py b/models/NFFM.py new file mode 100644 index 0000000..65f665a --- /dev/null +++ b/models/NFFM.py @@ -0,0 +1,77 @@ +import tensorflow as tf + +from core.features import FeatureMetas, Features +from core.blocks import DNN, BiInteraction +from core.utils import group_embedded_by_dim + + +def NFFM( + feature_metas, + biinteraction_mode='all', + embedding_initializer='glorot_uniform', + embedding_regularizer=tf.keras.regularizers.l2(1e-5), + fm_fixed_embedding_dim=None, + linear_use_bias=True, + linear_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + linear_kernel_regularizer=tf.keras.regularizers.l2(1e-5), + dnn_hidden_units=(128, 64, 1), + dnn_activations=('relu', 'relu', None), + dnn_use_bias=True, + dnn_use_bn=False, + dnn_dropout=0, + dnn_kernel_initializers='glorot_uniform', + dnn_bias_initializers='zeros', + dnn_kernel_regularizers=tf.keras.regularizers.l2(1e-5), + dnn_bias_regularizers=None, + name='NFFM'): + + assert isinstance(feature_metas, FeatureMetas) + + with tf.name_scope(name): + + features = Features(metas=feature_metas) + + # Linear Part + with tf.name_scope('Linear'): + linear_output = features.get_linear_logit(use_bias=linear_use_bias, + kernel_initializer=linear_kernel_initializer, + kernel_regularizer=linear_kernel_regularizer, + embedding_group='dot_embedding', + slots_filter=None) + + # Interaction Part + with tf.name_scope('Interaction'): + fm_embedded_dict = features.get_embedded_dict( + group_name='embedding', + fixed_embedding_dim=fm_fixed_embedding_dim, + embedding_initializer=embedding_initializer, + embedding_regularizer=embedding_regularizer, + slots_filter=None + ) + fm_dim_groups = group_embedded_by_dim(fm_embedded_dict) + interactions = list() + for fm_group in fm_dim_groups.values(): + group_interaction = BiInteraction(mode=biinteraction_mode)(fm_group) + interactions.append(group_interaction) + + interactions = tf.concat(interactions, axis=1) + + dnn_output = DNN( + units=dnn_hidden_units, + use_bias=dnn_use_bias, + activations=dnn_activations, + use_bn=dnn_use_bn, + dropout=dnn_dropout, + kernel_initializers=dnn_kernel_initializers, + bias_initializers=dnn_bias_initializers, + kernel_regularizers=dnn_kernel_regularizers, + bias_regularizers=dnn_bias_regularizers + )(interactions) + + # Output + output = tf.add_n([linear_output, dnn_output]) + output = tf.keras.activations.sigmoid(output) + + model = tf.keras.Model(inputs=features.get_inputs_list(), outputs=output) + + return model