|
- import argparse
- import os
- import pathlib
- import time
-
- import matplotlib.pyplot as plt
- import shap
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
-
-
- class ITCexplainer:
- def __init__(self, model_path, data, result_path):
- # 加载文本分类模型
- if isinstance(model_path, (str, pathlib.Path)):
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
- self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
- else:
- # model_path is of type text.Models.HuggingFaceNLPVictimModel
- self.tokenizer = model_path.tokenizer
- self.model = model_path.model
- self.pred = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, return_all_scores=True)
- self.classifier = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer)
-
- # 加载文本分类模型的类别
- self.id2label = self.model.config.id2label
- self.label2id = self.model.config.label2id
-
- # 设置explainer
- self.explainer = shap.Explainer(self.pred)
-
- # 设置数据
- self.data = data
-
- # 设置画图时的中文字体
- plt.rcParams["font.sans-serif"] = ["SimHei"]
- plt.rcParams["axes.unicode_minus"] = False
-
- # 设置结果路径
- self.result_path = result_path
- if not os.path.exists(self.result_path):
- os.makedirs(self.result_path)
-
- def classify(self):
- # 文本分类,结果输出到classification.log中
- res = self.classifier(self.data)
- file = open(pathlib.Path(self.result_path) / "classification.log", "w")
- for idx, item in enumerate(res):
- file.writelines("for setence " + str(idx) + ", the class is:\t" + str(item["label"] + "\n"))
- file.flush()
- file.close()
-
- def calc_shap_value(self, data):
- # 计算shap,时间较长,不宜传递过多的句子
- # 参数是一个字符串的一维列表,支持list、array、dataframe等格式
- self.shap_values = self.explainer(data)
-
- def save_text_interpretation_html(self, html, file_name):
- # 保存html格式的图
- file = open(file_name, "w")
- file.write(html)
- file.flush()
- file.close()
-
- def draw_text_analysis(self, whole):
- # 画出每个词对某个类的分类结果的影响图,html格式,均放置在figure文件夹下
- # 若传递了多个句子,则会画出每一个句子的图,并根据whole决定是否画在一张图中
- self.calc_shap_value(self.data)
- if whole:
- plt.cla()
- res = shap.plots.text(self.shap_values, display=False)
- figure_path = str(pathlib.Path(self.result_path) / "whole.html")
- self.save_text_interpretation_html(res, figure_path)
- else:
- for idx in range(len(self.data)):
- plt.cla()
- res = shap.plots.text(self.shap_values[idx, :, :], display=False)
- figure_path = str(pathlib.Path(self.result_path) / (str(idx) + ".html"))
- self.save_text_interpretation_html(res, figure_path)
-
- def draw_shap_bar(self, label):
- # 对某个类别画出SHAP值的条形图。
- # 需指定想要分析的类别label
- self.calc_shap_value(self.data)
- self.label = self.id2label[label]
- for idx in range(len(self.data)):
- plt.cla()
- shap.plots.bar(self.shap_values[:, :, label][idx], show=False)
- figure_path = str(pathlib.Path(self.result_path) / (str(idx) + ".png"))
- plt.savefig(figure_path)
-
-
- class ITCdataset:
- def __init__(self, data_path):
- if isinstance(data_path, (str, pathlib.Path)):
- file = open(data_path, "r")
- self.data = file.readlines()
- file.close()
- else:
- # data_path is list of str (sample sentences)
- self.data = data_path
-
- def get_data(self):
- return self.data
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description="Interpretability of Text Classification")
- parser.add_argument(
- "--model",
- type=str,
- default="./model/roberta-base-finetuned-ifeng-chinese",
- help="the name or path of a huggingface text classification model",
- )
- parser.add_argument("--data", type=str, default="./data/data.txt", help="data path of text input")
- parser.add_argument("--task", type=str, default="classify", help="classify analysis bar")
- parser.add_argument("--whole", action="store_true", help="draw the analysis results of multiple sentences together")
- parser.add_argument("--label", type=int, default=0, help="draw the bar chart of the shap value for this label")
- parser.add_argument("--result-dir", type=str, default=None, help="the directory to save the results", dest="result_dir")
-
- args = parser.parse_args()
-
- return args
-
-
- def main(args: dict):
-
- default_args = {
- "model": "./model/roberta-base-finetuned-ifeng-chinese",
- "data": "./data/data.txt",
- "task": "classify",
- "whole": False,
- "label": 0,
- "result_dir": None,
- }
- tmp_args = args.copy()
- args = default_args.copy()
- args.update(tmp_args)
- del tmp_args
- del default_args
-
- # 设置分析结果路径
- timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
- if not isinstance(args["task"], str):
- args["task"] = ",".join(args["task"])
- if args["result_dir"] is None:
- result_path = str(pathlib.Path("./result").expanduser().resolve() / (timestamp + "_" + args["task"]))
- else:
- result_path = str(pathlib.Path(args["result_dir"]).expanduser().resolve())
-
- # 读取数据
- data = ITCdataset(args["data"]).get_data()
-
- # 读取模型,设置explainer
- explainer = ITCexplainer(args["model"], data, result_path)
-
- # 根据任务调取相应的方法,结果保存在result文件夹中
- if isinstance(args["task"], str):
- args["task"] = [task.strip() for task in args["task"].split(",")]
- task_flag = False
- if "classify" in args["task"]:
- explainer.classify()
- print("classify completed successfully.")
- task_flag = True
- if "analysis" in args["task"]:
- explainer.draw_text_analysis(args["whole"])
- print("analysis completed successfully.")
- task_flag = True
- if "bar" in args["task"]:
- explainer.draw_shap_bar(args["label"])
- print("bar completed successfully.")
- task_flag = True
-
- if not task_flag:
- print("Please fill in the task type correctly.")
- print("There are three options: classify, analysis, and bar")
-
-
- if __name__ == "__main__":
- # 解析参数,确定模型、数据、任务
- args = parse_args()
- main(vars(args))
|