RestListener.cs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. using System.Net;
  2. using System.Reflection;
  3. using System.Security.Cryptography.X509Certificates;
  4. using GenHTTP.Api.Content;
  5. using GenHTTP.Api.Infrastructure;
  6. using GenHTTP.Api.Protocol;
  7. using GenHTTP.Engine;
  8. using GenHTTP.Modules.IO;
  9. using GenHTTP.Modules.IO.FileSystem;
  10. using GenHTTP.Modules.IO.Streaming;
  11. using GenHTTP.Modules.Practices;
  12. using InABox.Clients;
  13. using InABox.Core;
  14. using InABox.Database;
  15. using InABox.Remote.Shared;
  16. using InABox.Server.WebSocket;
  17. using InABox.WebSocket.Shared;
  18. using RequestMethod = GenHTTP.Api.Protocol.RequestMethod;
  19. namespace InABox.API
  20. {
  21. public class RestHandler : IHandler
  22. {
  23. private readonly List<string> endpoints;
  24. private readonly List<string> operations;
  25. private int? WebSocketPort;
  26. public RestHandler(IHandler parent, int? webSocketPort)
  27. {
  28. WebSocketPort = webSocketPort;
  29. Parent = parent;
  30. endpoints = new();
  31. operations = new();
  32. var types = CoreUtils.TypeList(
  33. x => x.IsSubclassOf(typeof(Entity))
  34. && x.GetInterfaces().Contains(typeof(IRemotable))
  35. );
  36. var DBTypes = DbFactory.SupportedTypes();
  37. foreach (var t in types)
  38. if (DBTypes.Contains(t.EntityName().Replace(".", "_")))
  39. {
  40. operations.Add(t.EntityName().Replace(".", "_"));
  41. endpoints.Add(string.Format("List{0}", t.Name));
  42. endpoints.Add(string.Format("Load{0}", t.Name));
  43. endpoints.Add(string.Format("Save{0}", t.Name));
  44. endpoints.Add(string.Format("MultiSave{0}", t.Name));
  45. endpoints.Add(string.Format("Delete{0}", t.Name));
  46. endpoints.Add(string.Format("MultiDelete{0}", t.Name));
  47. }
  48. endpoints.Add("QueryMultiple");
  49. }
  50. private RequestData GetRequestData(IRequest request)
  51. {
  52. BinarySerializationSettings settings = BinarySerializationSettings.V1_0;
  53. if (request.Query.TryGetValue("serializationVersion", out var versionString))
  54. {
  55. settings = BinarySerializationSettings.ConvertVersionString(versionString);
  56. }
  57. var data = new RequestData(settings);
  58. if (request.Query.TryGetValue("format", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  59. {
  60. data.RequestFormat = format;
  61. }
  62. return data;
  63. }
  64. /// <summary>
  65. /// The main handler for the server; an HTTP request comes in, an HTTP response goes out.
  66. /// </summary>
  67. /// <param name="request"></param>
  68. /// <returns></returns>
  69. public ValueTask<IResponse?> HandleAsync(IRequest request)
  70. {
  71. try
  72. {
  73. switch (request.Method.KnownMethod)
  74. {
  75. case RequestMethod.GET:
  76. case RequestMethod.HEAD:
  77. var current = request.Target.Current?.Value;
  78. if (String.Equals(current,"update"))
  79. {
  80. request.Target.Advance();
  81. current = request.Target.Current?.Value;
  82. }
  83. switch (current)
  84. {
  85. case "operations" or "supported_operations":
  86. return new ValueTask<IResponse?>(GetSupportedOperations(request).Build());
  87. case "classes" or "supported_classes":
  88. return new ValueTask<IResponse?>(GetSupportedClasses(request).Build());
  89. case "version" or "releasenotes" or "release_notes" or "install" or "install_desktop":
  90. return new ValueTask<IResponse?>(GetUpdateFile(request).Build());
  91. case "info":
  92. return new ValueTask<IResponse?>(GetServerInfo(request).Build());
  93. }
  94. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  95. string.Format("GET/HEAD request to endpoint '{0}' is unresolved, because it does not exist", current));
  96. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  97. case RequestMethod.POST:
  98. var target = request.Target.Current;
  99. if (target is not null)
  100. {
  101. var data = GetRequestData(request);
  102. return target.Value switch
  103. {
  104. "validate" => new ValueTask<IResponse?>(Validate(request, data).Build()),
  105. "check_2fa" => new ValueTask<IResponse?>(Check2FA(request, data).Build()),
  106. "notify" => new ValueTask<IResponse?>(GetNotify(request, data).Build()),
  107. _ => HandleDatabaseRequest(request),
  108. };
  109. }
  110. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  111. }
  112. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.MethodNotAllowed).Header("Allow", "GET, POST, HEAD").Build());
  113. }
  114. catch(Exception e)
  115. {
  116. Logger.Send(LogType.Error, "", CoreUtils.FormatException(e));
  117. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.InternalServerError).Build());
  118. }
  119. }
  120. /// <summary>
  121. /// Returns a JSON list of operation names; used for checking support of operations client side
  122. /// </summary>
  123. /// <param name="request"></param>
  124. /// <returns></returns>
  125. private IResponseBuilder GetSupportedOperations(IRequest request)
  126. {
  127. var serialized = Core.Serialization.Serialize(endpoints, true) ?? "";
  128. return request.Respond()
  129. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  130. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  131. }
  132. /// <summary>
  133. /// Returns a JSON list of class names; used for checking support of operations client side
  134. /// </summary>
  135. /// <param name="request"></param>
  136. /// <returns></returns>
  137. private IResponseBuilder GetSupportedClasses(IRequest request)
  138. {
  139. var serialized = Serialization.Serialize(operations) ?? "";
  140. return request.Respond()
  141. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  142. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  143. }
  144. /// <summary>
  145. /// Returns the Splash Logo and Color Scheme for this Database
  146. /// </summary>
  147. /// <param name="request"></param>
  148. /// <returns></returns>
  149. private IResponseBuilder GetServerInfo(IRequest request)
  150. {
  151. InfoResponse response = RestService.Info(new InfoRequest());
  152. var serialized = Core.Serialization.Serialize(response, true) ?? "";
  153. return request.Respond()
  154. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  155. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  156. }
  157. /// <summary>
  158. /// Gets port for web socket
  159. /// </summary>
  160. /// <param name="request"></param>
  161. /// <returns></returns>
  162. private IResponseBuilder GetNotify(IRequest request, RequestData data)
  163. {
  164. var requestObj = Deserialize<NotifyRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  165. if (!CredentialsCache.SessionExists(requestObj.Credentials.Session))
  166. {
  167. return request.Respond().Status(ResponseStatus.NotFound);
  168. }
  169. var response = new NotifyResponse
  170. {
  171. Status = StatusCode.OK,
  172. SocketPort = WebSocketPort
  173. };
  174. var serialized = Serialization.Serialize(response);
  175. return request.Respond()
  176. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  177. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  178. }
  179. #region Authentication
  180. private IResponseBuilder Validate(IRequest request, RequestData data)
  181. {
  182. var requestObj = Deserialize<ValidateRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  183. var response = RestService.Validate(requestObj);
  184. var serialized = Serialization.Serialize(response);
  185. return request.Respond()
  186. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  187. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  188. }
  189. private IResponseBuilder Check2FA(IRequest request, RequestData data)
  190. {
  191. var requestObj = Deserialize<Check2FARequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  192. var response = RestService.Check2FA(requestObj);
  193. var serialized = Serialization.Serialize(response);
  194. return request.Respond()
  195. .Type(FlexibleContentType.Get(ContentType.ApplicationJson))
  196. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  197. }
  198. #endregion
  199. #region Database
  200. private static MethodInfo GetMethod(string name) =>
  201. typeof(RestHandler).GetMethod(name, BindingFlags.NonPublic | BindingFlags.Static)
  202. ?? throw new Exception($"Invalid method '{name}'");
  203. private static readonly List<Tuple<string, MethodInfo>> methodMap = new()
  204. {
  205. new("List", GetMethod(nameof(List))),
  206. new("Save", GetMethod(nameof(Save))),
  207. new("Delete", GetMethod(nameof(Delete))),
  208. new("MultiSave", GetMethod(nameof(MultiSave))),
  209. new("MultiDelete", GetMethod(nameof(MultiDelete)))
  210. };
  211. private class RequestData
  212. {
  213. public SerializationFormat RequestFormat { get; set; }
  214. public BinarySerializationSettings BinarySerializationSettings { get; set; }
  215. public RequestData(BinarySerializationSettings binarySerializationSettings)
  216. {
  217. BinarySerializationSettings = binarySerializationSettings;
  218. }
  219. }
  220. private static QueryResponse<T> List<T>(IRequest request, RequestData data) where T : Entity, new()
  221. {
  222. var requestObject = Deserialize<QueryRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  223. return RestService<T>.List(requestObject);
  224. }
  225. private static SaveResponse<T> Save<T>(IRequest request, RequestData data) where T : Entity, new()
  226. {
  227. var requestObject = Deserialize<SaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  228. return RestService<T>.Save(requestObject);
  229. }
  230. private static DeleteResponse<T> Delete<T>(IRequest request, RequestData data) where T : Entity, new()
  231. {
  232. var requestObject = Deserialize<DeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  233. return RestService<T>.Delete(requestObject);
  234. }
  235. private static MultiSaveResponse<T> MultiSave<T>(IRequest request, RequestData data) where T : Entity, new()
  236. {
  237. var requestObject = Deserialize<MultiSaveRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  238. return RestService<T>.MultiSave(requestObject);
  239. }
  240. private static MultiDeleteResponse<T> MultiDelete<T>(IRequest request, RequestData data) where T : Entity, new()
  241. {
  242. var requestObject = Deserialize<MultiDeleteRequest<T>>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  243. return RestService<T>.MultiDelete(requestObject);
  244. }
  245. private static MultiQueryResponse QueryMultiple(IRequest request, RequestData data)
  246. {
  247. var requestObject = Deserialize<MultiQueryRequest>(request.Content, data.RequestFormat, data.BinarySerializationSettings, true);
  248. return RestService.QueryMultiple(requestObject, false);
  249. }
  250. private static T Deserialize<T>(Stream? stream, SerializationFormat requestFormat, BinarySerializationSettings binarySettings, bool strict = false)
  251. {
  252. if (stream is null)
  253. throw new Exception("Stream is null");
  254. if (requestFormat == SerializationFormat.Binary && typeof(T).IsAssignableTo(typeof(ISerializeBinary)))
  255. {
  256. return (T)Serialization.ReadBinary(typeof(T), stream, binarySettings);
  257. }
  258. else
  259. {
  260. var str = new StreamReader(stream).ReadToEnd();
  261. return Serialization.Deserialize<T>(str, strict)
  262. ?? throw new Exception("Deserialization failed");
  263. }
  264. }
  265. private IResponseBuilder SerializeResponse(IRequest request, SerializationFormat responseFormat, BinarySerializationSettings binarySettings, Response? result)
  266. {
  267. if (responseFormat == SerializationFormat.Binary && result is ISerializeBinary binary)
  268. {
  269. var stream = new MemoryStream();
  270. binary.SerializeBinary(new CoreBinaryWriter(stream, binarySettings));
  271. var response = request.Respond()
  272. .Type(new FlexibleContentType(ContentType.ApplicationOctetStream))
  273. .Content(stream, (ulong?)stream.Length, () => new ValueTask<ulong?>((ulong)stream.GetHashCode()));
  274. return response;
  275. }
  276. else
  277. {
  278. var serialized = Serialization.Serialize(result);
  279. var response = request.Respond()
  280. .Type(new FlexibleContentType(ContentType.ApplicationJson))
  281. .Content(new ResourceContent(Resource.FromString(serialized).Build()));
  282. return response;
  283. }
  284. }
  285. /// <summary>
  286. /// Handler for all database requests
  287. /// </summary>
  288. /// <param name="request"></param>
  289. /// <returns></returns>
  290. private ValueTask<IResponse?> HandleDatabaseRequest(IRequest request)
  291. {
  292. var responseFormat = SerializationFormat.Json;
  293. if (request.Query.TryGetValue("responseFormat", out var formatString) && Enum.TryParse<SerializationFormat>(formatString, out var format))
  294. {
  295. responseFormat = format;
  296. }
  297. var requestData = GetRequestData(request);
  298. var endpoint = request.Target.Current?.Value ?? "";
  299. if (endpoint.StartsWith("QueryMultiple"))
  300. {
  301. var result = QueryMultiple(request, requestData);
  302. return new ValueTask<IResponse?>(SerializeResponse(request, responseFormat, requestData.BinarySerializationSettings, result).Build());
  303. }
  304. foreach (var (name, method) in methodMap)
  305. if (endpoint.Length > name.Length && endpoint.StartsWith(name))
  306. {
  307. var entityName = endpoint[name.Length..];
  308. var entityType = GetEntity(entityName);
  309. if (entityType != null)
  310. {
  311. if (entityType.IsAssignableTo(typeof(ISecure)))
  312. {
  313. Logger.Send(LogType.Error, "", $"{entityType} is a secure entity. Request failed from IP {request.Client.IPAddress}");
  314. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  315. }
  316. var resolvedMethod = method.MakeGenericMethod(entityType);
  317. var result = resolvedMethod.Invoke(null, new object[] { request, requestData }) as Response;
  318. return new ValueTask<IResponse?>(SerializeResponse(request, responseFormat, requestData.BinarySerializationSettings, result).Build());
  319. }
  320. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(),
  321. $"Request to endpoint '{endpoint}' unresolved, because '{entityName}' is not a valid entity");
  322. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  323. }
  324. Logger.Send(LogType.Error, request.Client.IPAddress.ToString(), $"Request to endpoint '{endpoint}' unresolved, because the method does not exist");
  325. return new ValueTask<IResponse?>(request.Respond().Status(ResponseStatus.NotFound).Build());
  326. }
  327. private Dictionary<string, Type>? _persistentRemotable;
  328. private Type? GetEntity(string entityName)
  329. {
  330. _persistentRemotable ??= CoreUtils.TypeList(
  331. e => e.IsSubclassOf(typeof(Entity)) &&
  332. e.GetInterfaces().Contains(typeof(IRemotable)) &&
  333. e.GetInterfaces().Contains(typeof(IPersistent))).ToDictionary(x => x.Name, x => x);
  334. return _persistentRemotable.GetValueOrDefault(entityName);
  335. }
  336. #endregion
  337. #region Installer
  338. private IResponseBuilder GetUpdateFile(IRequest request)
  339. {
  340. var endpoint = request.Target.Current;
  341. string? filename = null;
  342. switch (endpoint?.Value)
  343. {
  344. case "version":
  345. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/version.txt");
  346. break;
  347. case "releasenotes" or "release_notes":
  348. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/Release Notes.txt");
  349. break;
  350. case "install" or "install_desktop":
  351. filename = Path.Combine(CoreUtils.GetCommonAppData(), "update/PRSDesktopSetup.exe");
  352. break;
  353. }
  354. if (filename is null) return request.Respond().Status(ResponseStatus.NotFound);
  355. if (File.Exists(filename))
  356. {
  357. return request.Respond()
  358. .Header("Content-Disposition", $"attachment; filename={Path.GetFileName(filename)}")
  359. .Content(new ResourceContent(
  360. new FileResource(new FileInfo(filename), null, null)));
  361. }
  362. return request.Respond().Status(ResponseStatus.NotFound);
  363. }
  364. #endregion
  365. #region GenHTTP stuff
  366. public IHandler Parent { get; }
  367. public ValueTask PrepareAsync()
  368. {
  369. return new ValueTask();
  370. }
  371. public IEnumerable<ContentElement> GetContent(IRequest request)
  372. {
  373. return Enumerable.Empty<ContentElement>();
  374. }
  375. #endregion
  376. }
  377. public class RestHandlerBuilder : IHandlerBuilder<RestHandlerBuilder>
  378. {
  379. private readonly List<IConcernBuilder> _Concerns = new();
  380. private int? WebSocketPort;
  381. public RestHandlerBuilder(int? webSocketPort)
  382. {
  383. WebSocketPort = webSocketPort;
  384. }
  385. public RestHandlerBuilder Add(IConcernBuilder concern)
  386. {
  387. _Concerns.Add(concern);
  388. return this;
  389. }
  390. public IHandler Build(IHandler parent)
  391. {
  392. return Concerns.Chain(parent, _Concerns, p => new RestHandler(p, WebSocketPort));
  393. }
  394. }
  395. class RestNotifier : Notifier
  396. {
  397. private WebSocketServer SocketServer;
  398. public int Port => SocketServer.Port;
  399. public RestNotifier(int port)
  400. {
  401. SocketServer = new WebSocketServer(port);
  402. SocketServer.Poll += SocketServer_Poll;
  403. }
  404. private void SocketServer_Poll(NotifyState.Session session)
  405. {
  406. Poll(session.SessionID);
  407. }
  408. public void Start()
  409. {
  410. SocketServer.Start();
  411. }
  412. public void Stop()
  413. {
  414. SocketServer.Stop();
  415. }
  416. protected override void NotifyAll<TNotification>(TNotification notification)
  417. {
  418. SocketServer.Push(notification);
  419. }
  420. protected override void NotifySession(Guid session, Type TNotification, object? notification)
  421. {
  422. SocketServer.Push(session, TNotification, notification);
  423. }
  424. protected override void NotifySession<TNotification>(Guid session, TNotification notification)
  425. {
  426. SocketServer.Push(session, notification);
  427. }
  428. protected override IEnumerable<Guid> GetUserSessions(Guid userID)
  429. {
  430. return CredentialsCache.GetUserSessions(userID);
  431. }
  432. protected override IEnumerable<Guid> GetSessions(Platform platform)
  433. {
  434. return SocketServer.GetSessions(platform);
  435. }
  436. }
  437. public static class RestListener
  438. {
  439. private static IServerHost? host;
  440. private static X509Certificate2? certificate;
  441. private static RestNotifier? notifier;
  442. public static X509Certificate2? Certificate { get => certificate; }
  443. public static void Start()
  444. {
  445. host?.Start();
  446. notifier?.Start();
  447. }
  448. public static void Stop()
  449. {
  450. host?.Stop();
  451. notifier?.Stop();
  452. }
  453. public static void InitCertificate(ushort port, X509Certificate2 certificate)
  454. {
  455. RestListener.certificate = certificate;
  456. RestService.IsHTTPS = true;
  457. host?.Bind(IPAddress.Any, port, certificate);
  458. }
  459. public static void InitCertificate(ushort port, string certificateFile)
  460. {
  461. InitCertificate(port, new X509Certificate2(certificateFile));
  462. }
  463. public static void InitPort(ushort port)
  464. {
  465. RestService.IsHTTPS = false;
  466. host?.Bind(IPAddress.Any, port);
  467. }
  468. /// <summary>
  469. /// Clears certificate and host information, and stops the listener.
  470. /// </summary>
  471. public static void Clear()
  472. {
  473. host?.Stop();
  474. host = null;
  475. notifier?.Stop();
  476. notifier = null;
  477. certificate = null;
  478. }
  479. public static void Init(int webSocketPort)
  480. {
  481. if(webSocketPort != 0)
  482. {
  483. notifier = new RestNotifier(webSocketPort);
  484. Notify.Notifier = notifier;
  485. }
  486. host = Host.Create();
  487. host.Handler(new RestHandlerBuilder(notifier?.Port))
  488. .Defaults().Backlog(1024);
  489. }
  490. }
  491. }