Skip to content

Commit

Permalink
New model xDeepFM
Browse files Browse the repository at this point in the history
  • Loading branch information
张浩天 committed Jul 28, 2019
1 parent 1040eef commit d387332
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
| Wide & Deep | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792.pdf)|
| DeepFM | [IJCAI 2017][DeepFM: A Factorization-Machine based Neural Network for CTR Prediction](http://www.ijcai.org/proceedings/2017/0239.pdf)|
| Deep & Cross Network (DCN) | [ADKDD 2017][Deep & Cross Network for Ad Click Predictions](https://arxiv.org/abs/1708.05123) |
| xDeepFM | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf) |
60 changes: 58 additions & 2 deletions core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,16 @@ class CrossNetwork(tf.keras.Model):
def __init__(self,
kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-5),
bias_initializer='zeros',
bias_regularizer=None,
**kwargs):

super(CrossNetwork, self).__init__(**kwargs)

self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_initializer = bias_initializer
self.bias_regularizer = bias_regularizer

def call(self, inputs, layers_num=3, require_logit=True, **kwargs):

Expand All @@ -214,8 +218,8 @@ def call(self, inputs, layers_num=3, require_logit=True, **kwargs):
regularizer=self.kernel_regularizer,
trainable=True)
bias = self.add_weight(shape=(x0.shape[1], 1),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
trainable=True)
x = tf.matmul(tf.matmul(x0, x), kernel) + bias + tf.transpose(x, [0, 2, 1])
x = tf.transpose(x, [0, 2, 1])
Expand All @@ -229,3 +233,55 @@ def call(self, inputs, layers_num=3, require_logit=True, **kwargs):
x = tf.matmul(x, kernel)

return x


class CIN(tf.keras.Model):

def __init__(self,
kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.l2(1e-5),
**kwargs):

super(CIN, self).__init__(**kwargs)

self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer

def call(self, inputs, hidden_width=(128, 64), require_logit=True, **kwargs):

# [b, n, m]
x0 = tf.stack(inputs, axis=1)
x = tf.identity(x0)

hidden_width = [x0.shape[1]] + list(hidden_width)

finals = list()
for h in hidden_width:
rows = list()
cols = list()
for i in range(x0.shape[1]):
for j in range(x.shape[1]):
rows.append(i)
cols.append(j)
# [b, pair, m]
x0_ = tf.gather(x0, rows, axis=1)
x_ = tf.gather(x, cols, axis=1)
# [b, m, pair]
p = tf.transpose(tf.multiply(x0_, x_), [0, 2, 1])

kernel = self.add_weight(shape=(p.shape[-1], h),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
trainable=True)
# [b, h, m]
x = tf.transpose(tf.matmul(p, kernel), [0, 2, 1])
finals.append(tf.reduce_sum(x, axis=-1, keepdims=False))

finals = tf.concat(finals, axis=-1)
kernel = self.add_weight(shape=(finals.shape[-1], 1),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
trainable=True)
logit = tf.matmul(finals, kernel)

return logit
8 changes: 6 additions & 2 deletions models/DCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def DCN(
feature_metas,
cross_kernel_initializer='glorot_uniform',
cross_kernel_regularizer=tf.keras.regularizers.l2(1e-5),
cross_bias_initializer='zeros',
cross_bias_regularizer=None,
cross_layers_num=3,
embedding_initializer='glorot_uniform',
embedding_regularizer=tf.keras.regularizers.l2(1e-5),
Expand All @@ -21,7 +23,7 @@ def DCN(
dnn_bias_initializers='zeros',
dnn_kernel_regularizers=tf.keras.regularizers.l2(1e-5),
dnn_bias_regularizers=None,
name='Deep&Crossing Network'):
name='Deep&Cross Network'):

assert isinstance(feature_metas, FeatureMetas)

Expand Down Expand Up @@ -59,7 +61,9 @@ def DCN(
cross_inputs = list(embedded_dict.values())
cross_output = CrossNetwork(
kernel_initializer=cross_kernel_initializer,
kernel_regularizer=cross_kernel_regularizer
kernel_regularizer=cross_kernel_regularizer,
bias_initializer=cross_bias_initializer,
bias_regularizer=cross_bias_regularizer
)(cross_inputs, layers_num=cross_layers_num, require_logit=True)

output = tf.keras.activations.sigmoid(deep_output + cross_output)
Expand Down
84 changes: 84 additions & 0 deletions models/xDeepFM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import tensorflow as tf

from core.features import FeatureMetas, Features, group_embedded_by_dim
from core.blocks import DNN, CIN


def xDeepFM(
feature_metas,
linear_slots,
fm_slots,
dnn_slots,
embedding_initializer='glorot_uniform',
embedding_regularizer=tf.keras.regularizers.l2(1e-5),
fm_fixed_embedding_dim=None,
fm_kernel_initializer='glorot_uniform',
fm_kernel_regularizer=None,
linear_use_bias=True,
linear_kernel_initializer='glorot_uniform',
linear_kernel_regularizer=None,
dnn_hidden_units=(128, 64, 1),
dnn_activations=('relu', 'relu', None),
dnn_use_bias=True,
dnn_use_bn=False,
dnn_dropout=0,
dnn_kernel_initializers='glorot_uniform',
dnn_bias_initializers='zeros',
dnn_kernel_regularizers=tf.keras.regularizers.l2(1e-5),
dnn_bias_regularizers=None,
name='xDeepFM'):

assert isinstance(feature_metas, FeatureMetas)

with tf.name_scope(name):

features = Features(metas=feature_metas)

# Linear Part
with tf.name_scope('Linear'):
linear_output = features.get_linear_logit(use_bias=linear_use_bias,
kernel_initializer=linear_kernel_initializer,
kernel_regularizer=linear_kernel_regularizer,
embedding_group='dot_embedding',
slots_filter=linear_slots)

# FM Part
with tf.name_scope('FM'):
fm_embedded_dict = features.get_embedded_dict(group_name='embedding',
fixed_embedding_dim=fm_fixed_embedding_dim,
embedding_initializer=embedding_initializer,
embedding_regularizer=embedding_regularizer,
slots_filter=fm_slots)
fm_dim_groups = group_embedded_by_dim(fm_embedded_dict)
fms = [CIN(
kernel_initializer=fm_kernel_initializer,
kernel_regularizer=fm_kernel_regularizer
)(group) for group in fm_dim_groups.values() if len(group) > 1]
fm_output = tf.add_n(fms)

# DNN Part
with tf.name_scope('DNN'):
dnn_inputs = features.gen_concated_feature(embedding_group='embedding',
fixed_embedding_dim=fm_fixed_embedding_dim,
embedding_initializer=embedding_initializer,
embedding_regularizer=embedding_regularizer,
slots_filter=dnn_slots)
dnn_output = DNN(
units=dnn_hidden_units,
use_bias=dnn_use_bias,
activations=dnn_activations,
use_bn=dnn_use_bn,
dropout=dnn_dropout,
kernel_initializers=dnn_kernel_initializers,
bias_initializers=dnn_bias_initializers,
kernel_regularizers=dnn_kernel_regularizers,
bias_regularizers=dnn_bias_regularizers
)(dnn_inputs)

# Output
output = tf.add_n([linear_output, fm_output, dnn_output])
output = tf.keras.activations.sigmoid(output)

model = tf.keras.Model(inputs=features.get_inputs_list(), outputs=output)

return model

0 comments on commit d387332

Please sign in to comment.