|
- import uvicorn
- from fastapi import FastAPI
- from pydantic import BaseModel
- from starlette.middleware.cors import CORSMiddleware
- import numpy as np
-
- from ge.classify import read_node_label, Classifier
- from ge import Struc2Vec
- from sklearn.linear_model import LogisticRegression
-
- import matplotlib.pyplot as plt
- import networkx as nx
- from sklearn.manifold import TSNE
-
- app = FastAPI(
- title='node embedding struc2vec 相似图谱可视化',
- description='基于node embedding任务,包括struc2vec训练模块、推理模块。',
- version='1.0.0'
- )
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=False,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- import json
-
- family_doctor = json.load(open("../data/relation_data.json", "r", encoding="utf-8"))[:10000]
- G = nx.DiGraph()
- for i in family_doctor:
- G.add_edge(i[0], i[2])
- model = Struc2Vec(G, 5, 40, workers=6, verbose=20, )
- model.train()
- embeddings = model.get_embeddings()
-
-
- def extract_spoes(query):
- all_out = []
- all_query = []
- for i in model.w2v_model.wv.index_to_key:
- # print(i)
- if query in i:
- all_query.append(i)
- all_node = []
- for one_query in all_query:
- for i, j in model.w2v_model.wv.similar_by_word(one_query)[:5]:
- all_node.append(i)
- all_node = list(set(all_node))
- for family_doctor_one_list in family_doctor:
- for one_node in all_node:
- if one_node in family_doctor_one_list:
- a = (family_doctor_one_list[0], "导致", family_doctor_one_list[2])
- all_out.append(a)
-
- all_out = list(set(all_out))
- return all_out[:500], len(all_out)
-
-
- class Item(BaseModel):
- array: str
-
-
- @app.post("/search")
- async def root(item: Item):
- out = []
- prediction, result_len = extract_spoes(item.array)
- for i in prediction:
- out.append({"source": i[0], "target": i[2], 'rela': "导致", "type": 'resolved'})
- return {
- "result_len": result_len,
- "prediction": out,
- }
-
-
- if __name__ == '__main__':
- uvicorn.run(app, host="127.0.0.1", port=6009)
|