Java 服務器多線程編程
1. 前言
前面小節介紹的 Java TCP Socket 程序是單線程模型,也是阻塞式模型。我們調用 java.net.ServerSocket 的 accept 方法,此時線程會被阻塞,等待客戶端連接。當有新客戶端連接到服務器以后,accept 方法會返回一個 java.net.Socket 類型的對象,此對象代表了客戶端和服務器完成了三次握手后,建立的新連接。 調用 java.net.Socket 的 recv 和 send 方法和客戶端進行數據收發。由于我們采用的是阻塞式 Socket 編程,java.net.ServerSocket 的 accept 方法會阻塞線程,java.net.Socket 的 recv 和 send 方法也會阻塞線程。因此,如果采用此模型,在同一時刻,服務器只能和一個客戶端通信。
要想服務器同時和多個客戶端進行通信,要么采用非阻塞式 Socket 編程,通過 I/O 多路復用機制 實現此目的;要么采用多線程編程模型。當然,在非阻塞式 Socket 編程模型下,往往也采用多線程編程。因為目前的計算機都是多核處理器,采用多線程編碼模型,可以充分利用 CPU 多核的優勢,最大化 CPU 資源的利用。
本節主要介紹阻塞式 Socket 編程中常用的兩種線程模型:
- 每線程模型
- 線程池模型
2. Java 多線程編程方法
由于本節會涉及到 Java 多線程編程,所以需要你能預先掌握 Java 多線程編程的方法。比如,線程的創建,線程的啟動,線程之間的同步和線程之間的通信。
在 Java 平臺下,創建線程的方法有兩種:
-
第一,是創建一個用戶自定義的線程類,然后繼承 java.leng.Thread 類,同時要覆寫它的 run 方法,調用它的 start 方法啟動線程。例如:
class MyThread extends Thread { @Override public void run() { super.run(); } } new MyThread().start();
-
第二,是創建一個任務類。
首先,實現 Runnable 接口,并且重寫它的 run 方法。然后,創建 java.leng.Thread 類的對象,同時將 Runnable 的實例通過 java.lang.Thread 的構造方法傳入。最后,調用 java.lang.Thread 的 start 方法啟動線程。例如:class MyTask implements Runnable { @Override public void run() { } } new Thread(new MyTask()).start();
3. 每線程模型
下圖展示了每線程模型的結構。
從圖中可以看出,每線程模型的程序結構如下:
- 創建一個監聽線程,通常會采用 Java 主線程作為監聽線程。
- 創建一個 java.net.ServerSocket 實例,調用它的 accept 方法等待客戶端的連接。
- 當有新的客戶端和服務器建立連接,accept 方法會返回,創建一個新的線程和客戶端通信。此時監聽線程返回,繼續調用 accept 方法,等待新的客戶端連接。
- 在新線程中調用 java.net.Socket 的 recv 和 send 方法和客戶端進行數據收發。
- 當數據收發完成后,調用 java.net.Socket 的 close 方法關閉連接,同時線程退出。
下來,我們通過一個簡單的示例程序演示一下每線程模型服務器的編寫方法。示例程序的基本功能如下:
- 客戶端每隔 1 秒向服務器發送一個消息。
- 服務器收到客戶端的消息后,向客戶端發送一個響應消息。
- 客戶端發送完 10 個消息后,關閉 Socket 連接,程序退出。
- 服務器檢測到客戶端關閉連接后,同樣關閉 Socket 連接,并且負責和客戶端通信的線程也退出。
客戶端代碼:
import java.io.*;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
public class TCPClientMultiThread {
// 服務器監聽的端口號
private static final int PORT = 56002;
// 連接超時時間
private static final int TIMEOUT = 15000;
// 客戶端執行次數
private static final int TEST_TIMES = 10;
public static void main(String[] args) {
Socket client = null;
try {
// 測試次數
int testCount = 0;
// 調用無參構造方法
client = new Socket();
// 構造服務器地址結構
SocketAddress serverAddr = new InetSocketAddress("192.168.0.101", PORT);
// 連接服務器,超時時間是 15 毫秒
client.connect(serverAddr, TIMEOUT);
System.out.println("Client start:" + client.getLocalSocketAddress().toString());
while (true) {
// 向服務器發送數據
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(client.getOutputStream()));
String req = "Hello Server!";
out.writeInt(req.getBytes().length);
out.write(req.getBytes());
// 不能忘記 flush 方法的調用
out.flush();
System.out.println("Send to server:" + req);
// 接收服務器的數據
DataInputStream in = new DataInputStream(
new BufferedInputStream(client.getInputStream()));
int msgLen = in.readInt();
byte[] inMessage = new byte[msgLen];
in.read(inMessage);
System.out.println("Recv from server:" + new String(inMessage));
// 如果執行次數已經達到上限,結束測試。
if (++testCount >= TEST_TIMES) {
break;
}
// 等待 1 秒然后再執行
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace);
}
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (client != null){
try {
client.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
服務器代碼:
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
public class TCPServerPerThread implements Runnable{
private static final int PORT =56002;
private Socket sock = null;
TCPServerPerThread(Socket sock){
this.sock = sock;
}
@Override
public void run() {
// 讀取客戶端數據
try {
while (true){
// 讀取客戶端數據
DataInputStream in = new DataInputStream(
new BufferedInputStream(sock.getInputStream()));
int msgLen = in.readInt();
byte[] inMessage = new byte[msgLen];
in.read(inMessage);
System.out.println("Recv from client:" + new String(inMessage) + "length:" + msgLen);
// 向客戶端發送數據
String rsp = "Hello Client!\n";
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(sock.getOutputStream()));
out.writeInt(rsp.getBytes().length);
out.write(rsp.getBytes());
out.flush();
System.out.println("Send to client:" + rsp + " length:" + rsp.getBytes().length);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (sock != null){
try {
sock.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public static void main(String[] args) {
ServerSocket ss = null;
try {
// 創建一個服務器 Socket
ss = new ServerSocket(PORT);
while (true){
// 監聽新的連接請求
Socket conn = ss.accept();
System.out.println("Accept a new connection:"
+ conn.getRemoteSocketAddress().toString());
Thread t = new Thread(new TCPServerPerThread(conn));
t.start();
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (ss != null){
try {
ss.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
客戶端采用單線程模型。服務器采用每線程模型,我們采用實現 Runnable 接口的方式實現多線程邏輯。從示例代碼可以看出,每線程模型的優點就是結構簡單,相比單線程模型,也沒有增加復雜度。缺點就是針對每個客戶端都創建線程,當和客戶端通信結束后,線程要退出。頻繁的創建、銷毀線程,對系統的資源消耗比較大,只能用在簡單的業務場景下。
3. 線程池模型
線程池模型的結構如下:
從圖中可以看出,線程池模型的程序結構如下:
- 創建一個監聽線程,通常會采用 Java 主線程作為監聽線程。
- 創建一個 java.net.ServerSocket 實例,調用它的 accept 方法等待客戶端的連接。
- 服務器預先創建一組線程,叫做線程池。線程池中的線程,在服務運行過程中,一直運行,不會退出。
- 當有新的客戶端和服務器建立連接,accept 方法會返回 java.net.Socket 對象,表示新的連接。服務器一般會創建一個處理 java.net.Socket 邏輯的任務,并且將此任務投遞給線程池去處理。然后,監聽線程返回,繼續調用 accept 方法,等待新的客戶端連接。
- 線程池調度空閑的線程去處理任務。
- 在新新任務中調用 java.net.Socket 的 recv 和 send 方法和客戶端進行數據收發。
- 當數據收發完成后,調用 java.net.Socket 的 close 方法關閉連接,任務完成。
- 線程重新回歸線程池,等待調度。
下來,我們同樣通過示例代碼演示一下線程池模型的編寫方法。程序功能和每線程模型完全一致,所以我們只編寫服務端程序,客戶端程序采用每線程模型的客戶端。
示例代碼如下:
import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class TCPServerThreadPool{
// 服務監聽端口號
private static final int PORT =56002;
// 開啟線程數
private static final int THREAD_NUMS = 20;
private static ExecutorService pool = null;
// 創建一個 socket Task 類,處理數據收發
private static class SockTask implements Callable<Void> {
private Socket sock = null;
public SockTask(Socket sock){
this.sock = sock;
}
@Override
public Void call() throws Exception {
try {
while (true){
// 讀取客戶端數據
DataInputStream in = new DataInputStream(
new BufferedInputStream(sock.getInputStream()));
int msgLen = in.readInt();
byte[] inMessage = new byte[msgLen];
in.read(inMessage);
System.out.println("Recv from client:" + new String(inMessage) + "length:" + msgLen);
// 向客戶端發送數據
String rsp = "Hello Client!\n";
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(sock.getOutputStream()));
out.writeInt(rsp.getBytes().length);
out.write(rsp.getBytes());
out.flush();
System.out.println("Send to client:" + rsp + " length:" + rsp.getBytes().length);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (sock != null){
try {
sock.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return null;
}
}
public static void main(String[] args) {
ServerSocket ss = null;
try {
pool = Executors.newFixedThreadPool(THREAD_NUMS);
// 創建一個服務器 Socket
ss = new ServerSocket(PORT);
while (true){
// 監聽新的連接請求
Socket conn = ss.accept();
System.out.println("Accept a new connection:"
+ conn.getRemoteSocketAddress().toString());
pool.submit(new SockTask(conn));
}
} catch (IOException e) {
e.printStackTrace();
} finally {
if (ss != null){
try {
ss.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
4. 小結
本節主要介紹的是 Java 服務器編程中比較簡單的兩種線程模型,每線程模型和線程池模型。示例程序都采用了阻塞式 Socket 編程。編寫 Java 服務器程序,通常需要采用多線程模型。對于非常簡單的業務場景,可以采用每線程模型。對于比較復雜的業務場景,通常需要采用線程池模型。