In Chapter 4, Federated Learning Server Implementation with Python, and Chapter 5, Federated Learning Client-Side Implementation, both about the implementation of federated learning (FL) systems, internal library functions were given to simplify the explanation of the implementation of the FL server and client functionalities and machine learning (ML) applications. Here, we will talk about those internal libraries, such as the communications handler, data structure handler, and enumeration class definitions, in more detail for you to be able to easily implement the FL systems that work over the internet and on the cloud. Those internal libraries and supporting functions can all be found in the fl_main/lib/util directory of the provided simple-fl GitHub repository.
In this appendix, we will provide an overview of the internal library and utilization classes and functions with code samples to achieve their functionalities.
In this chapter, we’re going to cover the following main topics:
All the library code files introduced in this chapter can be found in the fl_main/lib/util directory of the GitHub repository (https://github.com/tie-set/simple-fl).
Important note
You can use the code files for personal or educational purposes. Please note that we will not support deployment for commercial use and will not be responsible for any errors, issues, or damages caused by using the code.
Figure A.1 shows the Python code components for the internal libraries found in the lib/util folder of the fl_main directory, which is used in the database, aggregator, and agent of the FL system:
Figure A.1 – Python software components for the internal libraries used in the database, aggregator, and agent
The following are brief descriptions of the Python files for the internal libraries found in the lib/util folder of the FL system.
The states.py file in the lib/util folder defines a variety of enumeration classes to support implementing the FL system. Definitions of the classes include FL client states, types of ML models and messages, and locations of the information and values of various messages.
The communication_handler.py file in the lib/util folder can provide communication functionalities among the database, FL server, and clients, mainly defining the send and receive functions between them. Also, it provides the functions to start the servers for the database, aggregator, and agent.
The data_struc.py file in the lib/util folder defines the class called LimitedDict to support an aggregation process of the FL cycle. It provides functions to convert ML models with a dictionary format into LimitedDict and vice versa.
The helpers.py file in the lib/util folder has a collection of internal helper functions, such as reading configuration files, generating unique hash IDs, packaging ML models into a dictionary, loading and saving local ML models, getting the IP address of the machine, and manipulating the FL client state.
The messengers.py file in the lib/util folder is for generating a variety of messages as communication payloads among FL systems to facilitate the implementation of communication protocols of the simple FL system discussed throughout the book.
Now that we have discussed an overview of the FL system’s internal libraries, next, let’s talk about the individual code files in more detail.
Enumeration classes are for assisting implemention of the FL system. They are defined in the states.py file found in the lib/util folder of the fl_main directory. Let us look into what libraries are imported to define the enumeration classes.
In this states.py code example, the file imports general libraries such as Enum and IntEnum from enum:
from enum import Enum, IntEnum
Next, we’ll explain the class that defines the prefixes of three components of the FL system.
The following is a list of classes to define the FL system components. IDPrefix is the prefix to indicate which FL component is referred to in the code, such as agent, aggregator, or database:
class IDPrefix: agent = 'agent' aggregator = 'aggregator' db = 'database'
Next, we’ll provide a list of the classes for the client state.
The following is a list of enumeration classes related to the FL client states, including the state of waiting for global models (waiting_gm), the state of ML training (training), the state of sending local ML models (sending), and the state of receiving the global models (gm_ready). The client states defined in the agent specification are as follows:
# CLIENT STATE class ClientState(IntEnum): waiting_gm = 0 training = 1 sending = 2 gm_ready = 3
The following is a list of classes defining the types of ML models and messages related to the FL system implementation.
The types of ML models, including local models and cluster models (global models), are defined as follows:
class ModelType(Enum): local = 0 cluster = 1
The message types are defined in the communication protocol between an aggregator and database, as follows:
class DBMsgType(Enum): push = 0
The message types are defined in the communication protocol sent from an agent to an aggregator, as follows:
class AgentMsgType(Enum): participate = 0 update = 1 polling = 2
The message types are defined in the communication protocol sent from an aggregator to an agent, as follows:
class AggMsgType(Enum): welcome = 0 update = 1 ack = 2
The following is a list of classes defining the message location related to communication between the FL systems.
The index indicator to read a participation message from an agent to the aggregator is as follows:
class ParticipateMSGLocation(IntEnum): msg_type = 0 agent_id = 1 model_id = 2 lmodels = 3 init_flag = 4 sim_flag = 5 exch_socket = 6 gene_time = 7 meta_data = 8 agent_ip = 9 agent_name = 10 round = 11
The index indicator to read a participation confirmation message sent back from the aggregator is as follows:
class ParticipateConfirmationMSGLocation(IntEnum): msg_type = 0 aggregator_id = 1 model_id = 2 global_models = 3 round = 4 agent_id = 5 exch_socket = 6 recv_socket = 7
The index indicator to read a push message from an aggregator to the database is as follows:
class DBPushMsgLocation(IntEnum): msg_type = 0 component_id = 1 round = 2 model_type = 3 models = 4 model_id = 5 gene_time = 6 meta_data = 7 req_id_list = 8
The index indicator to read a global model distribution message from an aggregator to agents is as follows:
class GMDistributionMsgLocation(IntEnum): msg_type = 0 aggregator_id = 1 model_id = 2 round = 3 global_models = 4
The index indicator to a message uploading local ML models from an agent to an aggregator is as follows:
class ModelUpMSGLocation(IntEnum): msg_type = 0 agent_id = 1 model_id = 2 lmodels = 3 gene_time = 4 meta_data = 5
The index indicator for a polling message from an agent to an aggregator is as follows:
class PollingMSGLocation(IntEnum): msg_type = 0 round = 1 agent_id = 2
We have defined the enumeration classes that are utilized throughout the code of the FL system. In the next section, we will discuss the communication handler functionalities.
The communication handler functionalities are implemented in the communication_handler.py file, which can be found in the lib/util folder of the fl_main directory.
In this communication_handler.py code example, the handler imports general libraries such as websockets, asyncio, pickle, and logging:
import websockets, asyncio, pickle, logging
Next, we’ll provide a list of functions of the communication handler.
The following is a list of the functions related to the communication hander. Although the Secure Sockets Layer (SSL) or Transport Layer Security (TLS) framework is not implemented in the communication handler code here for simplification, it is recommended to support them to secure communication among FL components all the time.
The init_db_server function is for starting the database server on the FL server side. It takes a function, database IP address, and socket information as inputs and initiates the server functionality based on the WebSocket framework. You can use any other communication protocol, such as HTTP, as well. Here is the sample code to initiate the database server:
def init_db_server(func, ip, socket): start_server = websockets.serve( func, ip, socket, max_size=None, max_queue=None) loop = asyncio.get_event_loop() loop.run_until_complete(start_server) loop.run_forever()
The init_fl_server function is for starting the FL server on the aggregator side. As parameters, it takes three functions for agent registration, receiving messages from agents, and the model synthesis routine, as well as the aggregator’s IP address and registration and receiver sockets info (to receive messages from agents) to initiate the server functionality based on the WebSocket framework. Here is the sample code for initiating the FL server:
def init_fl_server(register, receive_msg_from_agent, model_synthesis_routine, aggr_ip, reg_socket, recv_socket): loop = asyncio.get_event_loop() start_server = websockets.serve(register, aggr_ip, reg_socket, max_size=None, max_queue=None) start_receiver = websockets.serve( receive_msg_from_agent, aggr_ip, recv_socket, max_size=None, max_queue=None) loop.run_until_complete(asyncio.gather( start_server, start_receiver, model_synthesis_routine)) loop.run_forever()
The init_client_server function is for starting the FL client-side server functionalities. It takes a function, the agent’s IP address, and the socket info to receive messages from an aggregator as inputs and initiate the functionality based on the WebSocket framework. Here is sample code for initiating the FL client-side server functionality:
def init_client_server(func, ip, socket): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) client_server = websockets.serve(func, ip, socket, max_size=None, max_queue=None) loop.run_until_complete(asyncio.gather(client_server)) loop.run_forever()
The send function is for sending a message to the destination specified by the IP address and socket info taken as parameters together with a message to be sent. It returns a response message sent back from the destination node to the source node, if there is one:
async def send(msg, ip, socket): resp = None try: wsaddr = f'ws://{ip}:{socket}' async with websockets.connect( wsaddr, max_size=None, max_queue=None, ping_interval=None) as websocket: await websocket.send(pickle.dumps(msg)) try: rmsg = await websocket.recv() resp = pickle.loads(rmsg) except: pass return resp except: return resp
The send_websocket function is for returning a message to the message source specified by the WebSocket information, taken as a parameter together with a message to be sent:
async def send_websocket(msg, websocket): while not websocket: await asyncio.sleep(0.001) await websocket.send(pickle.dumps(msg))
The receive function is used to receive a message with the WebSocket taken as a parameter and returns a pickled message:
async def receive(websocket): return pickle.loads(await websocket.recv())
Next, we will talk about the data structure class that handles processing ML models.
The data structure handler is implemented in the data_struc.py file, which can be found in the lib/util folder of the fl_main directory. The data structure class has the LimitedDict class to handle the aggregation of the ML models in a consistent manner.
In this data_struc.py code example, the handler imports general libraries, such as numpy and Dict:
from typing import Dict import numpy as np
Next, let’s move on to the LimitedDict class and its functions related to the data structure handler.
The following is a definition of the LimitedDict class and its functions related to the data structure handler.
The functions of the LimitedDict class are for converting a dictionary format into a class with keys and values. LimitedDict is used with the buffer in ML models to store local and cluster models in the memory space of the state manager of the aggregator:
class LimitedDict(dict): def __init__(self, keys): self._keys = keys self.clear() def __setitem__(self, key, value): if key not in self._keys: raise KeyError dict.__setitem__(self, key, value) def clear(self): for key in self._keys: self[key] = list()
The convert_LDict_to_Dict function is used to convert the LimitedDict instance defined previously into a normal dictionary format:
def convert_LDict_to_Dict(ld: LimitedDict) -> Dict[str,np.array]: d = dict() for key, val in ld.items(): d[key] = val[0] return d
In the next section, we will talk about the helper and supporting libraries.
The helper and supporting functions are implemented in the helpers.py file, which can be found in the lib/util folder of the fl_main directory.
In this helpers.py code example, the file imports general libraries such as json and time:
import json, time, pickle, pathlib, socket, asyncio from getmac import get_mac_address as gma from typing import Dict, List, Any from hashlib import sha256 from fl_main.lib.util.states import IDPrefix, ClientState
Next, let’s move on to the list of functions of the helper library.
The following is a list of functions related to the helper library.
The set_config_file function takes the type of the config file, such as db, aggregator, or agent, as a parameter and returns a string of the path to the configuration file:
def set_config_file(config_type: str) -> str: # set the config file name module_path = pathlib.Path.cwd() config_file = f'{module_path}/setups/config_{config_type}.json' return config_file
The read_config function reads a JSON configuration file to set up the database, aggregator, or agent. It takes a config path as a parameter and returns config info in a dictionary format:
def read_config(config_path: str) -> Dict[str, Any]: with open(config_path) as jf: config = json.load(jf) return config
The generate_id function generates a system-wide unique ID based on the MAC address and instantiation time with a hash function (sha256) returning the hash value as an ID:
def generate_id() -> str: macaddr = gma() in_time = time.time() raw = f'{macaddr}{in_time}' hash_id = sha256(raw.encode('utf-8')) return hash_id.hexdigest()
The generate_model_id function generates a system-wide unique ID for a set of models based on the following:
The ID is generated by a hash function (sha256). It takes the following parameters:
This function returns the hash value as a model ID:
def generate_model_id(component_type: str, component_id: str, gene_time: float) -> str: raw = f'{component_type}{component_id}{gene_time}' hash_id = sha256(raw.encode('utf-8')) return hash_id.hexdigest()
The create_data_dict_from_models function creates the data dictionary for ML models by taking the following parameters:
It returns a data dictionary containing the ML models:
def create_data_dict_from_models( model_id, models, component_id): data_dict = dict() data_dict['models'] = models data_dict['model_id'] = model_id data_dict['my_id'] = component_id data_dict['gene_time'] = time.time() return data_dict
The create_meta_data_dict function creates the metadata dictionary with the metadata of the ML models, taking the performance metrics (perf_val) and the number of samples (num_samples) as parameters, and returns meta_data_dict, containing the performance value and the number of samples:
def create_meta_data_dict(perf_val, num_samples): meta_data_dict = dict() meta_data_dict["accuracy"] = perf_val meta_data_dict["num_samples"] = num_samples return meta_data_dict
The compatible_data_dict_read function takes data_dict, which contains the information related to ML models, extracts the values if the corresponding key exists in the dictionary, and returns the component ID, the generation time of the ML models, the ML models themselves, and the model IDs:
def compatible_data_dict_read(data_dict: Dict[str, Any]) -> List[Any]: if 'my_id' in data_dict.keys(): id = data_dict['my_id'] else: id = generate_id() if 'gene_time' in data_dict.keys(): gene_time = data_dict['gene_time'] else: gene_time = time.time() if 'models' in data_dict.keys(): models = data_dict['models'] else: models = data_dict if 'model_id' in data_dict.keys(): model_id = data_dict['model_id'] else: model_id = generate_model_id( IDPrefix.agent, id, gene_time) return id, gene_time, models, model_id
The save_model_file function is for saving a given set of models into a local file. It takes the following parameters:
def save_model_file( data_dict: Dict[str, Any], path: str, name: str, performance_dict: Dict[str, float] = dict()): data_dict['performance'] = performance_dict fname = f'{path}/{name}' with open(fname, 'wb') as f: pickle.dump(data_dict, f)
load_model_file reads a local model file that takes the following parameters:
It returns the unpickled ML models and performance data in the Dict format:
def load_model_file(path: str, name: str) -> (Dict[str, Any], Dict[str, float]): fname = f'{path}/{name}' with open(fname, 'rb') as f: data_dict = pickle.load(f) performance_dict = data_dict.pop('performance') # data_dict only includes models return data_dict, performance_dict
The read_state function reads a local state file that takes the following parameters:
This function returns a client state, ClientState (for example, training or sending), the state indicated in the file, in an integer format. If the client state file is being written at the time of access, it will try to read the file again after 0.01 seconds:
def read_state(path: str, name: str) -> ClientState: fname = f'{path}/{name}' with open(fname, 'r') as f: st = f.read() if st == '': time.sleep(0.01) return read_state(path, name) return int(st)
write_state changes the client state on the state file in the agent. It takes the following parameters:
def write_state(path: str, name: str, state: ClientState): fname = f'{path}/{name}' with open(fname, 'w') as f: f.write(str(int(state)))
The get_ip function obtains the IP address of the machine and returns the value of the IP address:
def get_ip() -> str: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable s.connect(('1.1.1.1', 1)) ip = s.getsockname()[0] except: ip = '127.0.0.1' finally: s.close() return ip
The init_loop function is used to start a continuous loop function. It takes a function for running a loop function:
def init_loop(func): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(asyncio.gather(func)) loop.run_forever()
In the next section, let’s look at the messenger functions to create communication payloads.
The messenger functions are defined in the messengers.py file, which can be found in the lib/util folder of the fl_main directory.
In this messengers.py code example, the file imports general libraries, such as time and numpy. It also imports ModelType, DBMsgType, AgentMsgType, and AggMsgType, which were defined in the Enumeration classes for implementing the FL system section in this chapter:
import time import numpy as np from typing import Dict, List, Any from fl_main.lib.util.states import ModelType, DBMsgType, AgentMsgType, AggMsgType
Next, let’s move on to the list of functions of the messengers library.
The following is a list of functions related to the messengers library.
The generate_db_push_message function generates and returns a message for pushing the message containing ML models to the database. It takes the following parameters to package them as a payload message (in a List format with the message type defined as push) between the aggregator and database:
The following code provides the functionality of generating the preceding database push message:
def generate_db_push_message( component_id: str, round: int, model_type: ModelType, models: Dict[str,np.array], model_id: str, gene_time: float, performance_dict: Dict[str,float]) -> List[Any]: msg = list() msg.append(DBMsgType.push) # 0 msg.append(component_id) # 1 msg.append(round) # 2 msg.append(model_type) # 3 msg.append(models) # 4 msg.append(model_id) # 5 msg.append(gene_time) # 6 msg.append(performance_dict) # 7 return msg
The generate_lmodel_update_message function generates and returns a message for sending the aggregator a message containing the local models created in an agent. It takes the following parameters to package them as a payload message (in List format with the message type defined as update) between the agent and aggregator:
The following code shows the functionality of generating the preceding local model update message:
def generate_lmodel_update_message( agent_id: str, model_id: str, local_models: Dict[str,np.array], performance_dict: Dict[str,float]) -> List[Any]: msg = list() msg.append(AgentMsgType.update) # 0 msg.append(agent_id) # 1 msg.append(model_id) # 2 msg.append(local_models) # 3 msg.append(time.time()) # 4 msg.append(performance_dict) # 5 return msg
The generate_cluster_model_dist_message function generates and returns a message in List format to send a message containing the global models created by an aggregator to the connected agents. It takes the following parameters to package them as a payload message (in List format with the message type defined as update) between the aggregator and agent:
The following code shows the functionality of generating the preceding cluster model distribution message:
def generate_cluster_model_dist_message( aggregator_id: str, model_id: str, round: int, models: Dict[str,np.array]) -> List[Any]: msg = list() msg.append(AggMsgType.update) # 0 msg.append(aggregator_id) # 1 msg.append(model_id) # 2 msg.append(round) # 3 msg.append(models) # 4 return msg
The generate_agent_participation_message function generates and returns a message to send a participation request message containing the initial models created by an agent to the connected aggregator. It takes the following parameters to package them as a payload message (in List format with the message type defined as participate) between the agent and aggregator:
The following code shows the functionality of generating the preceding agent participation message:
def generate_agent_participation_message( agent_name: str, agent_id: str, model_id: str, models: Dict[str,np.array], init_weights_flag: bool, simulation_flag: bool, exch_socket: str, gene_time: float, meta_dict: Dict[str,float], agent_ip: str) -> List[Any]: msg = list() msg.append(AgentMsgType.participate) # 0 msg.append(agent_id) # 1 msg.append(model_id) # 2 msg.append(models) # 3 msg.append(init_weights_flag) # 4 msg.append(simulation_flag) # 5 msg.append(exch_socket) # 6 msg.append(gene_time) # 7 msg.append(meta_dict) # 8 msg.append(agent_ip) # 9 msg.append(agent_name) # 9 return msg
The generate_agent_participation_confirm_message function generates and returns a message to send a participation confirmation message containing the global models back to the agent. It takes the following parameters to package them as a payload message (in List format with the message type defined as welcome) between the aggregator and agent:
The following code shows the functionality of generating the preceding agent participation confirmation message:
def generate_agent_participation_confirm_message( aggregator_id: str, model_id: str, models: Dict[str,np.array], round: int, agent_id: str, exch_socket: str, recv_socket: str) -> List[Any]: msg = list() msg.append(AggMsgType.welcome) # 0 msg.append(aggregator_id) # 1 msg.append(model_id) # 2 msg.append(models) # 3 msg.append(round) # 4 msg.append(agent_id) # 5 msg.append(exch_socket) # 6 msg.append(recv_socket) # 7 return msg
The generate_polling_message function generates and returns a message to send a polling message containing the polling signal to the aggregator. It takes the following parameters to package them as a payload message (in List format with the message type defined as polling) between the agent and aggregator:
The following code shows the functionality of generating the preceding polling message:
def generate_polling_message(round: int, agent_id: str): msg = list() msg.append(AgentMsgType.polling) # 0 msg.append(round) # 1 msg.append(agent_id) # 2 return msg
The generate_ack_message function generates and returns a message to send an ack message containing the acknowledgment signal back to an agent. No parameter is required to create a payload message (in List format with the message type defined as ack) between the aggregator and agent:
def generate_ack_message(): msg = list() msg.append(AggMsgType.ack) # 0 return msg
In this chapter, we have explained the internal libraries in detail so that you can implement the entire FL system without further investigating what and how to code for basic functionalities such as communication and data structure conversion frameworks.
There are mainly five aspects that the internal library covers: enumeration classes, defining the system states, such as FL client states; the communication handler, supporting send and receive functionalities; the data structure, to handle ML models when aggregation happens; helper and support functions, which cope with basic operations, such as saving data and producing randomized IDs; and messenger functions, to generate various payloads sent among the database, aggregator, and agents.
With these functions, you will find the implementation of FL systems easy and smooth, but these libraries only support achieving some minimal functionality of the FL system; hence, it is up to you to further enhance the FL system to create a more authentic platform that can be used in real-life use cases and technologies.