RestListener.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. using System.Collections.Generic;
  2. using System.Diagnostics.CodeAnalysis;
  3. using System.Net;
  4. using System.Reflection;
  5. using System.Security.Cryptography.X509Certificates;
  6. using GenHTTP.Api.Content;
  7. using GenHTTP.Api.Infrastructure;
  8. using GenHTTP.Api.Protocol;
  9. using GenHTTP.Engine;
  10. using GenHTTP.Modules.IO;
  11. using GenHTTP.Modules.IO.FileSystem;
  12. using GenHTTP.Modules.IO.Streaming;
  13. using GenHTTP.Modules.Practices;
  14. using InABox.Clients;
  15. using InABox.Core;
  16. using InABox.Database;
  17. using InABox.Server.WebSocket;
  18. using InABox.WebSocket.Shared;
  19. using NPOI.SS.Formula.Functions;
  20. using NPOI.XSSF.Streaming.Values;
  21. using StreamContent = GenHTTP.Modules.IO.Streaming.StreamContent;
  22. namespace InABox.API
  23. {
  24. public class RestHandler : IHandler
  25. {
  26. private readonly List<string> endpoints;
  27. private readonly List<string> operations;
  28. private int? WebSocketPort;
  29. public RestHandler(IHandler parent, int? webSocketPort)
  30. {
  31. WebSocketPort = webSocketPort;
  32. Parent = parent;
  33. endpoints = new();
  34. operations = new();
  35. var types = CoreUtils.TypeList(
  36. x => x.IsSubclassOf(typeof(Entity))
  37. && x.GetInterfaces().Contains(typeof(IRemotable))
  38. );
  39. var DBTypes = DbFactory.SupportedTypes();
  40. foreach (var t in types)
  41. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  42. {
  43. operations.Add(t.EntityName().Replace(".", "_"));
  44. endpoints.Add(string.Format("List{0}", t.Name));
  45. endpoints.Add(string.Format("Load{0}", t.Name));
  46. endpoints.Add(string.Format("Save{0}", t.Name));
  47. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  48. endpoints.Add(string.Format("Delete{0}", t.Name));
  49. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  50. }
  51. endpoints.Add("QueryMultiple");
  52. }
  53. /// <summary>
  54. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  55. /// </summary>
  56. /// <param name="request"></param>
  57. /// <returns></returns>
  58. public ValueTask<IResponse?> HandleAsync(IRequest request)
  59. {
  60. try
  61. {
  62. switch (request.Method.KnownMethod)
  63. {
  64. case RequestMethod.GET:
  65. case RequestMethod.HEAD:
  66. var current = request.Target.Current?.Value;
  67. if (String.Equals(current,"update"))
  68. {
  69. request.Target.Advance();
  70. current = request.Target.Current?.Value;
  71. }
  72. switch (current)
  73. {
  74. case "operations" or "supported_operations":
  75. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  76. case "classes" or "supported_classes":
  77. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  78. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  79. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  80. case "info":
  81. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  82. }
  83. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  84. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  85. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  86. case RequestMethod.POST:
  87. var target = request.Target.Current;
  88. if (target is not null)
  89. {
  90. return target.Value switch
  91. {
  92. "validate" => new ValueTask<IResponse?>(Validate(request).Build()),
  93. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request).Build()),
  94. "notify" => new ValueTask<IResponse?>(GetNotify(request).Build()),
  95. _ => HandleDatabaseRequest(request),
  96. };
  97. }
  98. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  99. }
  100. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  101. }
  102. catch(Exception e)
  103. {
  104. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  105. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  106. }
  107. }
  108. /// <summary>
  109. /// Returns a JSON list of operation names; used for checking support of operations client side
  110. /// </summary>
  111. /// <param name="request"></param>
  112. /// <returns></returns>
  113. private IResponseBuilder GetSupportedOperations(IRequest request)
  114. {
  115. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  116. return request.Respond()
  117. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  118. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  119. }
  120. /// <summary>
  121. /// Returns a JSON list of class names; used for checking support of operations client side
  122. /// </summary>
  123. /// <param name="request"></param>
  124. /// <returns></returns>
  125. private IResponseBuilder GetSupportedClasses(IRequest request)
  126. {
  127. var serialized = Serialization.Serialize(operations) ?? "";
  128. return request.Respond()
  129. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  130. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  131. }
  132. /// <summary>
  133. /// Returns the Splash Logo and Color Scheme for this Database
  134. /// </summary>
  135. /// <param name="request"></param>
  136. /// <returns></returns>
  137. private IResponseBuilder GetServerInfo(IRequest request)
  138. {
  139. InfoResponse response = RestService.Info(new InfoRequest());
  140. var serialized = Core.Serialization.Serialize(response, true) ?? "";
  141. return request.Respond()
  142. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  143. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  144. }
  145. /// <summary>
  146. /// Gets port for web socket
  147. /// </summary>
  148. /// <param name="request"></param>
  149. /// <returns></returns>
  150. private IResponseBuilder GetNotify(IRequest request)
  151. {
  152. var requestObj = Deserialize<NotifyRequest>(request.Content, true);
  153. if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  154. {
  155. return request.Respond().Status(ResponseStatus.NotFound);
  156. }
  157. var response = new NotifyResponse
  158. {
  159. Status = StatusCode.OK,
  160. SocketPort = WebSocketPort
  161. };
  162. var serialized = Serialization.Serialize(response);
  163. return request.Respond()
  164. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  165. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  166. }
  167. #region Authentication
  168. private IResponseBuilder Validate(IRequest request)
  169. {
  170. var requestObj = Deserialize<ValidateRequest>(request.Content, true);
  171. var response = RestService.Validate(requestObj);
  172. var serialized = Serialization.Serialize(response);
  173. return request.Respond()
  174. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  175. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  176. }
  177. private IResponseBuilder Check2FA(IRequest request)
  178. {
  179. var requestObj = Deserialize<Check2FARequest>(request.Content, true);
  180. var response = RestService.Check2FA(requestObj);
  181. var serialized = Serialization.Serialize(response);
  182. return request.Respond()
  183. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  184. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  185. }
  186. #endregion
  187. #region Database
  188. private static MethodInfo GetMethod(string name) =>
  189. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  190. ?? throw new Exception($"Invalid method '{name}'");
  191. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  192. {
  193. new("List", GetMethod(nameof(List))),
  194. new("Save", GetMethod(nameof(Save))),
  195. new("Delete", GetMethod(nameof(Delete))),
  196. new("MultiSave", GetMethod(nameof(MultiSave))),
  197. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  198. };
  199. private static QueryResponse<T> List<T>(IRequest request) where T : Entity, new()
  200. {
  201. var requestObject = Deserialize<QueryRequest<T>>(request.Content, true);
  202. return RestService<T>.List(requestObject);
  203. }
  204. private static SaveResponse<T> Save<T>(IRequest request) where T : Entity, new()
  205. {
  206. var requestObject = Deserialize<SaveRequest<T>>(request.Content, true);
  207. return RestService<T>.Save(requestObject);
  208. }
  209. private static DeleteResponse<T> Delete<T>(IRequest request) where T : Entity, new()
  210. {
  211. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, true);
  212. return RestService<T>.Delete(requestObject);
  213. }
  214. private static MultiSaveResponse<T> MultiSave<T>(IRequest request) where T : Entity, new()
  215. {
  216. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, true);
  217. return RestService<T>.MultiSave(requestObject);
  218. }
  219. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request) where T : Entity, new()
  220. {
  221. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, true);
  222. return RestService<T>.MultiDelete(requestObject);
  223. }
  224. private static T Deserialize<T>(Stream? stream, bool strict = false)
  225. {
  226. if (stream is null)
  227. throw new Exception("Stream is null");
  228. var str = new StreamReader(stream).ReadToEnd();
  229. return Serialization.Deserialize<T>(str, strict)
  230. ?? throw new Exception("Deserialization failed");
  231. }
  232. /// <summary>
  233. /// Handler for all database requests
  234. /// </summary>
  235. /// <param name="request"></param>
  236. /// <returns></returns>
  237. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request)
  238. {
  239. var endpoint = request.Target.Current.Value;
  240. if (endpoint.StartsWith("QueryMultiple"))
  241. {
  242. var requestObject = Deserialize<MultiQueryRequest>(request.Content, true);
  243. var result = RestService.QueryMultiple(requestObject, false);
  244. var serialized = Serialization.Serialize(result);
  245. var response = request.Respond()
  246. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  247. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  248. return new ValueTask<IResponse?>(response.Build());
  249. }
  250. foreach (var (name, method) in methodMap)
  251. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  252. {
  253. var entityName = endpoint[name.Length..];
  254. var entityType = GetEntity(entityName);
  255. if (entityType != null)
  256. {
  257. if (entityType.IsAssignableTo(typeof(ISecure)))
  258. {
  259. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  260. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  261. }
  262. var resolvedMethod = method.MakeGenericMethod(entityType);
  263. var result = resolvedMethod.Invoke(null, new object[] { request }) as Response;
  264. var serialized = Serialization.Serialize(result);
  265. var response = request.Respond()
  266. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  267. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  268. return new ValueTask<IResponse?>(response.Build());
  269. }
  270. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  271. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  272. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  273. }
  274. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  275. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  276. }
  277. private Dictionary<string, Type>? _persistentRemotable;
  278. private Type? GetEntity(string entityName)
  279. {
  280. _persistentRemotable ??= CoreUtils.TypeList(
  281. e => e.IsSubclassOf(typeof(Entity)) &&
  282. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  283. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  284. return _persistentRemotable.GetValueOrDefault(entityName);
  285. }
  286. #endregion
  287. #region Installer
  288. private IResponseBuilder GetUpdateFile(IRequest request)
  289. {
  290. var endpoint = request.Target.Current;
  291. string? filename = null;
  292. switch (endpoint?.Value)
  293. {
  294. case "version":
  295. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  296. break;
  297. case "releasenotes" or "release_notes":
  298. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  299. break;
  300. case "install" or "install_desktop":
  301. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  302. break;
  303. }
  304. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  305. if (File.Exists(filename))
  306. {
  307. return request.Respond()
  308. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  309. .Content(new ResourceContent(
  310. new FileResource(new FileInfo(filename), null, null)));
  311. }
  312. return request.Respond().Status(ResponseStatus.NotFound);
  313. }
  314. #endregion
  315. #region GenHTTP stuff
  316. public IHandler Parent { get; }
  317. public ValueTask PrepareAsync()
  318. {
  319. return new ValueTask();
  320. }
  321. public IEnumerable<ContentElement> GetContent(IRequest request)
  322. {
  323. return Enumerable.Empty<ContentElement>();
  324. }
  325. #endregion
  326. }
  327. public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
  328. {
  329. private readonly List<IConcernBuilder> _Concerns = new();
  330. private int? WebSocketPort;
  331. public RestHandlerBuilder(int? webSocketPort)
  332. {
  333. WebSocketPort = webSocketPort;
  334. }
  335. public RestHandlerBuilder Add(IConcernBuilder concern)
  336. {
  337. _Concerns.Add(concern);
  338. return this;
  339. }
  340. public IHandler Build(IHandler parent)
  341. {
  342. return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
  343. }
  344. }
  345. class RestNotifier : Notifier
  346. {
  347. private WebSocketServer SocketServer;
  348. public int Port => SocketServer.Port;
  349. public RestNotifier(int port)
  350. {
  351. SocketServer = new WebSocketServer(port);
  352. SocketServer.Poll += SocketServer_Poll;
  353. }
  354. private void SocketServer_Poll(NotifyState.Session session)
  355. {
  356. Poll(session.SessionID);
  357. }
  358. public void Start()
  359. {
  360. SocketServer.Start();
  361. }
  362. public void Stop()
  363. {
  364. SocketServer.Stop();
  365. }
  366. protected override void NotifyAll<TNotification>(TNotification notification)
  367. {
  368. SocketServer.Push(notification);
  369. }
  370. protected override void NotifySession(Guid session, Type TNotification, object? notification)
  371. {
  372. SocketServer.Push(session, TNotification, notification);
  373. }
  374. protected override void NotifySession<TNotification>(Guid session, TNotification notification)
  375. {
  376. SocketServer.Push(session, notification);
  377. }
  378. protected override IEnumerable<Guid> GetUserSessions(Guid userID)
  379. {
  380. return CredentialsCache.GetUserSessions(userID);
  381. }
  382. protected override IEnumerable<Guid> GetSessions(Platform platform)
  383. {
  384. return SocketServer.GetSessions(platform);
  385. }
  386. }
  387. public static class RestListener
  388. {
  389. private static IServerHost? host;
  390. private static X509Certificate2? certificate;
  391. private static RestNotifier? notifier;
  392. public static X509Certificate2? Certificate { get => certificate; }
  393. public static void Start()
  394. {
  395. host?.Start();
  396. notifier?.Start();
  397. }
  398. public static void Stop()
  399. {
  400. host?.Stop();
  401. notifier?.Stop();
  402. }
  403. public static void InitCertificate(ushort port, X509Certificate2 certificate)
  404. {
  405. RestListener.certificate = certificate;
  406. host?.Bind(IPAddress.Any, port, certificate);
  407. }
  408. public static void InitCertificate(ushort port, string certificateFile)
  409. {
  410. InitCertificate(port, new X509Certificate2(certificateFile));
  411. }
  412. public static void InitPort(ushort port)
  413. {
  414. host?.Bind(IPAddress.Any, port);
  415. }
  416. /// <summary>
  417. /// Clears certificate and host information, and stops the listener.
  418. /// </summary>
  419. public static void Clear()
  420. {
  421. host?.Stop();
  422. host = null;
  423. notifier?.Stop();
  424. notifier = null;
  425. certificate = null;
  426. }
  427. public static void Init(int webSocketPort)
  428. {
  429. if(webSocketPort != 0)
  430. {
  431. notifier = new RestNotifier(webSocketPort);
  432. Notify.Notifier = notifier;
  433. }
  434. host = Host.Create();
  435. host.Handler(new RestHandlerBuilder(notifier?.Port))
  436. .Defaults();
  437. }
  438. }
  439. }