Browse Source

Database proxy system; certainly not done

Kenric Nugteren 1 year ago
parent
commit
ac4af1005a

+ 14 - 0
prs.classes/Server/Properties/DatabaseProxyProperties.cs

@@ -0,0 +1,14 @@
+using InABox.Core;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace PRSServer
+{
+    public abstract class DatabaseProxyProperties : ServerProperties
+    {
+        [ComboLookupEditor(typeof(DatabaseServerLookupGenerator))]
+        [EditorSequence(1)]
+        public string Server { get; set; }
+    }
+}

+ 1 - 0
prs.classes/Server/ServerType.cs

@@ -8,6 +8,7 @@
         AutoDiscovery,
         Web,
         Certificate,
+        Proxy,
         Other = -1
     }
 }

+ 1 - 1
prs.desktop/MainWindow.xaml.cs

@@ -3049,7 +3049,7 @@ public partial class MainWindow : IPanelHostControl
                 try
                 {
                     var client = new HttpClient { BaseAddress = new Uri($"https://{domain}") };
-                    client.GetAsync("operations").Wait();
+                    client.GetAsync("ping").Wait();
                     url = $"https://{domain}";
                 }
                 catch (Exception)

+ 0 - 1
prs.licensing/Engine/LicensingEngine.cs

@@ -48,7 +48,6 @@ public class LicensingEngine : Engine<LicensingEngineProperties>
             listener = new Listener<LicensingHandler, LicensingHandlerProperties>(new LicensingHandlerProperties(Properties));
             listener.InitHTTPS((ushort)Properties.ListenPort, CertificateFileName());
 
-            Logger.Send(LogType.Information, "", "Starting Web Listener on port " + Properties.ListenPort);
             listener.Start();
         }
         catch (Exception eListen)

+ 48 - 0
prs.server/Engines/Database/Proxies/DatabaseProxyEngine.cs

@@ -0,0 +1,48 @@
+using InABox.Clients;
+using InABox.Core;
+using InABox.Rpc;
+using PRSServices;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace PRSServer;
+
+public abstract class DatabaseProxyEngine<TProperties> : Engine<TProperties>
+    where TProperties : DatabaseProxyProperties
+{
+
+    public override void Run()
+    {
+        Logger.Send(LogType.Information, "", "Starting..");
+
+        if (string.IsNullOrWhiteSpace(Properties.Server))
+        {
+            Logger.Send(LogType.Error, "", "Server is blank!");
+            return;
+        }
+
+        var transport = new RpcClientPipeTransport(DatabaseServerProperties.GetPipeName(Properties.Server, true));
+        ClientFactory.SetClientType(typeof(RpcClient<>), Platform.LicensingEngine, Version, transport);
+        CheckConnection();
+
+        RunProxy();
+    }
+
+    protected abstract void RunProxy();
+
+    private void CheckConnection()
+    {
+        // Wait for server connection
+        while (!Client.Ping())
+        {
+            Logger.Send(LogType.Error, "", "Database server unavailable. Trying again in 30 seconds...");
+            Task.Delay(30_000).Wait();
+            Logger.Send(LogType.Information, "", "Retrying connection...");
+        }
+
+        ClientFactory.SetBypass();
+    }
+}

+ 410 - 0
prs.server/Engines/Database/Proxies/HTTPDatabaseProxyEngine.cs

@@ -0,0 +1,410 @@
+using GenHTTP.Api.Content;
+using GenHTTP.Api.Protocol;
+using GenHTTP.Modules.IO;
+using GenHTTP.Modules.IO.Streaming;
+using GenHTTP.Modules.IO.Strings;
+using InABox.Clients;
+using InABox.Core;
+using PRSServices;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
+
+namespace PRSServer;
+
+public class HTTPDatabaseProxyProperties : DatabaseProxyProperties
+{
+    [IntegerEditor]
+    [EditorSequence(2)]
+    public int ListenPort { get; set; }
+
+    [EditorSequence(3)]
+    [FileNameEditor("Certificate Files (*.pfx)|*.pfx")]
+    public string CertificateFile { get; set; }
+
+    public override ServerType Type() => ServerType.Proxy;
+}
+
+internal class HTTPDatabaseProxyHandler : Handler<HTTPDatabaseProxyProperties>
+{
+    private readonly List<string> endpoints;
+    private readonly List<string> operations;
+
+    public RestHandler(IHandler parent)
+    {
+        Parent = parent;
+
+        endpoints = new();
+        operations = new();
+
+        var types = CoreUtils.TypeList(
+            x => x.IsSubclassOf(typeof(Entity))
+                 && x.GetInterfaces().Contains(typeof(IRemotable))
+        );
+        var DBTypes = DbFactory.SupportedTypes();
+
+        foreach (var t in types)
+            if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
+            {
+                operations.Add(t.EntityName().Replace(".", "_"));
+
+                endpoints.Add(string.Format("List{0}", t.Name));
+                endpoints.Add(string.Format("Load{0}", t.Name));
+                endpoints.Add(string.Format("Save{0}", t.Name));
+                endpoints.Add(string.Format("MultiSave{0}", t.Name));
+                endpoints.Add(string.Format("Delete{0}", t.Name));
+                endpoints.Add(string.Format("MultiDelete{0}", t.Name));
+            }
+
+        endpoints.Add("QueryMultiple");
+    }
+    
+    private RequestData GetRequestData(IRequest request)
+    {
+        BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
+        if (request.Query.TryGetValue("serializationVersion", out var versionString))
+        {
+            settings = BinarySerializationSettings.ConvertVersionString(versionString);
+        }
+
+        var data = new RequestData(settings);
+        if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
+        {
+            data.RequestFormat = format;
+        }
+        data.ResponseFormat = SerializationFormat.Json;
+        if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
+        {
+            data.ResponseFormat = format;
+        }
+
+        return data;
+    }
+
+    /// <summary>
+    /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
+    /// </summary>
+    /// <param name="request"></param>
+    /// <returns></returns>
+    public ValueTask<IResponse?> HandleAsync(IRequest request)
+    {
+        try
+        {
+            switch (request.Method.KnownMethod)
+            {
+                case RequestMethod.GET:
+                case RequestMethod.HEAD:
+                    var current = request.Target.Current?.Value;
+                    if (String.Equals(current,"update"))
+                    {
+                        request.Target.Advance();
+                        current = request.Target.Current?.Value;
+                    }
+                    switch (current)
+                    {
+                        case "operations" or "supported_operations":
+                            Logger.Send(LogType.Error, "", "Supported operations is no longer supported");
+                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                        case "classes" or "supported_classes":
+                            Logger.Send(LogType.Error, "", "Supported classes is no longer supported");
+                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                        case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
+                            return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
+                        case "info":
+                            return new ValueTask<IResponse?>(GetServerInfo(request).Build());
+                        case "ping":
+                            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
+                    }
+
+                    Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
+                        string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
+                    return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                case RequestMethod.POST:
+                    var target = request.Target.Current;
+                    if (target is not null)
+                    {
+                        var data = GetRequestData(request);
+                        return target.Value switch
+                        {
+                            "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
+                            "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
+                            _ => HandleDatabaseRequest(request, data),
+                        };
+                    }
+                    return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+            }
+
+            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
+        }
+        catch(Exception e)
+        {
+            Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
+            return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
+        }
+    }
+    
+    /// <summary>
+    /// Returns the Splash Logo and Color Scheme for this Database
+    /// </summary>
+    /// <param name="request"></param>
+    /// <returns></returns>
+    private IResponseBuilder GetServerInfo(IRequest request)
+    {
+        var data = GetRequestData(request);
+        var response = new InfoResponse(Client.Info());
+        response.Status = StatusCode.OK;
+        return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+    }
+
+    #region Authentication
+
+    private IResponseBuilder Validate(IRequest request, RequestData data)
+    {
+        var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        var response = RestService.Validate(requestObj);
+
+        return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+    }
+
+    private IResponseBuilder Check2FA(IRequest request, RequestData data)
+    {
+        var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        var response = RestService.Check2FA(requestObj);
+
+        return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
+    }
+
+    #endregion
+
+    #region Database
+
+    private static MethodInfo GetMethod(string name) =>
+        typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
+        ?? throw new Exception($"Invalid method '{name}'");
+
+    private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
+    {
+        new("List", GetMethod(nameof(List))),
+        new("Save", GetMethod(nameof(Save))),
+        new("Delete", GetMethod(nameof(Delete))),
+        new("MultiSave", GetMethod(nameof(MultiSave))),
+        new("MultiDelete", GetMethod(nameof(MultiDelete)))
+    };
+
+    private class RequestData
+    {
+        public SerializationFormat RequestFormat { get; set; }
+        public SerializationFormat ResponseFormat { get; set; }
+
+        public BinarySerializationSettings BinarySerializationSettings { get; set; }
+
+        public RequestData(BinarySerializationSettings binarySerializationSettings)
+        {
+            BinarySerializationSettings = binarySerializationSettings;
+        }
+    }
+
+    private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
+    {
+        var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        return RestService<T>.List(requestObject);
+    }
+    private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
+    {
+        var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        return RestService<T>.Save(requestObject);
+    }
+    private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
+    {
+        var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        return RestService<T>.Delete(requestObject);
+    }
+    private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
+    {
+        var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        return RestService<T>.MultiSave(requestObject);
+    }
+    private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
+    {
+        var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+        return RestService<T>.MultiDelete(requestObject);
+    }
+    private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
+    {
+        var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
+
+        return RestService.QueryMultiple(requestObject, false);
+    }
+
+    private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
+    {
+        if (stream is null)
+            throw new Exception("Stream is null");
+        if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
+        {
+            return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
+        }
+        else
+        {
+            var str = new StreamReader(stream).ReadToEnd();
+            return Serialization.Deserialize<T>(str, strict)
+                   ?? throw new Exception("Deserialization failed");
+        }
+    }
+
+    private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
+    {
+        if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
+        {
+            var stream = new MemoryStream();
+            binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
+
+            var response = request.Respond()
+                .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
+                .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
+            return response;
+        }
+        else
+        {
+            var serialized = Serialization.Serialize(result);
+
+            var response = request.Respond()
+                .Type(new FlexibleContentType(ContentType.ApplicationJson))
+                .Content(new ResourceContent(Resource.FromString(serialized).Build()));
+            return response;
+        }
+    }
+
+    /// <summary>
+    /// Handler for all database requests
+    /// </summary>
+    /// <param name="request"></param>
+    /// <returns></returns>
+    private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
+    {
+        var endpoint = request.Target.Current?.Value ?? "";
+        if (endpoint.StartsWith("QueryMultiple"))
+        {
+            var result = QueryMultiple(request, requestData);
+            return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
+        }
+
+        foreach (var (name, method) in methodMap)
+            if (endpoint.Length > name.Length && endpoint.StartsWith(name))
+            {
+                var entityName = endpoint[name.Length..];
+
+                var entityType = GetEntity(entityName);
+                if (entityType != null)
+                {
+                    if (entityType.IsAssignableTo(typeof(ISecure)))
+                    {
+                        Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
+                        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+                    }
+
+                    var resolvedMethod = method.MakeGenericMethod(entityType);
+                    var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
+
+                    return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
+                }
+
+                Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
+                    $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
+                return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+            }
+
+        Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
+        return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
+    }
+
+    private Dictionary<string, Type>? _persistentRemotable;
+
+    private Type? GetEntity(string entityName)
+    {
+        _persistentRemotable ??= CoreUtils.TypeList(
+            e => e.IsSubclassOf(typeof(Entity)) &&
+                 e.GetInterfaces().Contains(typeof(IRemotable)) &&
+                 e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
+        return _persistentRemotable.GetValueOrDefault(entityName);
+    }
+
+    #endregion
+
+    #region Installer
+
+    
+    private IResponseBuilder GetUpdateFile(IRequest request)
+    {
+        var endpoint = request.Target.Current;
+        switch (endpoint?.Value)
+        {
+            case "version":
+                return request.Respond()
+                    .Type(new FlexibleContentType(ContentType.TextPlain))
+                    .Content(new ResourceContent(Resource.FromString(Client.Version()).Build()));
+            case "releasenotes" or "release_notes":
+                return request.Respond()
+                    .Type(new FlexibleContentType(ContentType.TextPlain))
+                    .Content(new ResourceContent(Resource.FromString(Client.ReleaseNotes()).Build()));
+            case "install" or "install_desktop":
+                return request.Respond()
+                    .Header("Content-Disposition", $"attachment; filename=PRSDesktopSetup.exe")
+                    .Content(new ResourceContent(new ByteArrayResource(Client.Installer() ?? Array.Empty<byte>(), "PRSDesktopSetup.exe", new FlexibleContentType(ContentType.ApplicationOctetStream), null)));
+        }
+        return request.Respond().Status(ResponseStatus.NotFound);
+    }
+
+    #endregion
+
+    #region GenHTTP stuff
+    public IHandler Parent { get; }
+
+    public ValueTask PrepareAsync()
+    {
+        return new ValueTask();
+    }
+
+    public IEnumerable<ContentElement> GetContent(IRequest request)
+    {
+        return Enumerable.Empty<ContentElement>();
+    }
+
+    #endregion
+}
+
+internal class HTTPDatabaseProxyEngine : DatabaseProxyEngine<HTTPDatabaseProxyProperties>
+{
+    private Listener<HTTPDatabaseProxyHandler, HTTPDatabaseProxyProperties>? Listener;
+
+    protected override void RunProxy()
+    {
+        Logger.Send(LogType.Information, "", "Registering Classes");
+
+        Logger.Send(LogType.Information, "", "Starting Listener on port " + Properties.ListenPort);
+
+        try
+        {
+            Listener = new Listener<HTTPDatabaseProxyHandler, HTTPDatabaseProxyProperties>(Properties);
+            Listener.InitHTTPS((ushort)Properties.ListenPort, CertificateFileName());
+            Listener.Start();
+        }
+        catch (Exception eListen)
+        {
+            Logger.Send(LogType.Error, ClientFactory.UserID, eListen.Message);
+        }
+    }
+    private string CertificateFileName() =>
+        !string.IsNullOrWhiteSpace(Properties.CertificateFile)
+        ? Properties.CertificateFile
+        : CertificateEngine.CertificateFile;
+
+    public override void Stop()
+    {
+        Logger.Send(LogType.Information, "", "Stopping");
+        Listener?.Stop();
+    }
+}