From 5abdbc320f00c406895607f697042c9a57f00161 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=B5=A9=E5=A4=A9?= Date: Mon, 29 Jul 2019 16:59:07 +0800 Subject: [PATCH] New model AFM --- core/blocks.py | 60 +++++++++++++++++++++++++++++++++++- models/AFM.py | 71 +++++++++++++++++++++++++++++++++++++++++++ models/DeepFM.py | 4 +-- models/NFM.py | 4 +-- models/WideAndDeep.py | 4 +-- models/xDeepFM.py | 4 +-- 6 files changed, 138 insertions(+), 9 deletions(-) create mode 100644 models/AFM.py diff --git a/core/blocks.py b/core/blocks.py index f1e3fb2..2819555 100644 --- a/core/blocks.py +++ b/core/blocks.py @@ -3,6 +3,12 @@ import tensorflow as tf from tensorflow.python.keras import layers +from tensorflow.python.keras.initializers import (Zeros, glorot_normal, + glorot_uniform) +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras import backend as K +import itertools + def get_activation(activation): if activation is None: @@ -151,7 +157,7 @@ def __init__(self, **kwargs): super(InnerProduct, self).__init__(**kwargs) - def call(self, inputs, **kwargs): + def call(self, inputs, concat=True, **kwargs): inner_products_list = list() @@ -290,3 +296,55 @@ def call(self, inputs, hidden_width=(128, 64), require_logit=True, **kwargs): logit = tf.matmul(finals, kernel) return logit + + +class AttentionBasedPoolingLayer(tf.keras.Model): + + def __init__(self, + attention_factor=4, + kernel_initializer='glorot_uniform', + kernel_regularizer=tf.keras.regularizers.l2(1e-5), + bias_initializer='zeros', + bias_regularizer=None, + **kwargs): + + super(AttentionBasedPoolingLayer, self).__init__(**kwargs) + + self.attention_factor = attention_factor + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + self.bias_initializer = bias_initializer + self.bias_regularizer = bias_regularizer + + self.att_layer = layers.Dense( + units=self.attention_factor, + activation='relu', + use_bias=True, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_initializer=self.bias_initializer, + bias_regularizer=self.bias_regularizer + ) + self.att_proj_layer = layers.Dense( + units=1, + activation=None, + use_bias=False, + kernel_initializer=self.kernel_initializer + ) + + def call(self, inputs, **kwargs): + + interactions = list() + + for i in range(len(inputs) - 1): + for j in range(i + 1, len(inputs)): + interactions.append(tf.multiply(inputs[i], inputs[j])) + + interactions = tf.stack(interactions, axis=1) + att_weight = self.att_layer(interactions) + att_weight = self.att_proj_layer(att_weight) + + att_weight = layers.Softmax(axis=1)(att_weight) + output = tf.reduce_sum(interactions * att_weight, axis=1) + + return output diff --git a/models/AFM.py b/models/AFM.py new file mode 100644 index 0000000..50cb051 --- /dev/null +++ b/models/AFM.py @@ -0,0 +1,71 @@ +import tensorflow as tf + +from core.features import FeatureMetas, Features, group_embedded_by_dim +from core.blocks import DNN, AttentionBasedPoolingLayer + + +def AFM( + feature_metas, + linear_slots, + fm_slots, + embedding_initializer='glorot_uniform', + embedding_regularizer=tf.keras.regularizers.l2(1e-5), + fm_fixed_embedding_dim=None, + linear_use_bias=True, + linear_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + linear_kernel_regularizer=tf.keras.regularizers.l2(1e-5), + dnn_hidden_units=(128, 64, 1), + dnn_activations=('relu', 'relu', None), + dnn_use_bias=True, + dnn_use_bn=False, + dnn_dropout=0, + dnn_kernel_initializers='glorot_uniform', + dnn_bias_initializers='zeros', + dnn_kernel_regularizers=tf.keras.regularizers.l2(1e-5), + dnn_bias_regularizers=None, + name='AFM'): + + assert isinstance(feature_metas, FeatureMetas) + + with tf.name_scope(name): + + features = Features(metas=feature_metas) + + # Linear Part + with tf.name_scope('Linear'): + linear_output = features.get_linear_logit(use_bias=linear_use_bias, + kernel_initializer=linear_kernel_initializer, + kernel_regularizer=linear_kernel_regularizer, + embedding_group='dot_embedding', + slots_filter=linear_slots) + + # Interaction + with tf.name_scope('Interaction'): + fm_embedded_dict = features.get_embedded_dict(group_name='embedding', + fixed_embedding_dim=fm_fixed_embedding_dim, + embedding_initializer=embedding_initializer, + embedding_regularizer=embedding_regularizer, + slots_filter=fm_slots) + fm_dim_groups = group_embedded_by_dim(fm_embedded_dict) + fms = [AttentionBasedPoolingLayer()(group) + for group in fm_dim_groups.values() if len(group) > 1] + dnn_inputs = tf.concat(fms, axis=1) + dnn_output = DNN( + units=dnn_hidden_units, + use_bias=dnn_use_bias, + activations=dnn_activations, + use_bn=dnn_use_bn, + dropout=dnn_dropout, + kernel_initializers=dnn_kernel_initializers, + bias_initializers=dnn_bias_initializers, + kernel_regularizers=dnn_kernel_regularizers, + bias_regularizers=dnn_bias_regularizers + )(dnn_inputs) + + # Output + output = tf.add_n([linear_output, dnn_output]) + output = tf.keras.activations.sigmoid(output) + + model = tf.keras.Model(inputs=features.get_inputs_list(), outputs=output) + + return model diff --git a/models/DeepFM.py b/models/DeepFM.py index 1d4b919..3a1fb77 100644 --- a/models/DeepFM.py +++ b/models/DeepFM.py @@ -13,8 +13,8 @@ def DeepFM( embedding_regularizer=tf.keras.regularizers.l2(1e-5), fm_fixed_embedding_dim=None, linear_use_bias=True, - linear_kernel_initializer='glorot_uniform', - linear_kernel_regularizer=None, + linear_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + linear_kernel_regularizer=tf.keras.regularizers.l2(1e-5), dnn_hidden_units=(128, 64, 1), dnn_activations=('relu', 'relu', None), dnn_use_bias=True, diff --git a/models/NFM.py b/models/NFM.py index c9f98d2..605460c 100644 --- a/models/NFM.py +++ b/models/NFM.py @@ -12,8 +12,8 @@ def NFM( embedding_regularizer=tf.keras.regularizers.l2(1e-5), fm_fixed_embedding_dim=None, linear_use_bias=True, - linear_kernel_initializer='glorot_uniform', - linear_kernel_regularizer=None, + linear_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + linear_kernel_regularizer=tf.keras.regularizers.l2(1e-5), dnn_hidden_units=(128, 64, 1), dnn_activations=('relu', 'relu', None), dnn_use_bias=True, diff --git a/models/WideAndDeep.py b/models/WideAndDeep.py index de2c243..e12574b 100644 --- a/models/WideAndDeep.py +++ b/models/WideAndDeep.py @@ -11,8 +11,8 @@ def WideAndDeep( embedding_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1e-4), embedding_regularizer=tf.keras.regularizers.l2(1e-5), wide_use_bias=True, - wide_kernel_initializer='glorot_uniform', - wide_kernel_regularizer=None, + wide_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + wide_kernel_regularizer=tf.keras.regularizers.l2(1e-5), deep_fixed_embedding_dim=None, deep_hidden_units=(128, 64, 1), deep_activations=('relu', 'relu', None), diff --git a/models/xDeepFM.py b/models/xDeepFM.py index c150436..8f4f676 100644 --- a/models/xDeepFM.py +++ b/models/xDeepFM.py @@ -15,8 +15,8 @@ def xDeepFM( fm_kernel_initializer='glorot_uniform', fm_kernel_regularizer=None, linear_use_bias=True, - linear_kernel_initializer='glorot_uniform', - linear_kernel_regularizer=None, + linear_kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-4, seed=1024), + linear_kernel_regularizer=tf.keras.regularizers.l2(1e-5), dnn_hidden_units=(128, 64, 1), dnn_activations=('relu', 'relu', None), dnn_use_bias=True,