Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
AngusHuang17 4d62ee4201 | 1 year ago | |
---|---|---|
assets | 1 year ago | |
example | 1 year ago | |
recstudio | 1 year ago | |
test | 1 year ago | |
.gitignore | 1 year ago | |
LICENSE | 1 year ago | |
MANIFEST.in | 1 year ago | |
README.md | 1 year ago | |
README_CN.md | 1 year ago | |
run.py | 1 year ago | |
setup.py | 1 year ago |
RecStudio is a unified, highly-modularized and recommendation-efficient recommendation library based on PyTorch. All the algorithms are
categorized as follows according to recommendation tasks.
At the core of the library, all recommendation models are grouped into three base classes:
TowerFreeRecommender
: The most flexible base class, which enables any complex feature-interaction modeling.ItemTowerRecommender
: Item encoder are separated from recommender, enabling fast ANN and model-based negative sampling.TwoTowerRecommender
: The subclass of ItemTowerRecommender
, where recommenders only consist of user encoder and item encoder.For the dataset structure, the datasets are divided into five categories:
Dataset | Application | Examples |
---|---|---|
MFDataset | Dataset for providing user-item-rating triplet | BPR, NCF, CML et al. |
AEDataset | Dataset for AutoEncoder-based ItemTowerRecommender | MultiVAE, RecVAE, et al. |
SeqDataset | Dataset for Sequential recommenders with Causal Prediction | GRU4Rec, SASRec, et al. |
Seq2SeqDataset | Dataset for Sequential recommenders with Masked Prediction | Bert4Rec, et al. |
ALSDataset | Dataset for recommenders optimized by alternating least square | WRMF, et al. |
In order to accelerate dataset processing, processed dataset are automatically cached for repeatable training shortly.
Almost all common metrics used in recommender systems are implemented in RecStudio based on PyTorch, such as NDCG
, Recall
, Precision
, et al. All metric functions have the same interface, being fully implemented with tensor operators. Therefore, the evaluation procedure can be moved to GPU, leading to a remarkable speedup of evaluation.
In order to accelerate training and evaluation, RecStudio integrates various Approximate Nearest Neighbor
search (ANNs) and negative samplers. By building indexes with ANNs, the topk operator based on Euclidean distance, inner product and cosine similarity can be significantly accelerated. Negative samplers consist of static sampler and model-based samplers developed by RecStudio
team. Static samplers consist of Uniform Sampler
and Popularity Sampler
. The model-based samplers are based on either quantization of item vectors or importance resampling. Moreover, we also implement static sampling in the dataset, which enables us to generate negatives when loading data.
In RecStudio, loss functions are categorized into three types:
- FullScoreLoss
: Calculating scores on the whole items, such as SoftmaxLoss
.
- PairwiseLoss
: Calculating scores on positive and negative items, such as BPRLoss
,
BinaryCrossEntropyLoss
, et al.
- PointwiseLoss
: Calculating scores for a single (user,item) interaction, such as HingeLoss
.
Score functions are used to model users' preference on items. Various common score functions are
implemented in RecStudio, such as: InnerProduct
, EuclideanDistance
, CosineDistance
, MLPScorer
,
et al.
Loss | Math Type | Sampling Distribution | Calculation Complexity | Sampling Complexity | Convergence Speed | Related Metrics |
---|---|---|---|---|---|---|
Softmax | No sampling | - | very fast | NDCG | ||
Sampled Softmax | No sampling | - | fast | NDCG | ||
BPR | Uniform sampling | slow | AUC | |||
WARP | Reject Sampling | slower and slower | slow | Precision | ||
InfoNCE | Popularity sampling | fast | DCG | |||
WRMF | No sampling | - | very fast | - | ||
PRIS | Cluster sampling | very fast | DCG |
Figure: RecStudio Framework
By downloading the source code, you can run th provided script run.py
for initial usage of RecStudio.
python run.py
The initial config will train and evaluate BPR model on MovieLens-100k(ml-100k) dataset.
Generally speaking, the simple example will take less than one minute with GPUs. And the output will be
like below:
[2022-04-11 14:30:29] INFO (faiss.loader/MainThread) Loading faiss with AVX2 support.
[2022-04-11 14:30:29] INFO (faiss.loader/MainThread) Loading faiss.
[2022-04-11 14:30:29] INFO (faiss.loader/MainThread) Successfully loaded faiss.
[2022-04-11 14:30:30] INFO (pytorch_lightning.utilities.seed/MainThread) Global seed set to 42
[2022-04-11 14:30:30] INFO (pytorch_lightning/MainThread) learning_rate=0.001
weight_decay=0
learner=adam
scheduler=None
epochs=100
batch_size=2048
num_workers=0
gpu=None
ann=None
sampler=None
negative_count=1
dataset_sampling_count=None
embed_dim=64
item_bias=False
eval_batch_size=20
split_ratio=[0.8, 0.1, 0.1]
test_metrics=['recall', 'precision', 'map', 'ndcg', 'mrr', 'hit']
val_metrics=['recall', 'ndcg']
topk=100
cutoff=10
early_stop_mode=max
split_mode=user_entry
shuffle=True
use_fields=['user_id', 'item_id', 'rating']
[2022-04-11 14:30:30] INFO (pytorch_lightning/MainThread) save_dir:/home/RecStudio/
[2022-04-11 14:30:30] INFO (pytorch_lightning.utilities.distributed/MainThread) GPU available: True, used: False
[2022-04-11 14:30:30] INFO (pytorch_lightning.utilities.distributed/MainThread) TPU available: False, using: 0 TPU cores
[2022-04-11 14:30:30] INFO (pytorch_lightning.utilities.distributed/MainThread) IPU available: False, using: 0 IPUs
[2022-04-11 14:30:30] INFO (pytorch_lightning.utilities.distributed/MainThread) The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
[2022-04-11 14:30:30] INFO (pytorch_lightning.core.lightning/MainThread)
| Name | Type | Params
----------------------------------------------------
0 | loss_fn | BPRLoss | 0
1 | score_func | InnerProductScorer | 0
2 | item_encoder | Embedding | 107 K
3 | sampler | UniformSampler | 0
4 | user_encoder | Embedding | 60.4 K
----------------------------------------------------
168 K Trainable params
0 Non-trainable params
168 K Total params
0.673 Total estimated model params size (MB)
[2022-04-11 14:30:30] INFO (pytorch_lightning.callbacks.early_stopping/MainThread) Metric recall@10 improved. New best score: 0.007
[2022-04-11 14:30:30] INFO (pytorch_lightning/MainThread) Training: Epoch= 0 [recall@10=0.0074 ndcg@10=0.0129 train_loss=0.6932]
[2022-04-11 14:30:31] INFO (pytorch_lightning.callbacks.early_stopping/MainThread) Metric recall@10 improved by 0.006 >= min_delta = 0.0. New best score: 0.014
[2022-04-11 14:30:31] INFO (pytorch_lightning/MainThread) Training: Epoch= 1 [recall@10=0.0135 ndcg@10=0.0251 train_loss=0.6915]
[2022-04-11 14:30:32] INFO (pytorch_lightning.callbacks.early_stopping/MainThread) Metric recall@10 improved by 0.038 >= min_delta = 0.0. New best score: 0.051
...
[2022-04-11 14:31:26] INFO (pytorch_lightning/MainThread) Training: Epoch= 75 [recall@10=0.2074 ndcg@10=0.2942 train_loss=0.1909]
[2022-04-11 14:31:26] INFO (pytorch_lightning.callbacks.early_stopping/MainThread) Monitored metric recall@10 did not improve in the last 10 records. Best score: 0.211. Signaling Trainer to stop.
[2022-04-11 14:31:26] INFO (pytorch_lightning/MainThread) Training: Epoch= 76 [recall@10=0.2073 ndcg@10=0.2949 train_loss=0.1899]
[2022-04-11 14:31:26] INFO (pytorch_lightning.utilities.distributed/MainThread) The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: EarlyStopping, ModelCheckpoint
[2022-04-11 14:31:27] INFO (pytorch_lightning/MainThread) Testing: [recall@10=0.2439 precision@10=0.1893 map@10=0.5762 ndcg@10=0.3718 mrr@10=0.4487 hit@10=0.7815]
If you want to change models or datasets, command line is ready for you.
python run.py -m=NCF -d=ml-1m
Supported commandline arguments:
args | type | description | default | optional |
---|---|---|---|---|
-m,--model | str | model name | BPR | all the models in RecStudio |
-d,--dataset | str | dataset name | ml-100k | all the datasets supported by RecStudio |
--data_dir | str | dataset folder | datasets | folders that could be read by RecStudio |
mode | str | training mode | light | ['light','detail','tune'] |
--learning_rate | float | learning rate | 0.001 | |
--learner | str | optimizer name | adam | ['adam','sgd','adasgd','rmsprop','sparse_adam'] |
--weight_decay | float | weight decay for optimizer | 0 | |
--epochs | int | training epoch | 20,50 | |
--batch_size | int | the size of mini batch in training | 2048 | |
--eval_batch_size | int | the size of mini batch in evaluation | 128 | |
--embed_dim | int | the output size of embedding layers | 64 |
For ItemTowerRecommender
, some extra args are supported:
args | type | description | default | optional |
---|---|---|---|---|
--sampler | str | sampler name | uniform | ['uniform','popularity','midx_uni','midx_pop','cluster_uni','cluster_pop'] |
--negative_count | int | number of negative samples | 1 | positive integer |
For TwoTowerRecommender
, some extra args are supported based on ItemTowerRecommender
:
args | type | description | default | optional |
---|---|---|---|---|
--split_mode | str | split methods for the dataset | user_entry | ['user','entry','user_entry'] |
Here are some details of some unclear arguments.
mode
:inlight
mode anddetail
mode, the output will displayed on the terminal, while
the latter provide more detailed info.tune
mode will use Neural Network Intelligence(NNI) to show
a beautiful visual interface. You can run liketune.sh
with a config file likeconfig.yaml
. For
more details about NNI, please refer to NNI Documentation.sampler
:uniform
stands for UniformSampler is used.popularity
stands for sampling according
to the item popularity (more popular items are sampled with higher probablities).midx_uni
,midx_pop
aremidx
dynamic sampler, please refer to FastVAE for more details.
cluster_uni
,cluster_pop
arecluster
dynamic sampler, please refer to
PRIS for more details.split_mode
:user
means spliting all users into train/valid/test datasets, users in
those datasets are disjoint.entry
means spliting all the interactions in those three dataset.
user_entry
means spliting interaction of each user into three parts.
Also, you can install RecStudio from PyPi:
pip install recstudio
For basic usage like below:
import recstudio
recstudio.run(model="BPR", data_dir="./datasets/", dataset='ml-100k')
For more detailed information, please refer to our documentation https://recstudio.readthedocs.io/.
RecStudio integrates with NNI module for tuning the hype-parameters automatically. For easy usage,
you can run tune.sh
script with your specific config file like the provided file config.yaml
.
For more detailed infomation about NNI, please refer to
NNI Documentation.
Please let us know if you encounter a bug or have any suggestions by
submitting an issue.
We welcome all contributions from bug fixes to new features and extensions.
We expect all contributions firstly discussed in the issue tracker and then going through PRs.
RecStudio is developed and maintained by USTC BigData Lab.
User | Contributions |
---|---|
@DefuLian | Framework design and construction |
@AngusHuang17 | Sequential model, docs, bugs fixing |
@Xiuchen519 | Knowledge-based model, bugs fixing |
@JennahF | NCF,CML,logisticMF models |
@HERECJ | AutoEncoder models |
@BinbinJin | IRGAN model |
RecStudio uses MIT License.
RecStudio is a unified, highly-modularized and recommendation-efficient recommendation library based on PyTorch. All the algorithms are categorized as follows according to recommendation tasks.
Python Jupyter Notebook Markdown
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》