DbFactory.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. using System.Reflection;
  2. using FluentResults;
  3. using InABox.Clients;
  4. using InABox.Configuration;
  5. using InABox.Core;
  6. using InABox.Scripting;
  7. namespace InABox.Database;
  8. public class DatabaseMetadata : BaseObject, IGlobalConfigurationSettings
  9. {
  10. public Guid DatabaseID { get; set; } = Guid.NewGuid();
  11. }
  12. public class DbLockedException : Exception
  13. {
  14. public DbLockedException(): base("Database is read-only due to PRS license expiry.") { }
  15. }
  16. public static class DbFactory
  17. {
  18. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  19. private static DatabaseMetadata MetaData { get; set; } = new();
  20. public static Guid ID
  21. {
  22. get => MetaData.DatabaseID;
  23. set
  24. {
  25. MetaData.DatabaseID = value;
  26. SaveMetadata();
  27. }
  28. }
  29. private static IProvider? _provider;
  30. public static IProvider Provider
  31. {
  32. get => _provider ?? throw new Exception("Provider is not set");
  33. set => _provider = value;
  34. }
  35. public static bool IsProviderSet => _provider is not null;
  36. public static string? ColorScheme { get; set; }
  37. public static byte[]? Logo { get; set; }
  38. // See notes in Request.DatabaseInfo class
  39. // Once RPC transport is stable, these settings need
  40. // to be removed
  41. public static int RestPort { get; set; }
  42. public static int RPCPort { get; set; }
  43. //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }
  44. public static IEnumerable<Type> Entities
  45. {
  46. get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }
  47. }
  48. public static Type[] Stores
  49. {
  50. get => stores;
  51. set => SetStoreTypes(value);
  52. }
  53. public static DateTime Expiry { get; set; }
  54. public static void Start()
  55. {
  56. CoreUtils.CheckLicensing();
  57. var status = ValidateSchema();
  58. if (status.Equals(SchemaStatus.New))
  59. try
  60. {
  61. Provider.CreateSchema(ConsolidatedObjectModel().ToArray());
  62. SaveSchema();
  63. }
  64. catch (Exception err)
  65. {
  66. throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));
  67. }
  68. else if (status.Equals(SchemaStatus.Changed))
  69. try
  70. {
  71. Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());
  72. SaveSchema();
  73. }
  74. catch (Exception err)
  75. {
  76. throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));
  77. }
  78. // Start the provider
  79. Provider.Types = ConsolidatedObjectModel();
  80. Provider.OnLog += LogMessage;
  81. Provider.Start();
  82. CheckMetadata();
  83. if (!DataUpdater.MigrateDatabase())
  84. {
  85. throw new Exception("Database migration failed. Aborting startup");
  86. }
  87. //Load up your custom properties here!
  88. // Can't use clients (b/c we're inside the database layer already
  89. // but we can simply access the store directly :-)
  90. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  91. var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();
  92. DatabaseSchema.Load(props);
  93. AssertLicense();
  94. BeginLicenseCheckTimer();
  95. InitStores();
  96. LoadScripts();
  97. }
  98. #region MetaData
  99. private static void SaveMetadata()
  100. {
  101. var settings = new GlobalSettings
  102. {
  103. Section = nameof(DatabaseMetadata),
  104. Key = "",
  105. Contents = Serialization.Serialize(MetaData)
  106. };
  107. DbFactory.Provider.Save(settings);
  108. }
  109. private static void CheckMetadata()
  110. {
  111. var result = DbFactory.Provider.Query(new Filter<GlobalSettings>(x => x.Section).IsEqualTo(nameof(DatabaseMetadata)))
  112. .Rows.FirstOrDefault()?.ToObject<GlobalSettings>();
  113. var data = result is not null ? Serialization.Deserialize<DatabaseMetadata>(result.Contents) : null;
  114. if (data is null)
  115. {
  116. MetaData = new DatabaseMetadata();
  117. SaveMetadata();
  118. }
  119. else
  120. {
  121. MetaData = data;
  122. }
  123. }
  124. #endregion
  125. #region License
  126. private enum LicenseValidation
  127. {
  128. Valid,
  129. Missing,
  130. Expired,
  131. Corrupt,
  132. Tampered
  133. }
  134. private static LicenseValidation CheckLicenseValidity(out DateTime expiry)
  135. {
  136. expiry = DateTime.MinValue;
  137. var license = Provider.Load<License>().FirstOrDefault();
  138. if (license is null)
  139. return LicenseValidation.Missing;
  140. if (!LicenseUtils.TryDecryptLicense(license.Data, out var licenseData, out var error))
  141. return LicenseValidation.Corrupt;
  142. if (!LicenseUtils.ValidateMacAddresses(licenseData.Addresses))
  143. return LicenseValidation.Tampered;
  144. var userTrackingItems = Provider.Query(
  145. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  146. new Columns<UserTracking>(x => x.ID)
  147. , log: false
  148. ).Rows
  149. .Select(r => r.Get<UserTracking, Guid>(c => c.ID))
  150. .ToArray();
  151. foreach(var item in licenseData.UserTrackingItems)
  152. {
  153. if (!userTrackingItems.Contains(item))
  154. return LicenseValidation.Tampered;
  155. }
  156. expiry = licenseData.Expiry;
  157. if (licenseData.Expiry < DateTime.Now)
  158. return LicenseValidation.Expired;
  159. return LicenseValidation.Valid;
  160. }
  161. private static int _expiredLicenseCounter = 0;
  162. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  163. private static bool _readOnly;
  164. public static bool IsReadOnly { get => _readOnly; }
  165. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  166. private static void LogRenew(string message)
  167. {
  168. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  169. }
  170. public static void LogReadOnly()
  171. {
  172. LogImportant($"Your database is in read-only mode; please renew your license to enable database updates.");
  173. }
  174. private static void LogLicenseExpiry(DateTime expiry)
  175. {
  176. if (expiry.Date == DateTime.Today)
  177. {
  178. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  179. return;
  180. }
  181. var diffInDays = (expiry - DateTime.Now).TotalDays;
  182. if(diffInDays < 1)
  183. {
  184. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  185. }
  186. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  187. {
  188. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  189. _expiredLicenseCounter = 0;
  190. }
  191. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  192. {
  193. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  194. _expiredLicenseCounter = 0;
  195. }
  196. ++_expiredLicenseCounter;
  197. }
  198. private static void BeginReadOnly()
  199. {
  200. if (!IsReadOnly)
  201. {
  202. LogImportant(
  203. "Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  204. _readOnly = true;
  205. }
  206. }
  207. private static void EndReadOnly()
  208. {
  209. if (IsReadOnly)
  210. {
  211. LogImportant("Valid license found; the database is no longer read-only.");
  212. _readOnly = false;
  213. }
  214. }
  215. private static void BeginLicenseCheckTimer()
  216. {
  217. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  218. LicenseTimer.Start();
  219. }
  220. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  221. {
  222. AssertLicense();
  223. }
  224. public static void AssertLicense()
  225. {
  226. var result = CheckLicenseValidity(out DateTime expiry);
  227. switch (result)
  228. {
  229. case LicenseValidation.Valid:
  230. LogLicenseExpiry(expiry);
  231. EndReadOnly();
  232. break;
  233. case LicenseValidation.Missing:
  234. LogImportant("Database is unlicensed!");
  235. BeginReadOnly();
  236. break;
  237. case LicenseValidation.Expired:
  238. LogImportant("Database license has expired!");
  239. BeginReadOnly();
  240. break;
  241. case LicenseValidation.Corrupt:
  242. LogImportant("Database license is corrupt - you will need to renew your license.");
  243. BeginReadOnly();
  244. break;
  245. case LicenseValidation.Tampered:
  246. LogImportant("Database license has been tampered with - you will need to renew your license.");
  247. BeginReadOnly();
  248. break;
  249. }
  250. }
  251. #endregion
  252. #region Logging
  253. private static void LogMessage(LogType type, string message)
  254. {
  255. Logger.Send(type, "", message);
  256. }
  257. private static void LogInfo(string message)
  258. {
  259. Logger.Send(LogType.Information, "", message);
  260. }
  261. private static void LogImportant(string message)
  262. {
  263. Logger.Send(LogType.Important, "", message);
  264. }
  265. private static void LogError(string message)
  266. {
  267. Logger.Send(LogType.Error, "", message);
  268. }
  269. #endregion
  270. public static void InitStores()
  271. {
  272. foreach (var storetype in stores)
  273. {
  274. var store = (Activator.CreateInstance(storetype) as IStore)!;
  275. store.Provider = Provider;
  276. store.Init();
  277. }
  278. }
  279. public static IStore FindStore(Type type, Guid userguid, string userid, Platform platform, string version)
  280. {
  281. var defType = typeof(Store<>).MakeGenericType(type);
  282. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  283. var store = (Activator.CreateInstance(subType ?? defType) as IStore)!;
  284. store.Provider = Provider;
  285. store.UserGuid = userguid;
  286. store.UserID = userid;
  287. store.Platform = platform;
  288. store.Version = version;
  289. return store;
  290. }
  291. public static IStore<TEntity> FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version)
  292. where TEntity : Entity, new()
  293. {
  294. return (FindStore(typeof(TEntity), userguid, userid, platform, version) as IStore<TEntity>)!;
  295. }
  296. private static CoreTable DoQueryMultipleQuery<TEntity>(
  297. IQueryDef query,
  298. Guid userguid, string userid, Platform platform, string version)
  299. where TEntity : Entity, new()
  300. {
  301. var store = FindStore<TEntity>(userguid, userid, platform, version);
  302. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  303. }
  304. public static Dictionary<string, CoreTable> QueryMultiple(
  305. Dictionary<string, IQueryDef> queries,
  306. Guid userguid, string userid, Platform platform, string version)
  307. {
  308. var result = new Dictionary<string, CoreTable>();
  309. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  310. var tasks = new List<Task>();
  311. foreach (var item in queries)
  312. tasks.Add(Task.Run(() =>
  313. {
  314. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]
  315. {
  316. item.Value,
  317. userguid, userid, platform, version
  318. }) as CoreTable)!;
  319. }));
  320. Task.WaitAll(tasks.ToArray());
  321. return result;
  322. }
  323. //public static void OpenSession(bool write)
  324. //{
  325. // Provider.OpenSession(write);
  326. //}
  327. //public static void CloseSession()
  328. //{
  329. // Provider.CloseSession();
  330. //}
  331. #region Private Methods
  332. public static void LoadScripts()
  333. {
  334. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  335. LoadedScripts.Clear();
  336. var scripts = Provider.Load(
  337. new Filter<Script>
  338. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  339. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  340. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  341. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  342. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  343. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  344. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  345. );
  346. foreach (var script in scripts)
  347. {
  348. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  349. var doc = new ScriptDocument(script.Code);
  350. if (doc.Compile())
  351. {
  352. Logger.Send(LogType.Information, "",
  353. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  354. LoadedScripts[key] = doc;
  355. }
  356. else
  357. {
  358. Logger.Send(LogType.Error, "",
  359. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  360. }
  361. }
  362. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  363. }
  364. //private static Type[] entities = null;
  365. //private static void SetEntityTypes(Type[] types)
  366. //{
  367. // foreach (Type type in types)
  368. // {
  369. // if (!type.IsSubclassOf(typeof(Entity)))
  370. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  371. // }
  372. // entities = types;
  373. //}
  374. private static Type[] stores = { };
  375. private static void SetStoreTypes(Type[] types)
  376. {
  377. types = types.Where(
  378. myType => myType.IsClass
  379. && !myType.IsAbstract
  380. && !myType.IsGenericType).ToArray();
  381. foreach (var type in types)
  382. if (!type.GetInterfaces().Contains(typeof(IStore)))
  383. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  384. stores = types;
  385. }
  386. private static Type[] ConsolidatedObjectModel()
  387. {
  388. // Add the core types from InABox.Core
  389. var types = new List<Type>();
  390. //var coreTypes = CoreUtils.TypeList(
  391. // new Assembly[] { typeof(Entity).Assembly },
  392. // myType =>
  393. // myType.IsClass
  394. // && !myType.IsAbstract
  395. // && !myType.IsGenericType
  396. // && myType.IsSubclassOf(typeof(Entity))
  397. // && myType.GetInterfaces().Contains(typeof(IRemotable))
  398. //);
  399. //types.AddRange(coreTypes);
  400. // Now add the end-user object model
  401. types.AddRange(Entities.Where(x =>
  402. x.GetTypeInfo().IsClass
  403. && !x.GetTypeInfo().IsGenericType
  404. && x.GetTypeInfo().IsSubclassOf(typeof(Entity))
  405. ));
  406. return types.ToArray();
  407. }
  408. private enum SchemaStatus
  409. {
  410. New,
  411. Changed,
  412. Validated
  413. }
  414. private static Dictionary<string, Type> GetSchema()
  415. {
  416. var model = new Dictionary<string, Type>();
  417. var objectmodel = ConsolidatedObjectModel();
  418. foreach (var type in objectmodel)
  419. {
  420. Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);
  421. foreach (var key in thismodel.Keys)
  422. model[type.Name + "." + key] = thismodel[key];
  423. }
  424. return model;
  425. //return Serialization.Serialize(model, Formatting.Indented);
  426. }
  427. private static SchemaStatus ValidateSchema()
  428. {
  429. var db_schema = Provider.GetSchema();
  430. if (db_schema.Count() == 0)
  431. return SchemaStatus.New;
  432. var mdl_json = Serialization.Serialize(GetSchema());
  433. var db_json = Serialization.Serialize(db_schema);
  434. return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;
  435. }
  436. private static void SaveSchema()
  437. {
  438. Provider.SaveSchema(GetSchema());
  439. }
  440. #endregion
  441. }