DbFactory.cs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. using System.Reflection;
  2. using InABox.Clients;
  3. using InABox.Configuration;
  4. using InABox.Core;
  5. using InABox.Scripting;
  6. namespace InABox.Database
  7. {
  8. public static class DbFactory
  9. {
  10. public static Dictionary<string, ScriptDocument> LoadedScripts = new();
  11. private static IProvider? _provider;
  12. public static IProvider Provider
  13. {
  14. get => _provider ?? throw new Exception("Provider is not set");
  15. set => _provider = value;
  16. }
  17. public static bool IsProviderSet => _provider is not null;
  18. public static string? ColorScheme { get; set; }
  19. public static byte[]? Logo { get; set; }
  20. //public static Type[] Entities { get { return entities; } set { SetEntityTypes(value); } }
  21. public static IEnumerable<Type> Entities
  22. {
  23. get { return CoreUtils.Entities.Where(x => x.GetInterfaces().Contains(typeof(IPersistent))); }
  24. }
  25. public static Type[] Stores
  26. {
  27. get => stores;
  28. set => SetStoreTypes(value);
  29. }
  30. public static DateTime Expiry { get; set; }
  31. public static void Start()
  32. {
  33. CoreUtils.CheckLicensing();
  34. var status = ValidateSchema();
  35. if (status.Equals(SchemaStatus.New))
  36. try
  37. {
  38. Provider.CreateSchema(ConsolidatedObjectModel().ToArray());
  39. SaveSchema();
  40. }
  41. catch (Exception err)
  42. {
  43. throw new Exception(string.Format("Unable to Create Schema\n\n{0}", err.Message));
  44. }
  45. else if (status.Equals(SchemaStatus.Changed))
  46. try
  47. {
  48. Provider.UpgradeSchema(ConsolidatedObjectModel().ToArray());
  49. SaveSchema();
  50. }
  51. catch (Exception err)
  52. {
  53. throw new Exception(string.Format("Unable to Update Schema\n\n{0}", err.Message));
  54. }
  55. // Start the provider
  56. Provider.Types = ConsolidatedObjectModel();
  57. Provider.OnLog += LogMessage;
  58. Provider.Start();
  59. if (!DataUpdater.MigrateDatabase())
  60. {
  61. throw new Exception("Database migration failed. Aborting startup");
  62. }
  63. //Load up your custom properties here!
  64. // Can't use clients (b/c were inside the database layer already
  65. // but we can simply access the store directly :-)
  66. //CustomProperty[] props = FindStore<CustomProperty>("", "", "", "").Load(new Filter<CustomProperty>(x=>x.ID).IsNotEqualTo(Guid.Empty),null);
  67. var props = Provider.Query<CustomProperty>().Rows.Select(x => x.ToObject<CustomProperty>()).ToArray();
  68. DatabaseSchema.Load(props);
  69. AssertLicense();
  70. BeginLicenseCheckTimer();
  71. InitStores();
  72. LoadScripts();
  73. }
  74. #region License
  75. private enum LicenseValidation
  76. {
  77. Valid,
  78. Missing,
  79. Expired,
  80. Corrupt,
  81. Tampered
  82. }
  83. private static LicenseValidation CheckLicenseValidity(out License? license, out LicenseData? licenseData)
  84. {
  85. license = Provider.Load<License>().FirstOrDefault();
  86. if (license is null)
  87. {
  88. licenseData = null;
  89. return LicenseValidation.Missing;
  90. }
  91. if (!LicenseUtils.TryDecryptLicense(license.Data, out licenseData, out var error))
  92. return LicenseValidation.Corrupt;
  93. if (licenseData.Expiry < DateTime.Now)
  94. return LicenseValidation.Expired;
  95. var userTrackingItems = Provider.Query(
  96. new Filter<UserTracking>(x => x.ID).InList(licenseData.UserTrackingItems),
  97. new Columns<UserTracking>(x => x.ID), log: false).Rows.Select(x => x.Get<UserTracking, Guid>(x => x.ID));
  98. foreach(var item in licenseData.UserTrackingItems)
  99. {
  100. if (!userTrackingItems.Contains(item))
  101. {
  102. return LicenseValidation.Tampered;
  103. }
  104. }
  105. return LicenseValidation.Valid;
  106. }
  107. private static int _expiredLicenseCounter = 0;
  108. private static TimeSpan LicenseCheckInterval = TimeSpan.FromMinutes(10);
  109. private static bool _readOnly;
  110. public static bool IsReadOnly { get => _readOnly; }
  111. private static System.Timers.Timer LicenseTimer = new System.Timers.Timer(LicenseCheckInterval.TotalMilliseconds) { AutoReset = true };
  112. private static void LogRenew(string message)
  113. {
  114. LogImportant($"{message} Please renew your license before then, or your database will go into read-only mode; it will be locked for saving anything until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  115. }
  116. private static void LogLicenseExpiry(DateTime expiry)
  117. {
  118. if (expiry.Date == DateTime.Today)
  119. {
  120. LogRenew($"Your database license is expiring today at {expiry.TimeOfDay:HH:mm}!");
  121. return;
  122. }
  123. var diffInDays = (expiry - DateTime.Now).TotalDays;
  124. if(diffInDays < 1)
  125. {
  126. LogRenew($"Your database license will expire in less than a day, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  127. }
  128. else if(diffInDays < 3 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 1)
  129. {
  130. LogRenew($"Your database license will expire in less than three days, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  131. _expiredLicenseCounter = 0;
  132. }
  133. else if(diffInDays < 7 && (_expiredLicenseCounter * LicenseCheckInterval).TotalHours >= 2)
  134. {
  135. LogRenew($"Your database license will expire in less than a week, on the {expiry:dd MMM yyyy} at {expiry:hh:mm:tt}.");
  136. _expiredLicenseCounter = 0;
  137. }
  138. ++_expiredLicenseCounter;
  139. }
  140. public static void LogReadOnly()
  141. {
  142. LogError("Database is read-only because your license is invalid!");
  143. }
  144. private static void BeginReadOnly()
  145. {
  146. LogImportant("Your database is now in read-only mode, since your license is invalid; you will be unable to save any records to the database until you renew your license. For help with renewing your license, please see the documentation at https://prsdigital.com.au/wiki/index.php/License_Renewal.");
  147. _readOnly = true;
  148. }
  149. private static void EndReadOnly()
  150. {
  151. LogImportant("Valid license found; the database is no longer read-only.");
  152. _readOnly = false;
  153. }
  154. private static void BeginLicenseCheckTimer()
  155. {
  156. LicenseTimer.Elapsed += LicenseTimer_Elapsed;
  157. LicenseTimer.Start();
  158. }
  159. private static void LicenseTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
  160. {
  161. AssertLicense();
  162. }
  163. private static Random LicenseIDGenerate = new Random();
  164. private static void UpdateValidLicense(License license, LicenseData licenseData)
  165. {
  166. var ids = Provider.Query(
  167. new Filter<UserTracking>(x => x.Created).IsGreaterThanOrEqualTo(licenseData.LastRenewal),
  168. new Columns<UserTracking>(x => x.ID), log: false);
  169. var newIDList = new List<Guid>();
  170. if(ids.Rows.Count > 0)
  171. {
  172. for (int i = 0; i < 10; i++)
  173. {
  174. newIDList.Add(ids.Rows[LicenseIDGenerate.Next(0, ids.Rows.Count)].Get<UserTracking, Guid>(x => x.ID));
  175. }
  176. }
  177. licenseData.UserTrackingItems = newIDList.ToArray();
  178. if(LicenseUtils.TryEncryptLicense(licenseData, out var newData, out var error))
  179. {
  180. license.Data = newData;
  181. Provider.Save(license);
  182. }
  183. }
  184. private static void AssertLicense()
  185. {
  186. var result = CheckLicenseValidity(out var license, out var licenseData);
  187. if (IsReadOnly)
  188. {
  189. if(result == LicenseValidation.Valid)
  190. {
  191. EndReadOnly();
  192. }
  193. return;
  194. }
  195. // TODO: Switch to real system
  196. if(result != LicenseValidation.Valid)
  197. {
  198. var newLicense = LicenseUtils.GenerateNewLicense();
  199. if (LicenseUtils.TryEncryptLicense(newLicense, out var newData, out var error))
  200. {
  201. if (license == null)
  202. license = new License();
  203. license.Data = newData;
  204. Provider.Save(license);
  205. }
  206. else
  207. {
  208. Logger.Send(LogType.Error, "", $"Error updating license: {error}");
  209. }
  210. return;
  211. }
  212. else
  213. {
  214. return;
  215. }
  216. switch (result)
  217. {
  218. case LicenseValidation.Valid:
  219. LogLicenseExpiry(licenseData!.Expiry);
  220. UpdateValidLicense(license, licenseData);
  221. break;
  222. case LicenseValidation.Missing:
  223. LogImportant("Database is unlicensed!");
  224. BeginReadOnly();
  225. break;
  226. case LicenseValidation.Expired:
  227. LogImportant("Database license has expired!");
  228. BeginReadOnly();
  229. break;
  230. case LicenseValidation.Corrupt:
  231. LogImportant("Database license is corrupt - you will need to renew your license.");
  232. BeginReadOnly();
  233. break;
  234. case LicenseValidation.Tampered:
  235. LogImportant("Database license has been tampered with - you will need to renew your license.");
  236. BeginReadOnly();
  237. break;
  238. }
  239. }
  240. #endregion
  241. #region Logging
  242. private static void LogMessage(LogType type, string message)
  243. {
  244. Logger.Send(type, "", message);
  245. }
  246. private static void LogInfo(string message)
  247. {
  248. Logger.Send(LogType.Information, "", message);
  249. }
  250. private static void LogImportant(string message)
  251. {
  252. Logger.Send(LogType.Important, "", message);
  253. }
  254. private static void LogError(string message)
  255. {
  256. Logger.Send(LogType.Error, "", message);
  257. }
  258. #endregion
  259. public static void InitStores()
  260. {
  261. foreach (var storetype in stores)
  262. {
  263. var store = Activator.CreateInstance(storetype) as IStore;
  264. store.Provider = Provider;
  265. store.Init();
  266. }
  267. }
  268. public static IStore? FindStore(Type type, Guid userguid, string userid, Platform platform, string version)
  269. {
  270. var defType = typeof(Store<>).MakeGenericType(type);
  271. Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  272. var store = Activator.CreateInstance(subType ?? defType) as IStore;
  273. if (store != null)
  274. {
  275. store.Provider = Provider;
  276. store.UserGuid = userguid;
  277. store.UserID = userid;
  278. store.Platform = platform;
  279. store.Version = version;
  280. }
  281. return store;
  282. }
  283. public static IStore<TEntity>? FindStore<TEntity>(Guid userguid, string userid, Platform platform, string version)
  284. where TEntity : Entity, new()
  285. {
  286. return FindStore(typeof(TEntity), userguid, userid, platform, version) as IStore<TEntity>;
  287. // var defType = typeof(Store<>).MakeGenericType(typeof(TEntity));
  288. //
  289. // Type? subType = Stores.Where(myType => myType.IsSubclassOf(defType)).FirstOrDefault();
  290. //
  291. // var store = (Store<TEntity>)Activator.CreateInstance(subType ?? defType)!;
  292. // store.Provider = Provider;
  293. // store.UserGuid = userguid;
  294. // store.UserID = userid;
  295. // store.Platform = platform;
  296. // store.Version = version;
  297. // return store;
  298. }
  299. private static CoreTable DoQueryMultipleQuery<TEntity>(
  300. IQueryDef query,
  301. Guid userguid, string userid, Platform platform, string version)
  302. where TEntity : Entity, new()
  303. {
  304. var store = FindStore<TEntity>(userguid, userid, platform, version);
  305. return store.Query(query.Filter as Filter<TEntity>, query.Columns as Columns<TEntity>, query.SortOrder as SortOrder<TEntity>);
  306. }
  307. public static Dictionary<string, CoreTable> QueryMultiple(
  308. Dictionary<string, IQueryDef> queries,
  309. Guid userguid, string userid, Platform platform, string version)
  310. {
  311. var result = new Dictionary<string, CoreTable>();
  312. var queryMethod = typeof(DbFactory).GetMethod(nameof(DoQueryMultipleQuery), BindingFlags.NonPublic | BindingFlags.Static)!;
  313. var tasks = new List<Task>();
  314. foreach (var item in queries)
  315. tasks.Add(Task.Run(() =>
  316. {
  317. result[item.Key] = (queryMethod.MakeGenericMethod(item.Value.Type).Invoke(Provider, new object[]
  318. {
  319. item.Value,
  320. userguid, userid, platform, version
  321. }) as CoreTable)!;
  322. }));
  323. Task.WaitAll(tasks.ToArray());
  324. return result;
  325. }
  326. #region Supported Types
  327. private class ModuleConfiguration : Dictionary<string, bool>, ILocalConfigurationSettings
  328. {
  329. }
  330. private static Type[]? _dbtypes;
  331. public static IEnumerable<string> SupportedTypes()
  332. {
  333. _dbtypes ??= LoadSupportedTypes();
  334. return _dbtypes.Select(x => x.EntityName().Replace(".", "_"));
  335. }
  336. private static Type[] LoadSupportedTypes()
  337. {
  338. var result = new List<Type>();
  339. var path = Provider.URL.ToLower();
  340. var config = new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Load();
  341. var bChanged = false;
  342. foreach (var type in Entities)
  343. {
  344. var key = type.EntityName();
  345. if (config.ContainsKey(key))
  346. {
  347. if (config[key])
  348. //Logger.Send(LogType.Information, "", String.Format("{0} is enabled", key));
  349. result.Add(type);
  350. else
  351. Logger.Send(LogType.Information, "", string.Format("Entity [{0}] is disabled", key));
  352. }
  353. else
  354. {
  355. //Logger.Send(LogType.Information, "", String.Format("{0} does not exist - enabling", key));
  356. config[key] = true;
  357. result.Add(type);
  358. bChanged = true;
  359. }
  360. }
  361. if (bChanged)
  362. new LocalConfiguration<ModuleConfiguration>(Path.GetDirectoryName(path) ?? "", Path.GetFileName(path)).Save(config);
  363. return result.ToArray();
  364. }
  365. public static bool IsSupported<T>() where T : Entity
  366. {
  367. _dbtypes ??= LoadSupportedTypes();
  368. return _dbtypes.Contains(typeof(T));
  369. }
  370. #endregion
  371. //public static void OpenSession(bool write)
  372. //{
  373. // Provider.OpenSession(write);
  374. //}
  375. //public static void CloseSession()
  376. //{
  377. // Provider.CloseSession();
  378. //}
  379. #region Private Methods
  380. public static void LoadScripts()
  381. {
  382. Logger.Send(LogType.Information, "", "Loading Script Cache...");
  383. LoadedScripts.Clear();
  384. var scripts = Provider.Load(
  385. new Filter<Script>
  386. (x => x.ScriptType).IsEqualTo(ScriptType.BeforeQuery)
  387. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterQuery)
  388. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeSave)
  389. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterSave)
  390. .Or(x => x.ScriptType).IsEqualTo(ScriptType.BeforeDelete)
  391. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterDelete)
  392. .Or(x => x.ScriptType).IsEqualTo(ScriptType.AfterLoad)
  393. );
  394. foreach (var script in scripts)
  395. {
  396. var key = string.Format("{0} {1}", script.Section, script.ScriptType.ToString());
  397. var doc = new ScriptDocument(script.Code);
  398. if (doc.Compile())
  399. {
  400. Logger.Send(LogType.Information, "",
  401. string.Format("- {0}.{1} Compiled Successfully", script.Section, script.ScriptType.ToString()));
  402. LoadedScripts[key] = doc;
  403. }
  404. else
  405. {
  406. Logger.Send(LogType.Error, "",
  407. string.Format("- {0}.{1} Compile Exception:\n{2}", script.Section, script.ScriptType.ToString(), doc.Result));
  408. }
  409. }
  410. Logger.Send(LogType.Information, "", "Loading Script Cache Complete");
  411. }
  412. //private static Type[] entities = null;
  413. //private static void SetEntityTypes(Type[] types)
  414. //{
  415. // foreach (Type type in types)
  416. // {
  417. // if (!type.IsSubclassOf(typeof(Entity)))
  418. // throw new Exception(String.Format("{0} is not a valid entity", type.Name));
  419. // }
  420. // entities = types;
  421. //}
  422. private static Type[] stores = { };
  423. private static void SetStoreTypes(Type[] types)
  424. {
  425. types = types.Where(
  426. myType => myType.IsClass
  427. && !myType.IsAbstract
  428. && !myType.IsGenericType).ToArray();
  429. foreach (var type in types)
  430. if (!type.GetInterfaces().Contains(typeof(IStore)))
  431. throw new Exception(string.Format("{0} is not a valid store", type.Name));
  432. stores = types;
  433. }
  434. private static Type[] ConsolidatedObjectModel()
  435. {
  436. // Add the core types from InABox.Core
  437. var types = new List<Type>();
  438. //var coreTypes = CoreUtils.TypeList(
  439. // new Assembly[] { typeof(Entity).Assembly },
  440. // myType =>
  441. // myType.IsClass
  442. // && !myType.IsAbstract
  443. // && !myType.IsGenericType
  444. // && myType.IsSubclassOf(typeof(Entity))
  445. // && myType.GetInterfaces().Contains(typeof(IRemotable))
  446. //);
  447. //types.AddRange(coreTypes);
  448. // Now add the end-user object model
  449. types.AddRange(Entities.Where(x =>
  450. x.GetTypeInfo().IsClass
  451. && !x.GetTypeInfo().IsGenericType
  452. && x.GetTypeInfo().IsSubclassOf(typeof(Entity))
  453. ));
  454. return types.ToArray();
  455. }
  456. private enum SchemaStatus
  457. {
  458. New,
  459. Changed,
  460. Validated
  461. }
  462. private static Dictionary<string, Type> GetSchema()
  463. {
  464. var model = new Dictionary<string, Type>();
  465. var objectmodel = ConsolidatedObjectModel();
  466. foreach (var type in objectmodel)
  467. {
  468. Dictionary<string, Type> thismodel = CoreUtils.PropertyList(type, x => true, true);
  469. foreach (var key in thismodel.Keys)
  470. model[type.Name + "." + key] = thismodel[key];
  471. }
  472. return model;
  473. //return Serialization.Serialize(model, Formatting.Indented);
  474. }
  475. private static SchemaStatus ValidateSchema()
  476. {
  477. var db_schema = Provider.GetSchema();
  478. if (db_schema.Count() == 0)
  479. return SchemaStatus.New;
  480. var mdl_json = Serialization.Serialize(GetSchema());
  481. var db_json = Serialization.Serialize(db_schema);
  482. return mdl_json.Equals(db_json) ? SchemaStatus.Validated : SchemaStatus.Changed;
  483. }
  484. private static void SaveSchema()
  485. {
  486. Provider.SaveSchema(GetSchema());
  487. }
  488. #endregion
  489. }
  490. }