CHANGED TO MIRROR

This commit is contained in:
DerTyp187
2021-10-25 09:20:01 +02:00
parent bd712107b7
commit e509a919b6
611 changed files with 38291 additions and 1216 deletions

View File

@@ -0,0 +1,149 @@
using System;
using System.IO;
using System.Security.Cryptography;
using System.Text;
namespace Mirror.SimpleWeb
{
/// <summary>
/// Handles Handshakes from new clients on the server
/// <para>The server handshake has buffers to reduce allocations when clients connect</para>
/// </summary>
internal class ServerHandshake
{
const int GetSize = 3;
const int ResponseLength = 129;
const int KeyLength = 24;
const int MergedKeyLength = 60;
const string KeyHeaderString = "Sec-WebSocket-Key: ";
// this isn't an official max, just a reasonable size for a websocket handshake
readonly int maxHttpHeaderSize = 3000;
readonly SHA1 sha1 = SHA1.Create();
readonly BufferPool bufferPool;
public ServerHandshake(BufferPool bufferPool, int handshakeMaxSize)
{
this.bufferPool = bufferPool;
this.maxHttpHeaderSize = handshakeMaxSize;
}
~ServerHandshake()
{
sha1.Dispose();
}
public bool TryHandshake(Connection conn)
{
Stream stream = conn.stream;
using (ArrayBuffer getHeader = bufferPool.Take(GetSize))
{
if (!ReadHelper.TryRead(stream, getHeader.array, 0, GetSize))
return false;
getHeader.count = GetSize;
if (!IsGet(getHeader.array))
{
Log.Warn($"First bytes from client was not 'GET' for handshake, instead was {Log.BufferToString(getHeader.array, 0, GetSize)}");
return false;
}
}
string msg = ReadToEndForHandshake(stream);
if (string.IsNullOrEmpty(msg))
return false;
try
{
AcceptHandshake(stream, msg);
return true;
}
catch (ArgumentException e)
{
Log.InfoException(e);
return false;
}
}
string ReadToEndForHandshake(Stream stream)
{
using (ArrayBuffer readBuffer = bufferPool.Take(maxHttpHeaderSize))
{
int? readCountOrFail = ReadHelper.SafeReadTillMatch(stream, readBuffer.array, 0, maxHttpHeaderSize, Constants.endOfHandshake);
if (!readCountOrFail.HasValue)
return null;
int readCount = readCountOrFail.Value;
string msg = Encoding.ASCII.GetString(readBuffer.array, 0, readCount);
Log.Verbose(msg);
return msg;
}
}
static bool IsGet(byte[] getHeader)
{
// just check bytes here instead of using Encoding.ASCII
return getHeader[0] == 71 && // G
getHeader[1] == 69 && // E
getHeader[2] == 84; // T
}
void AcceptHandshake(Stream stream, string msg)
{
using (
ArrayBuffer keyBuffer = bufferPool.Take(KeyLength),
responseBuffer = bufferPool.Take(ResponseLength))
{
GetKey(msg, keyBuffer.array);
AppendGuid(keyBuffer.array);
byte[] keyHash = CreateHash(keyBuffer.array);
CreateResponse(keyHash, responseBuffer.array);
stream.Write(responseBuffer.array, 0, ResponseLength);
}
}
static void GetKey(string msg, byte[] keyBuffer)
{
int start = msg.IndexOf(KeyHeaderString) + KeyHeaderString.Length;
Log.Verbose($"Handshake Key: {msg.Substring(start, KeyLength)}");
Encoding.ASCII.GetBytes(msg, start, KeyLength, keyBuffer, 0);
}
static void AppendGuid(byte[] keyBuffer)
{
Buffer.BlockCopy(Constants.HandshakeGUIDBytes, 0, keyBuffer, KeyLength, Constants.HandshakeGUID.Length);
}
byte[] CreateHash(byte[] keyBuffer)
{
Log.Verbose($"Handshake Hashing {Encoding.ASCII.GetString(keyBuffer, 0, MergedKeyLength)}");
return sha1.ComputeHash(keyBuffer, 0, MergedKeyLength);
}
static void CreateResponse(byte[] keyHash, byte[] responseBuffer)
{
string keyHashString = Convert.ToBase64String(keyHash);
// compiler should merge these strings into 1 string before format
string message = string.Format(
"HTTP/1.1 101 Switching Protocols\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: websocket\r\n" +
"Sec-WebSocket-Accept: {0}\r\n\r\n",
keyHashString);
Log.Verbose($"Handshake Response length {message.Length}, IsExpected {message.Length == ResponseLength}");
Encoding.ASCII.GetBytes(message, 0, ResponseLength, responseBuffer, 0);
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 6268509ac4fb48141b9944c03295da11
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,74 @@
using System;
using System.IO;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
namespace Mirror.SimpleWeb
{
public struct SslConfig
{
public readonly bool enabled;
public readonly string certPath;
public readonly string certPassword;
public readonly SslProtocols sslProtocols;
public SslConfig(bool enabled, string certPath, string certPassword, SslProtocols sslProtocols)
{
this.enabled = enabled;
this.certPath = certPath;
this.certPassword = certPassword;
this.sslProtocols = sslProtocols;
}
}
internal class ServerSslHelper
{
readonly SslConfig config;
readonly X509Certificate2 certificate;
public ServerSslHelper(SslConfig sslConfig)
{
config = sslConfig;
if (config.enabled)
certificate = new X509Certificate2(config.certPath, config.certPassword);
}
internal bool TryCreateStream(Connection conn)
{
NetworkStream stream = conn.client.GetStream();
if (config.enabled)
{
try
{
conn.stream = CreateStream(stream);
return true;
}
catch (Exception e)
{
Log.Error($"Create SSLStream Failed: {e}", false);
return false;
}
}
else
{
conn.stream = stream;
return true;
}
}
Stream CreateStream(NetworkStream stream)
{
SslStream sslStream = new SslStream(stream, true, acceptClient);
sslStream.AuthenticateAsServer(certificate, false, config.sslProtocols, false);
return sslStream;
}
bool acceptClient(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
// always accept client
return true;
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 11061fee528ebdd43817a275b1e4a317
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,105 @@
using System;
using System.Collections.Generic;
using UnityEngine;
namespace Mirror.SimpleWeb
{
public class SimpleWebServer
{
readonly int maxMessagesPerTick;
readonly WebSocketServer server;
readonly BufferPool bufferPool;
public SimpleWebServer(int maxMessagesPerTick, TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig)
{
this.maxMessagesPerTick = maxMessagesPerTick;
// use max because bufferpool is used for both messages and handshake
int max = Math.Max(maxMessageSize, handshakeMaxSize);
bufferPool = new BufferPool(5, 20, max);
server = new WebSocketServer(tcpConfig, maxMessageSize, handshakeMaxSize, sslConfig, bufferPool);
}
public bool Active { get; private set; }
public event Action<int> onConnect;
public event Action<int> onDisconnect;
public event Action<int, ArraySegment<byte>> onData;
public event Action<int, Exception> onError;
public void Start(ushort port)
{
server.Listen(port);
Active = true;
}
public void Stop()
{
server.Stop();
Active = false;
}
public void SendAll(List<int> connectionIds, ArraySegment<byte> source)
{
ArrayBuffer buffer = bufferPool.Take(source.Count);
buffer.CopyFrom(source);
buffer.SetReleasesRequired(connectionIds.Count);
// make copy of array before for each, data sent to each client is the same
foreach (int id in connectionIds)
{
server.Send(id, buffer);
}
}
public void SendOne(int connectionId, ArraySegment<byte> source)
{
ArrayBuffer buffer = bufferPool.Take(source.Count);
buffer.CopyFrom(source);
server.Send(connectionId, buffer);
}
public bool KickClient(int connectionId)
{
return server.CloseConnection(connectionId);
}
public string GetClientAddress(int connectionId)
{
return server.GetClientAddress(connectionId);
}
public void ProcessMessageQueue(MonoBehaviour behaviour)
{
int processedCount = 0;
// check enabled every time in case behaviour was disabled after data
while (
behaviour.enabled &&
processedCount < maxMessagesPerTick &&
// Dequeue last
server.receiveQueue.TryDequeue(out Message next)
)
{
processedCount++;
switch (next.type)
{
case EventType.Connected:
onConnect?.Invoke(next.connId);
break;
case EventType.Data:
onData?.Invoke(next.connId, next.data.ToSegment());
next.data.Release();
break;
case EventType.Disconnected:
onDisconnect?.Invoke(next.connId);
break;
case EventType.Error:
onError?.Invoke(next.connId, next.exception);
break;
}
}
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: bd51d7896f55a5e48b41a4b526562b0e
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,230 @@
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
namespace Mirror.SimpleWeb
{
public class WebSocketServer
{
public readonly ConcurrentQueue<Message> receiveQueue = new ConcurrentQueue<Message>();
readonly TcpConfig tcpConfig;
readonly int maxMessageSize;
TcpListener listener;
Thread acceptThread;
bool serverStopped;
readonly ServerHandshake handShake;
readonly ServerSslHelper sslHelper;
readonly BufferPool bufferPool;
readonly ConcurrentDictionary<int, Connection> connections = new ConcurrentDictionary<int, Connection>();
int _idCounter = 0;
public WebSocketServer(TcpConfig tcpConfig, int maxMessageSize, int handshakeMaxSize, SslConfig sslConfig, BufferPool bufferPool)
{
this.tcpConfig = tcpConfig;
this.maxMessageSize = maxMessageSize;
sslHelper = new ServerSslHelper(sslConfig);
this.bufferPool = bufferPool;
handShake = new ServerHandshake(this.bufferPool, handshakeMaxSize);
}
public void Listen(int port)
{
listener = TcpListener.Create(port);
listener.Start();
Log.Info($"Server has started on port {port}");
acceptThread = new Thread(acceptLoop);
acceptThread.IsBackground = true;
acceptThread.Start();
}
public void Stop()
{
serverStopped = true;
// Interrupt then stop so that Exception is handled correctly
acceptThread?.Interrupt();
listener?.Stop();
acceptThread = null;
Log.Info("Server stopped, Closing all connections...");
// make copy so that foreach doesn't break if values are removed
Connection[] connectionsCopy = connections.Values.ToArray();
foreach (Connection conn in connectionsCopy)
{
conn.Dispose();
}
connections.Clear();
}
void acceptLoop()
{
try
{
try
{
while (true)
{
TcpClient client = listener.AcceptTcpClient();
tcpConfig.ApplyTo(client);
// TODO keep track of connections before they are in connections dictionary
// this might not be a problem as HandshakeAndReceiveLoop checks for stop
// and returns/disposes before sending message to queue
Connection conn = new Connection(client, AfterConnectionDisposed);
Log.Info($"A client connected {conn}");
// handshake needs its own thread as it needs to wait for message from client
Thread receiveThread = new Thread(() => HandshakeAndReceiveLoop(conn));
conn.receiveThread = receiveThread;
receiveThread.IsBackground = true;
receiveThread.Start();
}
}
catch (SocketException)
{
// check for Interrupted/Abort
Utils.CheckForInterupt();
throw;
}
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (Exception e) { Log.Exception(e); }
}
void HandshakeAndReceiveLoop(Connection conn)
{
try
{
bool success = sslHelper.TryCreateStream(conn);
if (!success)
{
Log.Error($"Failed to create SSL Stream {conn}");
conn.Dispose();
return;
}
success = handShake.TryHandshake(conn);
if (success)
{
Log.Info($"Sent Handshake {conn}");
}
else
{
Log.Error($"Handshake Failed {conn}");
conn.Dispose();
return;
}
// check if Stop has been called since accepting this client
if (serverStopped)
{
Log.Info("Server stops after successful handshake");
return;
}
conn.connId = Interlocked.Increment(ref _idCounter);
connections.TryAdd(conn.connId, conn);
receiveQueue.Enqueue(new Message(conn.connId, EventType.Connected));
Thread sendThread = new Thread(() =>
{
SendLoop.Config sendConfig = new SendLoop.Config(
conn,
bufferSize: Constants.HeaderSize + maxMessageSize,
setMask: false);
SendLoop.Loop(sendConfig);
});
conn.sendThread = sendThread;
sendThread.IsBackground = true;
sendThread.Name = $"SendLoop {conn.connId}";
sendThread.Start();
ReceiveLoop.Config receiveConfig = new ReceiveLoop.Config(
conn,
maxMessageSize,
expectMask: true,
receiveQueue,
bufferPool);
ReceiveLoop.Loop(receiveConfig);
}
catch (ThreadInterruptedException e) { Log.InfoException(e); }
catch (ThreadAbortException e) { Log.InfoException(e); }
catch (Exception e) { Log.Exception(e); }
finally
{
// close here in case connect fails
conn.Dispose();
}
}
void AfterConnectionDisposed(Connection conn)
{
if (conn.connId != Connection.IdNotSet)
{
receiveQueue.Enqueue(new Message(conn.connId, EventType.Disconnected));
connections.TryRemove(conn.connId, out Connection _);
}
}
public void Send(int id, ArrayBuffer buffer)
{
if (connections.TryGetValue(id, out Connection conn))
{
conn.sendQueue.Enqueue(buffer);
conn.sendPending.Set();
}
else
{
Log.Warn($"Cant send message to {id} because connection was not found in dictionary. Maybe it disconnected.");
}
}
public bool CloseConnection(int id)
{
if (connections.TryGetValue(id, out Connection conn))
{
Log.Info($"Kicking connection {id}");
conn.Dispose();
return true;
}
else
{
Log.Warn($"Failed to kick {id} because id not found");
return false;
}
}
public string GetClientAddress(int id)
{
if (connections.TryGetValue(id, out Connection conn))
{
return conn.client.Client.RemoteEndPoint.ToString();
}
else
{
Log.Error($"Cant close connection to {id} because connection was not found in dictionary");
return null;
}
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 5c434db044777d2439bae5a57d4e8ee7
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant: