RestListener.cs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. using System.Net;
  2. using System.Reflection;
  3. using System.Security.Cryptography.X509Certificates;
  4. using GenHTTP.Api.Content;
  5. using GenHTTP.Api.Infrastructure;
  6. using GenHTTP.Api.Protocol;
  7. using GenHTTP.Engine;
  8. using GenHTTP.Modules.IO;
  9. using GenHTTP.Modules.IO.FileSystem;
  10. using GenHTTP.Modules.IO.Streaming;
  11. using GenHTTP.Modules.Practices;
  12. using InABox.Clients;
  13. using InABox.Core;
  14. using InABox.Database;
  15. using InABox.Remote.Shared;
  16. using InABox.Server.WebSocket;
  17. using InABox.WebSocket.Shared;
  18. using NPOI.POIFS.Crypt.Dsig;
  19. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  20. namespace InABox.API
  21. {
  22. public class RestHandler : IHandler
  23. {
  24. private readonly List<string> endpoints;
  25. private readonly List<string> operations;
  26. private int? WebSocketPort;
  27. public RestHandler(IHandler parent, int? webSocketPort)
  28. {
  29. WebSocketPort = webSocketPort;
  30. Parent = parent;
  31. endpoints = new();
  32. operations = new();
  33. var types = CoreUtils.TypeList(
  34. x => x.IsSubclassOf(typeof(Entity))
  35. && x.GetInterfaces().Contains(typeof(IRemotable))
  36. );
  37. var DBTypes = DbFactory.SupportedTypes();
  38. foreach (var t in types)
  39. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  40. {
  41. operations.Add(t.EntityName().Replace(".", "_"));
  42. endpoints.Add(string.Format("List{0}", t.Name));
  43. endpoints.Add(string.Format("Load{0}", t.Name));
  44. endpoints.Add(string.Format("Save{0}", t.Name));
  45. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  46. endpoints.Add(string.Format("Delete{0}", t.Name));
  47. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  48. }
  49. endpoints.Add("QueryMultiple");
  50. }
  51. private RequestData GetRequestData(IRequest request)
  52. {
  53. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  54. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  55. {
  56. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  57. }
  58. var data = new RequestData(settings);
  59. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  60. {
  61. data.RequestFormat = format;
  62. }
  63. data.ResponseFormat = SerializationFormat.Json;
  64. if (request.Query.TryGetValue("responseFormat", out formatString) && Enum.TryParse<SerializationFormat>(formatString, out format))
  65. {
  66. data.ResponseFormat = format;
  67. }
  68. return data;
  69. }
  70. /// <summary>
  71. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  72. /// </summary>
  73. /// <param name="request"></param>
  74. /// <returns></returns>
  75. public ValueTask<IResponse?> HandleAsync(IRequest request)
  76. {
  77. try
  78. {
  79. switch (request.Method.KnownMethod)
  80. {
  81. case RequestMethod.GET:
  82. case RequestMethod.HEAD:
  83. var current = request.Target.Current?.Value;
  84. if (String.Equals(current,"update"))
  85. {
  86. request.Target.Advance();
  87. current = request.Target.Current?.Value;
  88. }
  89. switch (current)
  90. {
  91. case "operations" or "supported_operations":
  92. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  93. case "classes" or "supported_classes":
  94. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  95. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  96. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  97. case "info":
  98. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  99. case "ping":
  100. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.OK).Build());
  101. }
  102. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  103. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  104. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  105. case RequestMethod.POST:
  106. var target = request.Target.Current;
  107. if (target is not null)
  108. {
  109. var data = GetRequestData(request);
  110. return target.Value switch
  111. {
  112. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  113. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  114. "notify" => new ValueTask<IResponse?>(GetNotify(request, data).Build()),
  115. _ => HandleDatabaseRequest(request, data),
  116. };
  117. }
  118. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  119. }
  120. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  121. }
  122. catch(Exception e)
  123. {
  124. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  125. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  126. }
  127. }
  128. /// <summary>
  129. /// Returns a JSON list of operation names; used for checking support of operations client side
  130. /// </summary>
  131. /// <param name="request"></param>
  132. /// <returns></returns>
  133. private IResponseBuilder GetSupportedOperations(IRequest request)
  134. {
  135. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  136. return request.Respond()
  137. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  138. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  139. }
  140. /// <summary>
  141. /// Returns a JSON list of class names; used for checking support of operations client side
  142. /// </summary>
  143. /// <param name="request"></param>
  144. /// <returns></returns>
  145. private IResponseBuilder GetSupportedClasses(IRequest request)
  146. {
  147. var serialized = Serialization.Serialize(operations) ?? "";
  148. return request.Respond()
  149. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  150. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  151. }
  152. /// <summary>
  153. /// Returns the Splash Logo and Color Scheme for this Database
  154. /// </summary>
  155. /// <param name="request"></param>
  156. /// <returns></returns>
  157. private IResponseBuilder GetServerInfo(IRequest request)
  158. {
  159. var data = GetRequestData(request);
  160. InfoResponse response = RestService.Info(new InfoRequest());
  161. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  162. }
  163. /// <summary>
  164. /// Gets port for web socket
  165. /// </summary>
  166. /// <param name="request"></param>
  167. /// <returns></returns>
  168. private IResponseBuilder GetNotify(IRequest request, RequestData data)
  169. {
  170. var requestObj = Deserialize<NotifyRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  171. if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  172. {
  173. return request.Respond().Status(ResponseStatus.NotFound);
  174. }
  175. var response = new NotifyResponse
  176. {
  177. Status = StatusCode.OK,
  178. SocketPort = WebSocketPort
  179. };
  180. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  181. }
  182. #region Authentication
  183. private IResponseBuilder Validate(IRequest request, RequestData data)
  184. {
  185. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  186. var response = RestService.Validate(requestObj);
  187. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  188. }
  189. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  190. {
  191. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  192. var response = RestService.Check2FA(requestObj);
  193. return SerializeResponse(request, data.ResponseFormat, data.BinarySerializationSettings, response);
  194. }
  195. #endregion
  196. #region Database
  197. private static MethodInfo GetMethod(string name) =>
  198. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  199. ?? throw new Exception($"Invalid method '{name}'");
  200. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  201. {
  202. new("List", GetMethod(nameof(List))),
  203. new("Save", GetMethod(nameof(Save))),
  204. new("Delete", GetMethod(nameof(Delete))),
  205. new("MultiSave", GetMethod(nameof(MultiSave))),
  206. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  207. };
  208. private class RequestData
  209. {
  210. public SerializationFormat RequestFormat { get; set; }
  211. public SerializationFormat ResponseFormat { get; set; }
  212. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  213. public RequestData(BinarySerializationSettings binarySerializationSettings)
  214. {
  215. BinarySerializationSettings = binarySerializationSettings;
  216. }
  217. }
  218. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  219. {
  220. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  221. return RestService<T>.List(requestObject);
  222. }
  223. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  224. {
  225. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  226. return RestService<T>.Save(requestObject);
  227. }
  228. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  229. {
  230. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  231. return RestService<T>.Delete(requestObject);
  232. }
  233. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  234. {
  235. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  236. return RestService<T>.MultiSave(requestObject);
  237. }
  238. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  239. {
  240. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  241. return RestService<T>.MultiDelete(requestObject);
  242. }
  243. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  244. {
  245. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  246. return RestService.QueryMultiple(requestObject, false);
  247. }
  248. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  249. {
  250. if (stream is null)
  251. throw new Exception("Stream is null");
  252. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  253. {
  254. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  255. }
  256. else
  257. {
  258. var str = new StreamReader(stream).ReadToEnd();
  259. return Serialization.Deserialize<T>(str, strict)
  260. ?? throw new Exception("Deserialization failed");
  261. }
  262. }
  263. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  264. {
  265. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  266. {
  267. var stream = new MemoryStream();
  268. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  269. var response = request.Respond()
  270. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  271. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  272. return response;
  273. }
  274. else
  275. {
  276. var serialized = Serialization.Serialize(result);
  277. var response = request.Respond()
  278. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  279. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  280. return response;
  281. }
  282. }
  283. /// <summary>
  284. /// Handler for all database requests
  285. /// </summary>
  286. /// <param name="request"></param>
  287. /// <returns></returns>
  288. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request, RequestData requestData)
  289. {
  290. var endpoint = request.Target.Current?.Value ?? "";
  291. if (endpoint.StartsWith("QueryMultiple"))
  292. {
  293. var result = QueryMultiple(request, requestData);
  294. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  295. }
  296. foreach (var (name, method) in methodMap)
  297. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  298. {
  299. var entityName = endpoint[name.Length..];
  300. var entityType = GetEntity(entityName);
  301. if (entityType != null)
  302. {
  303. if (entityType.IsAssignableTo(typeof(ISecure)))
  304. {
  305. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  306. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  307. }
  308. var resolvedMethod = method.MakeGenericMethod(entityType);
  309. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  310. return new ValueTask<IResponse?>(SerializeResponse(request, requestData.ResponseFormat, requestData.BinarySerializationSettings, result).Build());
  311. }
  312. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  313. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  314. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  315. }
  316. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  317. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  318. }
  319. private Dictionary<string, Type>? _persistentRemotable;
  320. private Type? GetEntity(string entityName)
  321. {
  322. _persistentRemotable ??= CoreUtils.TypeList(
  323. e => e.IsSubclassOf(typeof(Entity)) &&
  324. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  325. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  326. return _persistentRemotable.GetValueOrDefault(entityName);
  327. }
  328. #endregion
  329. #region Installer
  330. private IResponseBuilder GetUpdateFile(IRequest request)
  331. {
  332. var endpoint = request.Target.Current;
  333. string? filename = null;
  334. switch (endpoint?.Value)
  335. {
  336. case "version":
  337. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  338. break;
  339. case "releasenotes" or "release_notes":
  340. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  341. break;
  342. case "install" or "install_desktop":
  343. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  344. break;
  345. }
  346. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  347. if (File.Exists(filename))
  348. {
  349. return request.Respond()
  350. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  351. .Content(new ResourceContent(
  352. new FileResource(new FileInfo(filename), null, null)));
  353. }
  354. return request.Respond().Status(ResponseStatus.NotFound);
  355. }
  356. #endregion
  357. #region GenHTTP stuff
  358. public IHandler Parent { get; }
  359. public ValueTask PrepareAsync()
  360. {
  361. return new ValueTask();
  362. }
  363. public IEnumerable<ContentElement> GetContent(IRequest request)
  364. {
  365. return Enumerable.Empty<ContentElement>();
  366. }
  367. #endregion
  368. }
  369. public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
  370. {
  371. private readonly List<IConcernBuilder> _Concerns = new();
  372. private int? WebSocketPort;
  373. public RestHandlerBuilder(int? webSocketPort)
  374. {
  375. WebSocketPort = webSocketPort;
  376. }
  377. public RestHandlerBuilder Add(IConcernBuilder concern)
  378. {
  379. _Concerns.Add(concern);
  380. return this;
  381. }
  382. public IHandler Build(IHandler parent)
  383. {
  384. return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
  385. }
  386. }
  387. class RestNotifier : Notifier
  388. {
  389. private WebSocketServer SocketServer;
  390. public int Port => SocketServer.Port;
  391. public RestNotifier(int port)
  392. {
  393. SocketServer = new WebSocketServer(port);
  394. SocketServer.Poll += SocketServer_Poll;
  395. }
  396. private void SocketServer_Poll(NotifyState.Session session)
  397. {
  398. Notify.Poll(session.SessionID);
  399. }
  400. public void Start()
  401. {
  402. SocketServer.Start();
  403. }
  404. public void Stop()
  405. {
  406. SocketServer.Stop();
  407. }
  408. protected override void NotifyAll<TNotification>(TNotification notification)
  409. {
  410. SocketServer.Push(notification);
  411. }
  412. protected override void NotifySession(Guid session, Type TNotification, BaseObject notification)
  413. {
  414. SocketServer.Push(session, TNotification, notification);
  415. }
  416. protected override void NotifySession<TNotification>(Guid session, TNotification notification)
  417. {
  418. SocketServer.Push(session, notification);
  419. }
  420. protected override IEnumerable<Guid> GetUserSessions(Guid userID)
  421. {
  422. return CredentialsCache.GetUserSessions(userID);
  423. }
  424. protected override IEnumerable<Guid> GetSessions(Platform platform)
  425. {
  426. return SocketServer.GetSessions(platform);
  427. }
  428. }
  429. public static class RestListener
  430. {
  431. private static IServerHost? host;
  432. private static X509Certificate2? certificate;
  433. private static RestNotifier? notifier;
  434. public static X509Certificate2? Certificate { get => certificate; }
  435. public static void Start()
  436. {
  437. host?.Start();
  438. notifier?.Start();
  439. }
  440. public static void Stop()
  441. {
  442. host?.Stop();
  443. notifier?.Stop();
  444. }
  445. public static void InitCertificate(ushort port, X509Certificate2 certificate)
  446. {
  447. RestListener.certificate = certificate;
  448. RestService.IsHTTPS = true;
  449. host?.Bind(IPAddress.Any, port, certificate);
  450. }
  451. public static void InitCertificate(ushort port, string certificateFile)
  452. {
  453. InitCertificate(port, new X509Certificate2(certificateFile));
  454. }
  455. public static void InitPort(ushort port)
  456. {
  457. RestService.IsHTTPS = false;
  458. host?.Bind(IPAddress.Any, port);
  459. }
  460. /// <summary>
  461. /// Clears certificate and host information, and stops the listener.
  462. /// </summary>
  463. public static void Clear()
  464. {
  465. host?.Stop();
  466. host = null;
  467. notifier?.Stop();
  468. notifier = null;
  469. certificate = null;
  470. }
  471. public static void Init(int webSocketPort)
  472. {
  473. if(webSocketPort != 0)
  474. {
  475. notifier = new RestNotifier(webSocketPort);
  476. Notify.AddNotifier(notifier);
  477. }
  478. host = Host.Create();
  479. host.Handler(new RestHandlerBuilder(notifier?.Port))
  480. .Defaults().Backlog(1024);
  481. }
  482. }
  483. }