Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
zonghuia a231fe7a98 | 1 year ago | |
---|---|---|
assets | 1 year ago | |
instructor | 1 year ago | |
metrics | 1 year ago | |
models | 1 year ago | |
run | 1 year ago | |
utils | 1 year ago | |
visual | 1 year ago | |
.gitignore | 1 year ago | |
LICENSE | 1 year ago | |
README.md | 1 year ago | |
config.py | 1 year ago | |
main.py | 1 year ago | |
requirements.txt | 1 year ago | |
run_signal.txt | 1 year ago |
TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models, including general text generation models and category text generation models. TextGAN serves as a benchmarking platform to support research on GAN-based text generation models. Since most GAN-based text generation models are implemented by Tensorflow, TextGAN can help those who get used to PyTorch to enter the text generation field faster.
If you find any mistake in my implementation, please let me know! Also, please feel free to contribute to this repository if you want to add other models.
To install, run pip install -r requirements.txt
. In case of CUDA problems, consult the official PyTorch Get Started guide.
Download stable release and unzip: http://kheafield.com/code/kenlm.tar.gz
Need Boost >= 1.42.0 and bjam
sudo apt-get install libboost-all-dev
brew install boost; brew install bjam
Run within kenlm directory:
mkdir -p build
cd build
cmake ..
make -j 4
pip install https://github.com/kpu/kenlm/archive/master.zip
For more information on KenLM see: https://github.com/kpu/kenlm and http://kheafield.com/code/kenlm/
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorch
Image COCO
, EMNLP NEWs
, Movie Review
, Amazon Review
) can be downloaded from here.cd run
python3 run_[model_name].py 0 0 # The first 0 is job_id, the second 0 is gpu_id
# For example
python3 run_seqgan.py 0 0
Instructor
For each model, the entire runing process is defined in instructor/oracle_data/seqgan_instructor.py
. (Take SeqGAN in Synthetic data experiment for example). Some basic functions like init_model()
and optimize()
are defined in the base class BasicInstructor
in instructor.py
. If you want to add a new GAN-based text generation model, please create a new instructor under instructor/oracle_data
and define the training process for the model.
Visualization
Use utils/visualization.py
to visualize the log file, including model loss and metrics scores. Custom your log files in log_file_list
, no more than len(color_list)
. The log filename should exclude .txt
.
Logging
The TextGAN-PyTorch use the logging
module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved in log/log_****_****.txt
and save/**/log.txt
respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in ./save/**/models
and ./save/**/samples
per log step, where **
depends on your hyper-parameters.
Running Signal
You can easily control the training process with the class Signal
(please refer to utils/helpers.py
) based on dictionary file run_signal.txt
.
For using the Signal
, just edit the local file run_signal.txt
and set pre_sig
to Fasle
for example, the program will stop pre-training process and step into next training phase. It is convenient to early stop the training if you think the current training is enough.
Automatiaclly select GPU
In config.py
, the program would automatically select a GPU device with the least GPU-Util
in nvidia-smi
. This feature is enabled by default. If you want to manually select a GPU device, please uncomment the --device
args in run_[run_model].py
and specify a GPU device with command.
run file: run_seqgan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from SeqGAN)
run file: run_leakgan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from LeakGAN)
run file: run_maligan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from my understanding)
run file: run_jsdgan.py
Instructors: oracle_data, real_data
Models: generator (No discriminator)
Structure (from my understanding)
run file: run_relgan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from my understanding)
run file: run_dpgan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from DPGAN)
run file: run_dgsan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
run file: run_cot.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from CoT)
run file: run_sentigan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from SentiGAN)
run file: run_catgan.py
Instructors: oracle_data, real_data
Models: generator, discriminator
Structure (from CatGAN)
MIT lincense
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》