TCP并发的简单实现
之前的代码中,用一个循环开始接受连接,这样的代码只能接受一个连接,在接收连接之后,立刻跳转到与当前socket的通信过程,通信结束以后,才会跳回原来的大循环继续阻塞住等待连接.
显然这样是无法接受并发的,而现实中上网的各种服务器,都可以并发,要在python中实现tcp的并发,需要用到socketserver模块.
先来写一个简单的并发server再来了解原理.其实也能想到,就是每次有连接过来,实例化某个可以进行连接和通信的对象去做这个事情就可以了.再有连接过来,再实例化对象.
# server
import socketserver
buffer_size = 1024
ip_port = ('127.0.0.1', 8080,)
class MyServer(socketserver.BaseRequestHandler):
def handle(self): # 这是一个收发消息的通信循环,还需要一个连接循环
print('conn is:', self.request) # conn
print('address is:', self.client_address) # address
while True:
try:
# 收消息
data = self.request.recv(buffer_size)
if not data:
break
print('收到的消息是:', data.decode('utf-8'))
# 发消息
self.request.sendall(data.upper())
except Exception as e:
print(e)
break
if __name__ == '__main__':
s = socketserver.ThreadingTCPServer(ip_port, MyServer) # 服务端对象,
s.serve_forever() # 服务端对象无限循环,相当于连接循环
这样的服务端可以同时接受多个TCP连接.socketserver的原理是什么呢,先来看一看socketserver的工作流程:
socketserver源码分析-TCP并发部分
如果遇到一个模块想要搞清楚结构,可以通过在Pycharm内ctrl点击模块名称,进入到模块的源代码,然后可以用Pycharm的UML类图功能来查看这个文件的组织形式. 进入:在当前文件右键-》Diagrams-》Show Diagrams-》Python Class Diagrams. 点击UML图界面上方的m图片可以显示成员函数,点击f图标可以显示成员变量. socketserver的类图如下:可以看到大体的设计思路:
即有一个Baseserver,然后发展出TCP server,之后UDP server也继承TCPserver,但是可以将其当做一个主要分支.
之后的四个具体服务:ThreadingTCPserver ThreadingUnixStreamServer ThreadingUnixDatagramServer和ThreadingUDPserver看名字就知道,分别是多线程的TCP,UDP的网络服务器和多线程的Unix文件的TCP和UDP服务器.凡是多线程的(名字里有Threading的)都会继承一个ThreadingMixIn类.
还有两个ForkingTCP和UDPserver,这些是多进程的服务,都继承一个ForkingMixIn类.
然后上边的这些服务类,最终继承到TCP和UDP server类.关系还不算太复杂,知道了需要哪种服务器,就通过哪个类进行实例化.
这些服务类,相当于我们自己写的服务器里边的产生连接,具体负责通信的,还有一个类和两个继承类.
DatagramRequestHandler 和 StreamRequestHandler 从名字可以看出来是分别用于UDP和TCP通信,继承BaseRequestHandler类.
来分析我们的代码:
MyServer(socketserver.BaseRequestHandler)可以看出,是继承了BaseRequestHandler类,看一看这个类的代码:
class BaseRequestHandler:
""""""Base class for request handler classes.
This class is instantiated for each request to be handled. The
constructor sets the instance variables request, client_address
and server, and then calls the handle() method. To implement a
specific service, all you need to do is to derive a class which
defines a handle() method.
这个类用于处理每个request的通信.构建类的时候需要request,客户地址和server对象 ,然后调用handle()方法.想完成一个特定的服务,需要继承此类然后定义一个handle()方法.
The handle() method can find the request as self.request, the
client address as self.client_address, and the server (in case it
needs access to per-server information) as self.server. Since a
separate instance is created for each request, the handle() method
can define other arbitrary instance variables.
handle()方法里通过self.request找到request.由于针对每个request都生成独立的对象,handle()方法里还可以定义任意的变量.
""""""
def __init__(self, request, client_address, server):
self.request = request
self.client_address = client_address
self.server = server
self.setup()
try:
self.handle()
finally:
self.finish()
def setup(self):
pass
def handle(self):
pass
def finish(self):
pass
从类图上看到有两个类继承了这个类,点进去发现,这两个类也没有实现handle()方法,这就是为什么我们要来写handle()方法的原因.那么BaseRequestHandler初始化的过程中,传入的request, client_address, server这三个参数究竟是什么呢?
从我们自己写的收和发消息的语句也能猜出来,request很可能是一个socket对象.
继续往下看这一行s = socketserver.ThreadingTCPServer(ip_port, MyServer).
这一行从名字可以看出,实例化了一个多线程TCPserver对象,参数是ip_port和Myserver,就是我们刚刚定义的处理消息的对象,这里既然是处理连接,猜想估计是把我们的myserver实例化了之后,去处理一个新连接生成的socket对象的通信.
查看ThreadingTCPServer类的代码,结果发现,只有一行代码class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass,类里面什么都没有.说明继承自ThreadingMixIn和 TCPServer类,到两个父类里继续寻找__init__方法,到ThreadingMixIn类内发现代码如下:
class ThreadingMixIn:
""""""Mix-in class to handle each request in a new thread.""""""
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
def process_request_thread(self, request, client_address):
""""""Same as in BaseServer but as a thread.
In addition, exception handling is done here.
""""""
try:
self.finish_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
finally:
self.shutdown_request(request)
def process_request(self, request, client_address):
""""""Start a new thread to process the request.""""""
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
t.start()
没有该方法,查看TCPserver类,发现其中有__init__方法,而且TCPserver类继承了BaseServer类:
class TCPServer(BaseServer):
address_family = socket.AF_INET # 地址家族
socket_type = socket.SOCK_STREAM # socket协议
request_queue_size = 5 # backlog半连接池大小
allow_reuse_address = False # 是否允许地址重用
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
""""""Constructor. May be extended, do not override.""""""
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.socket = socket.socket(self.address_family,
self.socket_type)
if bind_and_activate:
try:
self.server_bind()
self.server_activate()
except:
self.server_close()
raise
在类属性里看到了一些很熟悉的东西,由于socketserver模块最开始就导入了socket模块,所以这里看到了几个之前用过的变量,地址家族就是AF_INET,socket类型就是TCP,request_queue_size就是backlog,而允许重用地址在之前也遇到过.
init方法里,依次传入的是实例,服务地址,RequestHandlerClass类,还有一个默认值参数bind_and_activate=True.
实例不用多说,就是s自己,server_address传入的是ip_port元组,和之前程序内的一样,是IP和端口的元组.MyServer就是我们继承自BaseRequestHandler的类,符合RequestHandlerClass这个参数的要求(传入一个RequestHandler的类).结果发现这个初始化函数,又调用了BaseServer的init方法,再追上去看一看:
class BaseServer:
def __init__(self, server_address, RequestHandlerClass):
""""""Constructor. May be extended, do not override.""""""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
这里就很清楚了,又出来了几个类属性,server_address就是传入的ip+端口,RequestHandlerClass指向了传入的类.后边先不管,然后跳回到TCPServer的init函数下一行,啊,发现了熟悉的东西:
self.socket = socket.socket(self.address_family,self.socket_type)
类属性socket就是实例化的一个socket对象,和我们自己生成的socket对象一样.
之后立刻去尝试使用self.server_bind()和 self.server_activate()两个方法.
追到TCPServer的server_bind()和server_activate()方法:
class TCPServer(BaseServer):
def server_bind(self):
""""""Called by constructor to bind the socket.
May be overridden.
""""""
if self.allow_reuse_address:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname()
def server_activate(self):
""""""Called by constructor to activate the server.
May be overridden.
""""""
self.socket.listen(self.request_queue_size)
又是熟悉的语句,其实就是拿我们传入的ip_port做了和自行编写代码socket.bind()一样的事情,然后用sockname()方法的结果更新server_address.这里实际上就是绑定IP和端口.
然后server_activate()做了一件也是我们编写过代码的事情,就是listen.
好,现在回头再看 s = socketserver.ThreadingTCPServer(ip_port, MyServer)
这句执行完毕之后,s这个对象下边有若干属性,其中一个叫socket的属性,就是一个已经阻塞在listen状态的,按照我们传入的ip_port参数建立的socket对象.还有我们传入的MyServer类起什么作用呢?
再看下一句:
s.serve_forever() 要找到serve_forever方法,ThreadingTCPServer类里没有,继续找ThreadingMixIn,还是没有,再找TCPServer类,还是没有,最后在BaseServer类里找到了,如下:
# s.serve_forever()
class BaseServer:
def serve_forever(self, poll_interval=0.5):
""""""Handle one request at a time until shutdown.
Polls for shutdown every poll_interval seconds. Ignores
self.timeout. If you need to do periodic tasks, do them in
another thread.
""""""
self.__is_shut_down.clear()
try:
# XXX: Consider using another file descriptor or connecting to the
# socket to wake this up instead of polling. Polling reduces our
# responsiveness to a shutdown request and wastes cpu at all other
# times.
with _ServerSelector() as selector:
selector.register(self, selectors.EVENT_READ)
while not self.__shutdown_request:
ready = selector.select(poll_interval)
if ready:
self._handle_request_noblock()
self.service_actions()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
先不看和selector相关的语句,看while内的语句,从字面看,如果不关闭,就一直有一个循环.其实就是执行_handle_request_noblock()方法,一路寻找,两个父类都没有,最后还是在BaseServer里找到_handle_request_noblock()这个方法:
# _handle_request_noblock()
def _handle_request_noblock(self):
""""""Handle one request, without blocking.
I assume that selector.select() has returned that the socket is
readable before this function was called, so there should be no risk of
blocking in get_request().
""""""
try:
request, client_address = self.get_request()
except OSError:
return
if self.verify_request(request, client_address):
try:
self.process_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
self.shutdown_request(request)
except:
self.shutdown_request(request)
raise
else:
self.shutdown_request(request)
这个时候发现第一句是request, client_address = self.get_request(),继续再找get_request()方法,这次是在TCPServer类里找到了:
def get_request(self):
""""""Get the request and client address from the socket.
May be overridden.
""""""
return self.socket.accept()
结果发现,这不就是socket.accept()么,那么上边的request, client_address = self.get_request()这句得到的request就是已经三次握手之后的socket连接,client_address就是字面上的客户地址.
回到_handle_request_noblock()里边继续,这个时候已经知道了request的内容.
看后边的代码,这里先跳过self.verify_request这步验证,然后看self.process_request(request, client_address)这一行,继续找这个方法,最后在ThreadingMixIn类内找到:
class ThreadingMixIn:
daemon_threads = False
def process_request_thread(self, request, client_address):
""""""Same as in BaseServer but as a thread.
In addition, exception handling is done here.
""""""
try:
self.finish_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
finally:
self.shutdown_request(request)
def process_request(self, request, client_address):
""""""Start a new thread to process the request.""""""
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
t.start()
这个地方就是第一个核心步骤了,虽然还没有学,但是可以看到,threading模块内的process_request方法,用一个新的线程去调用self.process_request_thread的方法(也在ThreadingMixIn类中),参数是request(建立三次握手的TCP连接对象和客户地址).然后到process_request_thread(self, request, client_address)方法里一看,(从这里往后,是在一个新的线程里执行了)有一个self.finish_request(request, client_address)方法,继续寻找,结果在BaseServer里找到:
class BaseServer:
def finish_request(self, request, client_address):
""""""Finish one request by instantiating RequestHandlerClass.""""""
self.RequestHandlerClass(request, client_address, self)
找了这么久,终于发现了第二个核心步骤,就是用到了RequestHandlerClass类,也就是我们传入的自定义的MyServer类.finish_request方法做的事情很简单,就是实例化了一个MyServer类.
回到我们的MyServer类,还记得吗BaseServer初始化的那几个参数吗:
class BaseServer:
def __init__(self, server_address, RequestHandlerClass):
""""""Constructor. May be extended, do not override.""""""
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
这里的self.RequestHandlerClass就是传入的MyServer,然后这里用了request, client_address, self的顺序去实例化了MyServer.我们的MyServer是继承了BaseRequestHandler类的,所以看看BaseRequestHandler的初始化函数:
class BaseRequestHandler:
def __init__(self, request, client_address, server):
self.request = request
self.client_address = client_address
self.server = server
self.setup()
try:
self.handle()
finally:
self.finish()
一切都明白了,request就是一个已经建立了三次握手的TCP连接对象,地址就是这个TCP连接对象的客户地址,而server,就是s对象.
可见为什么能够实现TCP并发,等待循环的过程还是不变的,s对象产生以后,就生成了s.socket这一个对象,然后通过serve_forever方法,不断的试验TCP连接的生成情况,如果生成了三次握手的TCP连接对象,就立刻通过调用process_request_thread方法新开一个线程,把这个对象连同实例化的处理通信的对象丢到新的线程里去处理.之后就在serve_forever里继续.
为什么要定义handle,就是因为init里会立刻调用handle,而且自己想要的通信判断等逻辑,要写在自己的handle方法里.
还一个需要补充的是,如果初始化连接服务对象的时候不采用ThreadingTCPServer(多线程)而是采用ForkingTCPServer(多进程)的话,服务一样可以运行(windows下会因为os模块没有fork而失败),只不过多线程的开销比多进程的要低.后边学习并发的时候会了解.通过继承关系也能看到,多进程的关键就是ForkingMixIn类.