|
- import utils
- import AISyncore as syncore
- import argparse
- import json
-
- utils.setup_seed(1)
-
- _config = utils._config
-
- class CifarClient(syncore.client.NumPyClient):
- def __init__(
- self,
- model,
- trainloader,
- testloader,
- num_examples,
- device
- ):
- self.model = model
-
- self.trainloader = trainloader
- self.testloader = testloader
- self.num_examples = num_examples
- self.device = device
-
- def get_parameters(self):
- """Get parameters of the local model."""
- print("GETTING PARAMETERS")
- raise Exception("Not implemented (server-side parameter initialization)")
-
- def set_parameters(self, parameters):
- """Loads a efficientnet model and replaces it parameters with the ones
- given."""
- print("SETTING PARAMETERS")
- return utils.set_model_params(self.model, parameters)
-
-
- def fit(self, parameters, config):
- """Train parameters on the locally held training set."""
-
- model = self.set_parameters(parameters)
-
- # Get hyperparameters for this round
-
- results = utils.train(model, self.trainloader, self.testloader, self.device, config)
-
- parameters_prime = utils.get_model_params(model)
-
- return parameters_prime, self.num_examples['trainset'], results
-
- def evaluate(self, parameters, config):
- """Evaluate parameters on the locally held test set."""
- print("EVALUATING")
- raise Exception("Not evaluate (server-side evaluate)")
-
-
- def client_dry_run(model, client, epochs=1):
- """Weak tests to check whether all client methods are working as
- expected."""
- dry_run_log = []
- for rnd in range(1, epochs + 1):
- config = utils.fit_server_config(rnd, None)
- _, _, result = client.fit(utils.get_model_params(model), config)
- dry_run_log.append([float(rnd), result["test_accuracy"]])
-
- with open("./log/dry_log.json",'w',encoding='utf-8') as file_obj:
- json.dump(dry_run_log,file_obj,ensure_ascii=False)
- print("Dry Run Successful")
-
-
- def main() -> None:
- # Parse command line argument `partition`
- parser = argparse.ArgumentParser(description="Flower")
- parser.add_argument(
- "--dry",
- action="store_true",
- help="Do a dry-run to check the client",
- )
- parser.add_argument(
- "--partition",
- type=int,
- default=0,
- required=False,
- help="Specifies the artificial data partition of CIFAR10 to be used. \
- Picks partition 0 by default",
- )
-
-
- args = parser.parse_args()
-
- model = utils.load_model()
-
- trainloader, testloader, num_examples = utils.load_partition_dataloader(args.partition, batch_size=_config.BATCH_SIZE)
-
- client = CifarClient(model, trainloader, testloader, num_examples, _config.DEVICE)
-
- if args.dry:
- client_dry_run(model, client, epochs=_config.EPOCHS)
- else:
- syncore.client.run_numpy_client(_config.TASK_SERVER_IP + ":" + _config.TASK_SERVER_PORT, client=client)
-
-
- if __name__ == "__main__":
- main()
|