Home

Awesome

tensorflow-DeepFM

This project includes a Tensorflow implementation of DeepFM [1].

NEWS

Usage

Input Format

This implementation requires the input data in the following format:

Please see example/DataReader.py an example how to prepare the data in required format for DeepFM.

Init and train a model

import tensorflow as tf
from sklearn.metrics import roc_auc_score

# params
dfm_params = {
    "use_fm": True,
    "use_deep": True,
    "embedding_size": 8,
    "dropout_fm": [1.0, 1.0],
    "deep_layers": [32, 32],
    "dropout_deep": [0.5, 0.5, 0.5],
    "deep_layers_activation": tf.nn.relu,
    "epoch": 30,
    "batch_size": 1024,
    "learning_rate": 0.001,
    "optimizer_type": "adam",
    "batch_norm": 1,
    "batch_norm_decay": 0.995,
    "l2_reg": 0.01,
    "verbose": True,
    "eval_metric": roc_auc_score,
    "random_seed": 2017
}

# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)

# init a DeepFM model
dfm = DeepFM(**dfm_params)

# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)

# make prediction
dfm.predict(Xi_valid, Xv_valid)

# evaluate a trained model
dfm.evaluate(Xi_valid, Xv_valid, y_valid)

You can use early_stopping in the training as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)

You can refit the model on the whole training and validation set as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)

You can use the FM or DNN part only by setting the parameter use_fm or use_dnn to False.

Regression

This implementation also supports regression task. To use DeepFM for regression, you can set loss_type as mse. Accordingly, you should use eval_metric for regression, e.g., mse or mae.

Example

Folder example includes an example usage of DeepFM/FM/DNN models for Porto Seguro's Safe Driver Prediction competition on Kaggle.

Please download the data from the competition website and put them into the example/data folder.

To train DeepFM model for this dataset, run

$ cd example
$ python main.py

Please see example/DataReader.py how to parse the raw dataset into the required format for DeepFM.

Performance

DeepFM

dfm

FM

fm

DNN

dnn

Some tips

Reference

[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.

Acknowledgments

This project gets inspirations from the following projects:

License

MIT