|
@@ -1,21 +1,125 @@
|
|
|
using H.Pipes;
|
|
|
using H.Pipes.AccessControl;
|
|
|
+using H.Pipes.Args;
|
|
|
using InABox.API;
|
|
|
using InABox.Clients;
|
|
|
using InABox.Core;
|
|
|
+using InABox.Server.WebSocket;
|
|
|
+using Piping;
|
|
|
+using System.Collections.Concurrent;
|
|
|
using System.IO.Pipes;
|
|
|
using System.Reflection;
|
|
|
using System.Security.Principal;
|
|
|
|
|
|
namespace InABox.IPC
|
|
|
{
|
|
|
+ delegate void IPCPollEvent(IPCNotifyState.Session session);
|
|
|
+
|
|
|
+ class IPCNotifyState
|
|
|
+ {
|
|
|
+ public class Session
|
|
|
+ {
|
|
|
+ public PipeConnection<IPCMessage?> Connection { get; }
|
|
|
+
|
|
|
+ public Guid SessionID { get; }
|
|
|
+
|
|
|
+ public Platform Platform { get; }
|
|
|
+
|
|
|
+ public Session(PipeConnection<IPCMessage?> connection, Guid sessionID, Platform platform)
|
|
|
+ {
|
|
|
+ Connection = connection;
|
|
|
+ SessionID = sessionID;
|
|
|
+ Platform = platform;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public ConcurrentDictionary<Guid, Session> SessionMap = new();
|
|
|
+
|
|
|
+ public event IPCPollEvent? OnPoll;
|
|
|
+
|
|
|
+ public void Poll(Session session)
|
|
|
+ {
|
|
|
+ OnPoll?.Invoke(session);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ class IPCNotifier : Notifier
|
|
|
+ {
|
|
|
+ IPCNotifyState NotifyState { get; set; }
|
|
|
+
|
|
|
+ public IPCNotifier(IPCNotifyState notifyState)
|
|
|
+ {
|
|
|
+ NotifyState = notifyState;
|
|
|
+ NotifyState.OnPoll += NotifyState_OnPoll;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void NotifyState_OnPoll(IPCNotifyState.Session session)
|
|
|
+ {
|
|
|
+ Notify.Poll(session.SessionID);
|
|
|
+ }
|
|
|
+
|
|
|
+ protected override IEnumerable<Guid> GetSessions(Platform platform)
|
|
|
+ {
|
|
|
+ return NotifyState.SessionMap.Where(x => x.Value.Platform == platform).Select(x => x.Key);
|
|
|
+ }
|
|
|
+
|
|
|
+ protected override IEnumerable<Guid> GetUserSessions(Guid userID)
|
|
|
+ {
|
|
|
+ return CredentialsCache.GetUserSessions(userID);
|
|
|
+ }
|
|
|
+
|
|
|
+ protected override void NotifyAll<TNotification>(TNotification notification)
|
|
|
+ {
|
|
|
+ foreach(var session in NotifyState.SessionMap.Values)
|
|
|
+ {
|
|
|
+ session.Connection.WriteAsync(IPCMessage.Notification(notification)).ContinueWith(task =>
|
|
|
+ {
|
|
|
+ if(task.Exception != null)
|
|
|
+ {
|
|
|
+ Logger.Send(LogType.Error, "", $"Error in notification: {CoreUtils.FormatException(task.Exception)}");
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ protected override void NotifySession<TNotification>(Guid sessionID, TNotification notification)
|
|
|
+ {
|
|
|
+ if(NotifyState.SessionMap.TryGetValue(sessionID, out var session))
|
|
|
+ {
|
|
|
+ session.Connection.WriteAsync(IPCMessage.Notification(notification)).ContinueWith(task =>
|
|
|
+ {
|
|
|
+ if(task.Exception != null)
|
|
|
+ {
|
|
|
+ Logger.Send(LogType.Error, "", $"Error in notification: {CoreUtils.FormatException(task.Exception)}");
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ protected override void NotifySession(Guid sessionID, Type TNotification, BaseObject notification)
|
|
|
+ {
|
|
|
+ if(NotifyState.SessionMap.TryGetValue(sessionID, out var session))
|
|
|
+ {
|
|
|
+ session.Connection.WriteAsync(IPCMessage.Notification(TNotification, notification)).ContinueWith(task =>
|
|
|
+ {
|
|
|
+ if(task.Exception != null)
|
|
|
+ {
|
|
|
+ Logger.Send(LogType.Error, "", $"Error in notification: {CoreUtils.FormatException(task.Exception)}");
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public class IPCServer : IDisposable
|
|
|
{
|
|
|
- PipeServer<IPCRequest> Server;
|
|
|
+ PipeServer<IPCMessage> Server;
|
|
|
+
|
|
|
+ IPCNotifyState NotifyState = new();
|
|
|
|
|
|
public IPCServer(string name)
|
|
|
{
|
|
|
- Server = new PipeServer<IPCRequest>(name);
|
|
|
+ Server = new PipeServer<IPCMessage>(name);
|
|
|
|
|
|
#if WINDOWS
|
|
|
SetPipeSecurity();
|
|
@@ -93,77 +197,86 @@ namespace InABox.IPC
|
|
|
_ => null
|
|
|
};
|
|
|
}
|
|
|
+ private class RequestData
|
|
|
+ {
|
|
|
+ public ConnectionMessageEventArgs<IPCMessage?> e { get; }
|
|
|
+
|
|
|
+ public RequestData(ConnectionMessageEventArgs<IPCMessage?> e)
|
|
|
+ {
|
|
|
+ this.e = e;
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- private static IPCRequest QueryMultiple(IPCRequest request)
|
|
|
+ private IPCMessage QueryMultiple(IPCMessage request, RequestData data)
|
|
|
{
|
|
|
var response = RestService.QueryMultiple(request.GetRequest<MultiQueryRequest>(), true);
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Validate(IPCRequest request)
|
|
|
+ private IPCMessage Validate(IPCMessage request, RequestData data)
|
|
|
{
|
|
|
var response = RestService.Validate(request.GetRequest<ValidateRequest>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Ping(IPCRequest request) => request.Respond(new PingResponse().Status(StatusCode.OK));
|
|
|
+ private IPCMessage Ping(IPCMessage request, RequestData data) => request.Respond(new PingResponse().Status(StatusCode.OK));
|
|
|
|
|
|
- private static IPCRequest Info(IPCRequest request)
|
|
|
+ private IPCMessage Info(IPCMessage request, RequestData data)
|
|
|
{
|
|
|
var response = RestService.Info(request.GetRequest<InfoRequest>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Check2FA(IPCRequest request)
|
|
|
+ private IPCMessage Check2FA(IPCMessage request, RequestData data)
|
|
|
{
|
|
|
var response = RestService.Check2FA(request.GetRequest<Check2FARequest>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Query<T>(IPCRequest request) where T : Entity, new()
|
|
|
+ private IPCMessage Query<T>(IPCMessage request, RequestData data) where T : Entity, new()
|
|
|
{
|
|
|
var response = RestService<T>.List(request.GetRequest<QueryRequest<T>>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Save<T>(IPCRequest request) where T : Entity, new()
|
|
|
+ private IPCMessage Save<T>(IPCMessage request, RequestData data) where T : Entity, new()
|
|
|
{
|
|
|
var response = RestService<T>.Save(request.GetRequest<SaveRequest<T>>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
- private static IPCRequest MultiSave<T>(IPCRequest request) where T : Entity, new()
|
|
|
+ private IPCMessage MultiSave<T>(IPCMessage request, RequestData data) where T : Entity, new()
|
|
|
{
|
|
|
var response = RestService<T>.MultiSave(request.GetRequest<MultiSaveRequest<T>>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static IPCRequest Delete<T>(IPCRequest request) where T : Entity, new()
|
|
|
+ private IPCMessage Delete<T>(IPCMessage request, RequestData data) where T : Entity, new()
|
|
|
{
|
|
|
var response = RestService<T>.Delete(request.GetRequest<DeleteRequest<T>>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
- private static IPCRequest MultiDelete<T>(IPCRequest request) where T : Entity, new()
|
|
|
+ private IPCMessage MultiDelete<T>(IPCMessage request, RequestData data) where T : Entity, new()
|
|
|
{
|
|
|
var response = RestService<T>.MultiDelete(request.GetRequest<MultiDeleteRequest<T>>());
|
|
|
return request.Respond(response);
|
|
|
}
|
|
|
|
|
|
- private static MethodInfo QueryMethod = GetMethod(nameof(Query));
|
|
|
- private static MethodInfo SaveMethod = GetMethod(nameof(Save));
|
|
|
- private static MethodInfo MultiSaveMethod = GetMethod(nameof(MultiSave));
|
|
|
- private static MethodInfo DeleteMethod = GetMethod(nameof(Delete));
|
|
|
- private static MethodInfo MultiDeleteMethod = GetMethod(nameof(MultiDelete));
|
|
|
- private static MethodInfo QueryMultipleMethod = GetMethod(nameof(QueryMultiple));
|
|
|
- private static MethodInfo ValidateMethod = GetMethod(nameof(Validate));
|
|
|
- private static MethodInfo Check2FAMethod = GetMethod(nameof(Check2FA));
|
|
|
- private static MethodInfo PingMethod = GetMethod(nameof(Ping));
|
|
|
- private static MethodInfo InfoMethod = GetMethod(nameof(Info));
|
|
|
+ private static readonly MethodInfo QueryMethod = GetMethod(nameof(Query));
|
|
|
+ private static readonly MethodInfo SaveMethod = GetMethod(nameof(Save));
|
|
|
+ private static readonly MethodInfo MultiSaveMethod = GetMethod(nameof(MultiSave));
|
|
|
+ private static readonly MethodInfo DeleteMethod = GetMethod(nameof(Delete));
|
|
|
+ private static readonly MethodInfo MultiDeleteMethod = GetMethod(nameof(MultiDelete));
|
|
|
+ private static readonly MethodInfo QueryMultipleMethod = GetMethod(nameof(QueryMultiple));
|
|
|
+ private static readonly MethodInfo ValidateMethod = GetMethod(nameof(Validate));
|
|
|
+ private static readonly MethodInfo Check2FAMethod = GetMethod(nameof(Check2FA));
|
|
|
+ private static readonly MethodInfo PingMethod = GetMethod(nameof(Ping));
|
|
|
+ private static readonly MethodInfo InfoMethod = GetMethod(nameof(Info));
|
|
|
|
|
|
private static MethodInfo GetMethod(string name) =>
|
|
|
- typeof(IPCServer).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
|
|
|
+ typeof(IPCServer).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Instance)
|
|
|
?? throw new Exception($"Invalid method '{name}'");
|
|
|
|
|
|
- private void Server_MessageReceived(object? sender, H.Pipes.Args.ConnectionMessageEventArgs<IPCRequest?> e)
|
|
|
+ private void Server_MessageReceived(object? sender, H.Pipes.Args.ConnectionMessageEventArgs<IPCMessage?> e)
|
|
|
{
|
|
|
Task.Run(() =>
|
|
|
{
|
|
@@ -193,8 +306,14 @@ namespace InABox.IPC
|
|
|
method = method.MakeGenericMethod(entityType);
|
|
|
}
|
|
|
|
|
|
- var response = method.Invoke(null, new object[] { e.Message }) as IPCRequest;
|
|
|
- e.Connection.WriteAsync(response);
|
|
|
+ var response = method.Invoke(this, new object[] { e.Message, new RequestData(e) }) as IPCMessage;
|
|
|
+ e.Connection.WriteAsync(response).ContinueWith(task =>
|
|
|
+ {
|
|
|
+ if (task.Exception != null)
|
|
|
+ {
|
|
|
+ Logger.Send(LogType.Error, "", $"Error in response: {CoreUtils.FormatException(task.Exception)}");
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
catch (Exception err)
|
|
|
{
|
|
@@ -207,20 +326,30 @@ namespace InABox.IPC
|
|
|
var response = (Activator.CreateInstance(responseType) as Response)!;
|
|
|
response.Status = StatusCode.Error;
|
|
|
response.Messages.Add(err.Message);
|
|
|
- e.Connection.WriteAsync(e.Message.Respond(response));
|
|
|
+ e.Connection.WriteAsync(e.Message.Respond(response)).ContinueWith(task =>
|
|
|
+ {
|
|
|
+ if (task.Exception != null)
|
|
|
+ {
|
|
|
+ Logger.Send(LogType.Error, "", $"Error in response: {CoreUtils.FormatException(task.Exception)}");
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- private void Server_ClientDisconnected(object? sender, H.Pipes.Args.ConnectionEventArgs<IPCRequest> e)
|
|
|
+ private void Server_ClientDisconnected(object? sender, H.Pipes.Args.ConnectionEventArgs<IPCMessage> e)
|
|
|
{
|
|
|
Logger.Send(LogType.Information, "", "Client Disconnected");
|
|
|
+
|
|
|
+ var sessionID = NotifyState.SessionMap.Where(x => x.Value.Connection == e.Connection).FirstOrDefault().Key;
|
|
|
+ NotifyState.SessionMap.TryRemove(sessionID, out var session);
|
|
|
+
|
|
|
e.Connection.DisposeAsync();
|
|
|
}
|
|
|
|
|
|
- private void Server_ClientConnected(object? sender, H.Pipes.Args.ConnectionEventArgs<IPCRequest> e)
|
|
|
+ private void Server_ClientConnected(object? sender, H.Pipes.Args.ConnectionEventArgs<IPCMessage> e)
|
|
|
{
|
|
|
Logger.Send(LogType.Information, "", "Client Connected");
|
|
|
}
|